diff --git a/client b/client index 33b6868..3a1243a 100644 Binary files a/client and b/client differ diff --git a/client.exe b/client.exe index cd76507..4621545 100644 Binary files a/client.exe and b/client.exe differ diff --git a/proxy/client.go b/proxy/client.go index ce95da7..7e98bce 100644 --- a/proxy/client.go +++ b/proxy/client.go @@ -10,30 +10,35 @@ type UClientProxy struct { clientAddr string } -func (client *UClientProxy) NewClientConnection() *PacketConnect { +func (client *UClientProxy) NewClientConnection(id uint32) *PacketConnect { conn, err := net.Dial("tcp", client.clientAddr) if err != nil { zlog.Infof("connect local client error: %s %s", err.Error(), client.clientAddr) return nil } - return NewPacketConnect(conn, "local") + return NewPacketConnect(conn, "local", id) } func (client *UClientProxy) NewTcpConnection(id uint32) *PacketConnect { conn, err := net.Dial("tcp", client.serverAddr) if err != nil { - zlog.Infof("connect server error: %s", err.Error()) + zlog.Fatalf("connect server error: %d", id) + return nil } - remote := NewPacketConnect(conn, "remote") + remote := NewPacketConnect(conn, "remote", id) remote.WriteHeader(id) if id == RemoteID { return remote } - local := client.NewClientConnection() + local := client.NewClientConnection(id) if local == nil { err = remote.Close() return nil } - go local.Forward(remote, 0) + zlog.Infof("Forward %s => %s", remote.LocalAddr(), remote.Name()) + go func() { + local.Forward(remote, 0) + //remote.SetReadDeadline(time.Now()) + }() go remote.Forward(local, 0) return nil } @@ -48,7 +53,7 @@ func (client *UClientProxy) ServeTCP(serverAddr string, clientAddr string) error zlog.Error("rcv package error:", data) continue } - id := packetEndian.Uint32(data[4:]) + id := packetEndian.Uint32(data[4:8]) go client.NewTcpConnection(id) } } diff --git a/proxy/packet_connect.go b/proxy/packet_connect.go index 532f91d..5eecc5a 100644 --- a/proxy/packet_connect.go +++ b/proxy/packet_connect.go @@ -1,34 +1,12 @@ package proxy import ( - "encoding/binary" "net" "strconv" "time" + "zproxy/zlog" ) -var ( - packetEndian = binary.LittleEndian - RemoteID uint32 = 1 - sid = RemoteID - maxSid uint32 = 1024 * 1024 - HeaderSize = 16 - Magic = "anki" - InitBuf = func() []byte { - buf := make([]byte, HeaderSize) - copy(buf[:len(Magic)], Magic) - return buf - }() -) - -func IncID() uint32 { - if sid > maxSid { - sid = RemoteID - } - sid++ - return sid -} - type PacketConnect struct { net.Conn header []byte @@ -38,16 +16,16 @@ type PacketConnect struct { name string } -func NewPacketConnect(conn net.Conn, name string) *PacketConnect { +func NewPacketConnect(conn net.Conn, name string, id uint32) *PacketConnect { pc := &PacketConnect{Conn: conn, header: make([]byte, HeaderSize)} - pc.SetRemote(name, 0) + pc.SetName(name, id) copy(pc.header, InitBuf) return pc } func (pc *PacketConnect) Name() string { return pc.name } -func (pc *PacketConnect) SetRemote(name string, id uint32) { +func (pc *PacketConnect) SetName(name string, id uint32) { pc.id = id pc.name = "(" + name + ")" + strconv.Itoa(int(id)) + "::" + pc.RemoteAddr().String() } @@ -62,6 +40,12 @@ func (pc *PacketConnect) Read(data []byte) int { pc.readErr = err return n } +func (pc *PacketConnect) SetReadDeadline(t time.Time) { + err := pc.Conn.SetReadDeadline(t) + if err != nil { + pc.readErr = err + } +} func (pc *PacketConnect) ReadHeader() []byte { pc.Read(pc.header) return pc.header @@ -76,12 +60,17 @@ func (pc *PacketConnect) Forward(dst *PacketConnect, rn int) { } defer func() { _ = pc.Close() + if Level > 0 { + zlog.Infof("forward %d %s %s => %s %s || %s", rn, pc.LocalAddr(), pc.Name(), dst.Name(), pc.readErr, dst.writeErr) + } }() - data := make([]byte, 1024) + data := make([]byte, MTU) for pc.readErr == nil && dst.writeErr == nil { - _ = pc.SetReadDeadline(time.Now().Add(time.Second)) n := pc.Read(data) + n = dst.Write(data[:n]) rn += n - dst.Write(data[:n]) + if Level > 1 { + zlog.Infof("rw %d %s => %s %s || %s", n, pc.Name(), dst.Name(), pc.readErr, dst.writeErr) + } } } diff --git a/proxy/server.go b/proxy/server.go index d10c9d5..f94f1fb 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -23,13 +23,13 @@ func (server *UServerProxy) NewTcpConnection(client *PacketConnect) { data := client.ReadHeader() if string(data[:4]) == "anki" { id := packetEndian.Uint32(data[4:8]) - client.SetRemote("remote", id) + client.SetName("remote", id) if id == RemoteID { if server.RemoteConn != nil && server.RemoteConn != client { _ = server.RemoteConn.Close() } + zlog.Infof("rcv remote %s => %s", client.Name(), client.LocalAddr()) server.RemoteConn = client - zlog.Infof("rcv %d %s", id, client.Name()) } else { server.RemoteMap[id] = client server.MessageQueue <- id @@ -37,12 +37,20 @@ func (server *UServerProxy) NewTcpConnection(client *PacketConnect) { return } if server.RemoteConn == nil { + _ = client.Close() return } id := IncID() server.RemoteConn.WriteHeader(id) + if server.RemoteConn.writeErr != nil { + zlog.Infof("remote disconnect %s", server.RemoteConn.RemoteAddr()) + _ = server.RemoteConn.Close() + server.RemoteConn = nil + return + } server.ClientMap[id] = client - zlog.Infof("%d %s", id, data) + client.SetName("local", id) + zlog.Infof("%d %d %s %s", len(server.RemoteMap), len(server.ClientMap), data, client.Name()) } func (server *UServerProxy) MessageLoop() { for { @@ -51,17 +59,15 @@ func (server *UServerProxy) MessageLoop() { remote := server.RemoteMap[mid] client := server.ClientMap[mid] if client == nil || remote == nil { - server.RemoteMap[mid] = nil - server.ClientMap[mid] = nil continue } go func() { client.Forward(remote, HeaderSize) - server.ClientMap[mid] = nil + delete(server.ClientMap, mid) }() go func() { remote.Forward(client, 0) - server.RemoteMap[mid] = nil + delete(server.RemoteMap, mid) }() } } @@ -74,10 +80,7 @@ func (server *UServerProxy) ServeTCP(listenAddr string) error { } zlog.Infof("listen %s", ln.Addr()) defer func(ln net.Listener) { - err := ln.Close() - if err != nil { - zlog.Error("close client error: ", err.Error()) - } + _ = ln.Close() }(ln) go server.MessageLoop() for { @@ -85,6 +88,6 @@ func (server *UServerProxy) ServeTCP(listenAddr string) error { if err != nil { return err } - server.NewTcpConnection(NewPacketConnect(conn, "client")) + go server.NewTcpConnection(NewPacketConnect(conn, "client", 0)) } } diff --git a/proxy/type.go b/proxy/type.go new file mode 100644 index 0000000..8c5e7a9 --- /dev/null +++ b/proxy/type.go @@ -0,0 +1,27 @@ +package proxy + +import "encoding/binary" + +var ( + Level = 1 + packetEndian = binary.LittleEndian + RemoteID uint32 = 1 + sid = RemoteID + maxSid uint32 = 1024 * 1024 + MTU = 65495 //1024 * 64 + HeaderSize = 16 + Magic = "anki" + InitBuf = func() []byte { + buf := make([]byte, HeaderSize) + copy(buf[:len(Magic)], Magic) + return buf + }() +) + +func IncID() uint32 { + if sid > maxSid { + sid = RemoteID + } + sid++ + return sid +} diff --git a/run/client/client.go b/run/client/client.go index 61c9194..d6f78ce 100644 --- a/run/client/client.go +++ b/run/client/client.go @@ -29,6 +29,7 @@ func main() { func init() { _addr := flag.String("addr", ServerAddr, "server addr") _port := flag.Int("port", ServerPort, "listen port") + _level := flag.Int("level", proxy.Level, "debug level") remoteAddr := flag.String("remote_addr", RemoteAddr, "server addr") remotePort := flag.Int("remote_port", RemotePort, "listen port") @@ -38,4 +39,5 @@ func init() { RemoteAddr = *remoteAddr RemotePort = *remotePort + proxy.Level = *_level } diff --git a/run/server/server.go b/run/server/server.go index 406ff3c..b43b989 100644 --- a/run/server/server.go +++ b/run/server/server.go @@ -24,7 +24,9 @@ func main() { func init() { _addr := flag.String("addr", ServerAddr, "server addr") _port := flag.Int("port", ServerPort, "listen port") + _level := flag.Int("level", proxy.Level, "debug level") flag.Parse() ServerAddr = *_addr ServerPort = *_port + proxy.Level = *_level } diff --git a/server b/server index ee65086..7d31f33 100644 Binary files a/server and b/server differ diff --git a/server.exe b/server.exe index 9aaf642..7569466 100644 Binary files a/server.exe and b/server.exe differ diff --git a/zlog/zlog.go b/zlog/zlog.go index f121cc4..3e39747 100644 --- a/zlog/zlog.go +++ b/zlog/zlog.go @@ -14,7 +14,7 @@ var ( ) func init() { - currentLevel = zap.DebugLevel + currentLevel = zap.FatalLevel cfg = zap.NewDevelopmentConfig() cfg.Development = true rebuildLoggerFromCfg()