From 9e35441a46e7c32967ad085399b2da437b5d4906 Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 4 Sep 2020 23:54:13 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86=E6=B3=A8=E5=86=8C=E5=92=8C=E5=BF=83?= =?UTF-8?q?=E8=B7=B3=E8=BD=AC=E7=A7=BB=E5=88=B0=20=E9=80=9A=E9=81=93?= =?UTF-8?q?=E4=B8=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dtu/channel.go | 193 +++++++++++++++++++++++++++++++++++++++++++++---- dtu/link.go | 104 -------------------------- 2 files changed, 179 insertions(+), 118 deletions(-) diff --git a/dtu/channel.go b/dtu/channel.go index 6aa3b5f..9cd932a 100644 --- a/dtu/channel.go +++ b/dtu/channel.go @@ -1,12 +1,15 @@ package dtu import ( + "bytes" + "encoding/hex" "errors" "fmt" "github.com/zgwit/dtu-admin/db" "github.com/zgwit/dtu-admin/model" "log" "net" + "regexp" "sync" "time" ) @@ -85,6 +88,29 @@ func (c *baseChannel) storeError(err error) error { return err } + +func (c *baseChannel) checkRegister(buf []byte) (string, error) { + n := len(buf) + if n < c.RegisterMin { + return "", fmt.Errorf("register package is too short %d %s", n, string(buf[:n])) + } + serial := string(buf[:n]) + if c.RegisterMax > 0 && c.RegisterMax >= c.RegisterMin && n > c.RegisterMax { + serial = string(buf[:c.RegisterMax]) + } + + // 正则表达式判断合法性 + if c.RegisterRegex != "" { + reg := regexp.MustCompile(`^` + c.RegisterRegex + `$`) + match := reg.MatchString(serial) + if !match { + return "", fmt.Errorf("register package format error %s", serial) + } + } + + return serial, nil +} + type Client struct { baseChannel @@ -145,6 +171,24 @@ func (c *Client) receive(conn net.Conn) { log.Println(e) break } + + //过滤心跳包 + if c.HeartBeatEnable && time.Now().Sub(c.client.lastTime) > time.Second*time.Duration(c.HeartBeatInterval) { + var b []byte + if c.HeartBeatIsHex { + var e error + b, e = hex.DecodeString(c.HeartBeatContent) + if e != nil { + log.Println(e) + } + } else { + b = []byte(c.HeartBeatContent) + } + if bytes.Compare(b, buf) == 0 { + continue + } + } + c.client.onData(buf[:n]) } @@ -222,13 +266,67 @@ func (c *Server) accept() { func (c *Server) receive(conn net.Conn) { link := newLink(c, conn) + defer link.Close() + + buf := make([]byte, 1024) + + //第一个包作为注册包 + if c.RegisterEnable { + n, e := conn.Read(buf) + if e != nil { + log.Println(e) + return + } + + serial, err := c.baseChannel.checkRegister(buf) + if err != nil { + _, _ = link.Send([]byte(err.Error())) + return + } + + //配置序列号 + link.Serial = serial + + //查找数据库同通道,同序列号链接,更新数据库中 addr online + var lnk model.Link + has, err := db.Engine.Where("channel_id=?", c.Id).And("serial=?", serial).Get(&lnk) + if err != nil { + _, _ = link.Send([]byte("数据库异常")) + log.Println(err) + return + } + if has { + l, _ := c.GetLink(lnk.Id) + if l != nil { + //如果同序号连接还在正常通讯,则关闭当前连接 + if l.conn != nil { + _, _ = link.Send([]byte(fmt.Sprintf("duplicate serial %s", serial))) + return + } + + //复制有用的历史数据 + link.Rx = l.Rx + link.Tx = l.Tx + + //复制watcher + } + + link.Id = lnk.Id + link.Name = lnk.Name + //link.Serial = lnk.Serial + } + - //未开启注册,则直接保存 - if !c.RegisterEnable { - c.StoreLink(link) + //处理剩余内容 + if c.RegisterMax > 0 && n > c.RegisterMax { + link.onData(buf[c.RegisterMax:]) + } } - buf := make([]byte, 1024) + //保存链接 + c.StoreLink(link) + + for link.conn != nil { n, e := conn.Read(buf) if e != nil { @@ -250,7 +348,11 @@ func (c *Server) receive(conn net.Conn) { } else { //有序号,等待5分钟,之后设为离线 time.AfterFunc(time.Minute*5, func() { - c.clients.Delete(link.Id) + lnk, _ := c.GetLink(link.Id) + //判断指针地址也行 + if lnk != nil && lnk.conn == nil { + c.clients.Delete(link.Id) + } }) } } @@ -309,22 +411,85 @@ func (c *PacketServer) receive() { key := addr.String() //找到连接,将消息发送过去 - var client *Link + var link *Link v, ok := c.packetIndexes.Load(key) if ok { - client = v.(*Link) - } else { - client = newPacketLink(c, c.packetConn, addr) + link = v.(*Link) + + //过滤心跳包 + if c.HeartBeatEnable && time.Now().Sub(link.lastTime) > time.Second*time.Duration(c.HeartBeatInterval) { + var b []byte + if c.HeartBeatIsHex { + var e error + b, e = hex.DecodeString(c.HeartBeatContent) + if e != nil { + log.Println(e) + } + } else { + b = []byte(c.HeartBeatContent) + } + if bytes.Compare(b, buf) == 0 { + continue + } + } - //根据ID保存 - if !c.RegisterEnable { - c.StoreLink(client) + } else { + link = newPacketLink(c, c.packetConn, addr) + + //第一个包作为注册包 + if c.RegisterEnable { + serial, err := c.baseChannel.checkRegister(buf) + if err != nil { + _, _ = link.Send([]byte(err.Error())) + return + } + + //配置序列号 + link.Serial = serial + + //查找数据库同通道,同序列号链接,更新数据库中 addr online + var lnk model.Link + has, err := db.Engine.Where("channel_id=?", c.Id).And("serial=?", serial).Get(&lnk) + if err != nil { + _, _ = link.Send([]byte("数据库异常")) + log.Println(err) + return + } + if has { + l, _ := c.GetLink(lnk.Id) + if l != nil { + //如果同序号连接还在正常通讯,则关闭当前连接 + if l.conn != nil { + _, _ = link.Send([]byte(fmt.Sprintf("duplicate serial %s", serial))) + return + } + + //复制有用的历史数据 + link.Rx = l.Rx + link.Tx = l.Tx + + //复制watcher + } + + link.Id = lnk.Id + link.Name = lnk.Name + //link.Serial = lnk.Serial + } + + + //处理剩余内容 + if c.RegisterMax > 0 && n > c.RegisterMax { + link.onData(buf[c.RegisterMax:]) + } } + //保存链接 + c.StoreLink(link) + //根据地址保存,收到UDP包之后,方便索引 - c.packetIndexes.Store(key, client) + c.packetIndexes.Store(key, link) } - client.onData(buf[:n]) + link.onData(buf[:n]) } } diff --git a/dtu/link.go b/dtu/link.go index 2c445c6..d4acf07 100644 --- a/dtu/link.go +++ b/dtu/link.go @@ -1,15 +1,10 @@ package dtu import ( - "bytes" - "encoding/hex" "errors" - "fmt" "github.com/zgwit/dtu-admin/db" "github.com/zgwit/dtu-admin/model" - "log" "net" - "regexp" "time" ) @@ -18,115 +13,18 @@ type Link struct { registerChecked bool - //RemoteAddr net.Addr - Rx int Tx int conn net.Conn lastTime time.Time - - channel Channel -} - -func (l *Link) checkRegister(buf []byte) error { - ch := l.channel.GetChannel() - - n := len(buf) - if n < ch.RegisterMin { - return fmt.Errorf("register package is too short %d %s", n, string(buf[:n])) - } - serial := string(buf[:n]) - if ch.RegisterMax > 0 && ch.RegisterMax >= ch.RegisterMin && n > ch.RegisterMax { - serial = string(buf[:ch.RegisterMax]) - } - - // 正则表达式判断合法性 - if ch.RegisterRegex != "" { - reg := regexp.MustCompile(`^` + ch.RegisterRegex + `$`) - match := reg.MatchString(serial) - if !match { - return fmt.Errorf("register package format error %s", serial) - } - } - - //配置序列号 - l.Serial = serial - - //查找数据库同通道,同序列号链接,更新数据库中 addr online - var link model.Link - has, err := db.Engine.Where("channel_id=?", ch.Id).And("serial=?", serial).Get(&link) - if err != nil { - return err - } - if has { - lnk, _ := l.channel.GetLink(link.Id) - if lnk != nil { - //如果同序号连接还在正常通讯,则关闭当前连接 - if lnk.conn != nil { - return fmt.Errorf("duplicate serial %s", serial) - } - - //复制有用的历史数据 - l.Rx = lnk.Rx - l.Tx = lnk.Tx - - //复制watcher - } - - l.Id = link.Id - l.Name = link.Name - l.Serial = link.Serial - } - - //保存链接 - l.channel.StoreLink(l) - - //处理剩余内容 - if ch.RegisterMax > 0 && n > ch.RegisterMax { - l.onData(buf[ch.RegisterMax:]) - } - - return nil } func (l *Link) onData(buf []byte) { l.Rx += len(buf) l.lastTime = time.Now() - ch := l.channel.GetChannel() - - //检查注册包(只有服务端是检测) - if !l.registerChecked && ch.RegisterEnable && ch.Role == "server" { - err := l.checkRegister(buf) - if err != nil { - log.Println(err) - _, _ = l.Send([]byte(err.Error())) - _ = l.Close() - return - } - l.registerChecked = true - return - } - - //检查心跳包, 判断上次收发时间,是否已经过去心跳间隔 - if ch.HeartBeatEnable && time.Now().Sub(l.lastTime) > time.Second*time.Duration(ch.HeartBeatInterval) { - var b []byte - if ch.HeartBeatIsHex { - var e error - b, e = hex.DecodeString(ch.HeartBeatContent) - if e != nil { - log.Println(e) - } - } else { - b = []byte(ch.HeartBeatContent) - } - if bytes.Compare(b, buf) == 0 { - return - } - } - //TODO 内容转发,暂时直接回复 _, _ = l.Send(buf) } @@ -170,7 +68,6 @@ func newLink(ch Channel, conn net.Conn) *Link { Online: true, OnlineAt: time.Now(), }, - channel: ch, conn: conn, } } @@ -187,7 +84,6 @@ func newPacketLink(ch Channel, conn net.PacketConn, addr net.Addr) *Link { Online: true, OnlineAt: time.Now(), }, - channel: ch, conn: &PackConn{ PacketConn: conn, addr: addr, -- GitLab