zproxy/proxy/server.go
2024-02-16 00:07:57 +08:00

96 lines
2.2 KiB
Go

package proxy
import (
"net"
"time"
"zproxy/zlog"
)
type UServerProxy struct {
RemoteMap map[uint32]*PacketConnect
LocalMap map[uint32]*PacketConnect
RemoteConn *PacketConnect
MessageQueue chan uint32
}
func NewServerProxy() *UServerProxy {
return &UServerProxy{
RemoteMap: make(map[uint32]*PacketConnect),
LocalMap: make(map[uint32]*PacketConnect),
MessageQueue: make(chan uint32, 32),
}
}
func (server *UServerProxy) NewTcpConnection(remote *PacketConnect) {
data := remote.ReadHeader()
if string(data[:4]) == "anki" {
id := packetEndian.Uint32(data[4:8])
remote.SetName("remote", id)
if id == RemoteID {
if server.RemoteConn != nil && server.RemoteConn != remote {
_ = server.RemoteConn.Close()
}
zlog.Infof("rcv remote %s => %s", remote.Name(), remote.LocalAddr())
server.RemoteConn = remote
} else {
server.RemoteMap[id] = remote
server.MessageQueue <- id
}
return
}
if server.RemoteConn == nil {
_ = remote.Close()
return
}
id := IncID()
server.RemoteConn.WriteHeader(id)
if server.RemoteConn.writeErr != nil {
zlog.Infof("remote disconnect %s", server.RemoteConn.RemoteAddr())
_ = server.RemoteConn.Close()
server.RemoteConn = nil
return
}
server.LocalMap[id] = remote
remote.SetName("remote", id)
zlog.Infof("%d %d %s %s", len(server.RemoteMap), len(server.LocalMap), data, remote.Name())
}
func (server *UServerProxy) MessageLoop() {
for {
select {
case mid := <-server.MessageQueue:
remote := server.RemoteMap[mid]
local := server.LocalMap[mid]
if local == nil || remote == nil {
continue
}
go func() {
local.Forward(remote, HeaderSize)
delete(server.LocalMap, mid)
}()
go func() {
remote.Forward(local, 0)
local.SetReadDeadline(time.Now())
delete(server.RemoteMap, mid)
}()
}
}
}
func (server *UServerProxy) ServeTCP(listenAddr string) error {
// listenAddr 只能填端口号
ln, err := net.Listen("tcp", listenAddr)
if err != nil {
zlog.Error("listen error:", err.Error())
}
zlog.Infof("listen %s", ln.Addr())
defer func(ln net.Listener) {
_ = ln.Close()
}(ln)
go server.MessageLoop()
for {
conn, err := ln.Accept()
if err != nil {
return err
}
go server.NewTcpConnection(NewPacketConnect(conn, "client", 0))
}
}