提交 28a4f526 编写于 作者: W wuhanqing

兼容websocket

上级 3c4627d1
......@@ -11,5 +11,6 @@ type IUser interface {
GetConnData() ([]byte, error)
GetConn() net.Conn
IsHttp([]byte) bool
MsgCount() int
Close() error
}
......@@ -11,6 +11,7 @@ import (
)
type DataPack struct {
protocol string
msgFormat string
}
......@@ -33,14 +34,19 @@ func NewDataPack(msgFormat string) *DataPack {
return &DataPack{msgFormat: msgFormat}
}
func (dp *DataPack) SetProtocol(protocol string) {
dp.protocol = protocol
}
func (dp DataPack) Pack(data protoreflect.ProtoMessage) (result []byte, err error) {
if dp.msgFormat == config.MSG_FORMAT_JSON {
result, err = json.Marshal(data)
return
}
if dp.msgFormat == config.MSG_FORMAT_PROTOBUF {
result, err = proto.Marshal(data)
return
}
if dp.protocol == PROTOCOL_WEBSOCKET {
result = WebSocketPackage(result)
}
if dp.msgFormat != config.MSG_FORMAT_JSON && dp.msgFormat != config.MSG_FORMAT_PROTOBUF {
err = fmt.Errorf("msgFormat(%v) can not Pack.", dp.msgFormat)
......@@ -49,6 +55,9 @@ func (dp DataPack) Pack(data protoreflect.ProtoMessage) (result []byte, err erro
}
func (dp DataPack) Unpack(data []byte, result protoreflect.ProtoMessage) error {
if dp.protocol == PROTOCOL_WEBSOCKET {
data = WebSocketParse(data)
}
if dp.msgFormat == config.MSG_FORMAT_JSON {
return json.Unmarshal(data, result)
}
......
......@@ -3,10 +3,19 @@ package model
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
)
const (
WebsocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
PROTOCOL_WEBSOCKET = "websocket"
)
type Request struct {
......@@ -33,7 +42,25 @@ func (r Request) ResponseJson(v interface{}) error {
return err
}
}
data := NewResponseOK(body).Json()
data := NewResponseOK(body).HttpJson()
_, err = r.conn.Write(data)
return err
}
func (r Request) ResponseWebSocket() error {
reqHeader := r.GetHttpRequest().Header
accept, err := r.getWebSocketNonceAccept([]byte(reqHeader["Sec-Websocket-Key"][0])) // Sec-WebSocket-Key to Sec-Websocket-Key
if err != nil {
return err
}
h := []string{
fmt.Sprintf("HTTP/1.1 %d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)),
"Connection: Upgrade",
"Upgrade: websocket",
fmt.Sprintf("Sec-WebSocket-Accept: %s", string(accept)),
}
htext := strings.Join(h, "\n") + "\n\r\n"
data := []byte(htext)
_, err = r.conn.Write(data)
return err
}
......@@ -73,3 +100,34 @@ func (r Request) RemoteAddr() net.Addr {
func (r Request) LocalAddr() net.Addr {
return r.conn.LocalAddr()
}
func (r Request) IsWebSocket() bool {
req := r.httpRequest
if req.Method != "GET" {
return false
}
h := r.httpRequest.Header
connVal := h.Get("Connection")
// if strings.ToUpper(connVal) == "UPGRADE" {
if connVal == "Upgrade" {
if h.Get(connVal) == "websocket" {
return true
}
}
return false
}
// getWebSocketNonceAccept computes the base64-encoded SHA-1 of the concatenation of
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
func (r Request) getWebSocketNonceAccept(nonce []byte) (expected []byte, err error) {
h := sha1.New()
if _, err = h.Write(nonce); err != nil {
return
}
if _, err = h.Write([]byte(WebsocketGUID)); err != nil {
return
}
expected = make([]byte, 28)
base64.StdEncoding.Encode(expected, h.Sum(nil))
return
}
......@@ -2,6 +2,7 @@ package model
import (
"bytes"
"fmt"
"net/http"
"strings"
......@@ -20,7 +21,7 @@ func NewResponseOK(body []byte) *Response {
return NewResponse(http.StatusOK, body)
}
func (r Response) GetHeader() []string {
func (r Response) getHeader() []string {
bodyLen := len(r.body)
return []string{
fmt.Sprintf("HTTP/1.1 %d %s", r.statusCode, http.StatusText(r.statusCode)),
......@@ -40,16 +41,16 @@ func (r Response) GetHeader() []string {
// https://developer.mozilla.org/zh-CN/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
}
func (r Response) Json() []byte {
header := r.GetHeader()
func (r Response) HttpJson() []byte {
header := r.getHeader()
header = append(header, "Content-Type: application/json; charset=utf-8")
headerText := strings.Join(header, "\n")
return bytes.Join([][]byte{[]byte(headerText), r.body}, []byte("\n\r\n"))
// return []byte(headerText + "\n\r\n" + string(r.body))
}
func (r Response) Html() []byte {
header := r.GetHeader()
func (r Response) HttpHtml() []byte {
header := r.getHeader()
header = append(header, "Content-Type: text/html; charset=utf-8")
headerText := strings.Join(header, "\n")
return []byte(headerText + "\n\r\n" + string(r.body))
......
package model
import (
"bytes"
"encoding/binary"
"fmt"
)
func BytesCombine(pBytes ...[]byte) []byte {
return bytes.Join(pBytes, []byte(""))
}
func WebSocketParse(data []byte) []byte {
en_bytes := []byte("")
cn_bytes := make([]int, 0)
v := data[1] & 0x7f
p := 0
switch v {
case 0x7e:
p = 4
case 0x7f:
p = 10
default:
p = 2
}
mask := data[p : p+4]
data_tmp := data[p+4:]
nv := ""
nv_bytes := []byte("")
nv_len := 0
for k, v := range data_tmp {
nv = string(int(v ^ mask[k%4]))
// nv = fmt.Sprintf("%d", int(v^mask[k%4]))
nv_bytes = []byte(nv)
nv_len = len(nv_bytes)
if nv_len == 1 {
en_bytes = BytesCombine(en_bytes, nv_bytes)
} else {
en_bytes = BytesCombine(en_bytes, []byte("%s"))
cn_bytes = append(cn_bytes, int(v^mask[k%4]))
}
}
//处理中文
cn_str := make([]interface{}, 0)
if len(cn_bytes) > 2 {
clen := len(cn_bytes)
count := int(clen / 3)
for i := 0; i < count; i++ {
mm := i * 3
hh := make([]byte, 3)
h1, _ := IntToBytes(cn_bytes[mm], 1)
h2, _ := IntToBytes(cn_bytes[mm+1], 1)
h3, _ := IntToBytes(cn_bytes[mm+2], 1)
hh[0] = h1[0]
hh[1] = h2[0]
hh[2] = h3[0]
cn_str = append(cn_str, string(hh))
}
// TODO string to []byte
new := string(bytes.Replace(en_bytes, []byte("%s%s%s"), []byte("%s"), -1))
return []byte(fmt.Sprintf(new, cn_str...))
}
return en_bytes
}
func WebSocketPackage(data []byte) []byte {
lenth := len(data)
token := string(0x81)
if lenth < 126 {
token += string(lenth)
}
bb, _ := IntToBytes(0x81, 1)
b0 := bb[0]
b1 := byte(0)
framePos := 0
// fmt.Println("长度", lenth)
switch {
case lenth >= 65536:
writeBuf := make([]byte, 10)
writeBuf[framePos] = b0
writeBuf[framePos+1] = b1 | 127
binary.BigEndian.PutUint64(writeBuf[framePos+2:], uint64(lenth))
return BytesCombine(writeBuf, data)
case lenth > 125:
fmt.Println("》125")
writeBuf := make([]byte, 4)
writeBuf[framePos] = b0
writeBuf[framePos+1] = b1 | 126
binary.BigEndian.PutUint16(writeBuf[framePos+2:], uint16(lenth))
fmt.Println(writeBuf)
return BytesCombine(writeBuf, data)
default:
writeBuf := make([]byte, 2)
writeBuf[framePos] = b0
writeBuf[framePos+1] = b1 | byte(lenth)
return BytesCombine(writeBuf, data)
}
}
// 整形转换成字节
func IntToBytes(n int, b byte) ([]byte, error) {
switch b {
case 1:
tmp := int8(n)
bytesBuffer := bytes.NewBuffer([]byte{})
binary.Write(bytesBuffer, binary.BigEndian, &tmp)
return bytesBuffer.Bytes(), nil
case 2:
tmp := int16(n)
bytesBuffer := bytes.NewBuffer([]byte{})
binary.Write(bytesBuffer, binary.BigEndian, &tmp)
return bytesBuffer.Bytes(), nil
case 3, 4:
tmp := int32(n)
bytesBuffer := bytes.NewBuffer([]byte{})
binary.Write(bytesBuffer, binary.BigEndian, &tmp)
return bytesBuffer.Bytes(), nil
}
return nil, fmt.Errorf("IntToBytes b param is invaild")
}
......@@ -22,21 +22,47 @@ func MainHandler(u contract.IUser) error {
logger.Debug("---handler.MainHandler--error:", err)
return err
}
if u.IsHttp(data) {
// HTTP API 接口业务处理
err = HttpHandler(model.NewRequest(data, u.GetConn()))
dp := model.GetDataPack()
if u.IsHttp(data) && u.MsgCount() == 1 {
// HTTP API 接口业务处理。不支持HTTP 的 Keep-Alive
req := model.NewRequest(data, u.GetConn())
req.ParseHttp()
if req.IsWebSocket() {
// websocket 握手
dp.SetProtocol(model.PROTOCOL_WEBSOCKET)
return req.ResponseWebSocket()
}
err = HttpHandler(req)
if err != nil {
logger.Debug("---handler.MainHandler--HttpHandler--error:", err)
return err
}
// HTTP 一次请求响应后,立即关闭连接。不支持HTTP 的 Keep-Alive
return u.Close()
}
logger.Debug("---------TCP------u.MsgCount=", u.MsgCount())
// TODO IM即时通讯业务处理
dataStr := "Response:" + string(data)
// 接收 FROM_USER 发送给TO_USER
msg := model.Msg{}
err = dp.Unpack(data, &msg)
if err != nil {
return fmt.Errorf("unpack msg fail:%v", err)
}
logger.Debug("-----ReceivedMsg(%v)--msg.ChatType(%d)--", msg.String(), msg.ChatType)
if msg.ChatType == model.Msg_SINGLE {
// 单聊。发送给TO_USER
msg.Content += "--Msg_SINGLE--Response--"
data, err = dp.Pack(&msg)
return u.SendData(data)
}
return u.SendData([]byte(dataStr))
if msg.ChatType == model.Msg_GROUP {
// 群聊。发送给群里的每一个成员。
msg.Content += "--Msg_GROUP--Response--"
data, err = dp.Pack(&msg)
return u.SendData(data)
}
return fmt.Errorf("unknown ChatType")
//提取用户的消息(去除'\n')
// msg := string(data[:n-1])
......@@ -49,5 +75,5 @@ func MainHandler(u contract.IUser) error {
// } else {
// u.server.BroadCast(u, msg)
// }
// return nil
}
......@@ -7,9 +7,8 @@ import (
)
func HttpHandler(req *model.Request) error {
req.ParseHttp()
hreq := req.GetHttpRequest()
body := req.GetHttpBody()
fmt.Printf("\n--method(%s)--Header(%+v)--Body(%s)-\n", hreq.Method, hreq.Header, string(body))
fmt.Printf("\n--method(%s)--proto(%s)--Header(%+v)--Body(%s)-\n", hreq.Method, hreq.Proto, hreq.Header, string(body))
return req.ResponseJson(ResponseOk("hello response Json From struct")) //
}
......@@ -4,20 +4,21 @@ import (
"fmt"
"io"
"net"
"strings"
"github.com/iotames/easyim/contract"
"github.com/iotames/easyim/model"
"github.com/iotames/miniutils"
)
const ERR_CONNECT_LOST = "connect lost"
// User. 一个TCP连接。ClentSocket
type User struct {
Name string
Addr string
IsClosed bool
// data chan []byte
Message chan string
Name string
Addr string
IsClosed bool
message chan []byte
msgCount int
isActive chan bool
conn net.Conn
server contract.IServer
......@@ -31,7 +32,7 @@ func NewUser(conn net.Conn, s contract.IServer) *User {
u := &User{
Name: userAddr,
Addr: userAddr,
Message: make(chan string),
message: make(chan []byte),
isActive: make(chan bool),
conn: conn,
server: s,
......@@ -68,12 +69,23 @@ func (u User) ConnectLost() {
}
}
// IsHttp 判断本次TCP连接,是否为HTTP协议。
// 因本系统不支持HTTP 的 Keep-Alive。故一个HTTP协议的TCP连接,只能发送一次消息。
// 消息发送超过1次,一定不是HTTP协议。
func (u User) IsHttp(data []byte) bool {
method := string(data[:4])
if method == "POST" || method == "GET " {
return true
method := string(data[:7])
httpMethods := []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "UPDATE"}
isHttp := false
for _, m := range httpMethods {
if strings.Contains(method, m) {
isHttp = true
}
}
return false
return isHttp
}
func (u User) MsgCount() int {
return u.msgCount
}
func (u *User) Close() error {
......@@ -85,7 +97,7 @@ func (u *User) Close() error {
// u.ReceiveDataToSend([]byte("连接长时间不活跃,连接已断开")) // 异步操作消息还没发出去,连接就断开了
// u.SendData([]byte("连接长时间不活跃,连接已断开")) // OK 给用户发送消息,同步操作
//销毁用的资源
close(u.Message)
close(u.message)
//关闭连接
u.IsClosed = true
return u.conn.Close()
......@@ -94,8 +106,7 @@ func (u *User) Close() error {
// ReceiveDataToSend 接受消息,并通过channel发送给客户端。异步操作。支持并发。
// 当连接断开时,可能会继续发送异步消息。此时须使用同步锁
func (u *User) ReceiveDataToSend(d []byte) {
// u.data <- d
u.Message <- string(d)
u.message <- d
}
// GetConn 获取TCP连接
......@@ -141,30 +152,20 @@ func (u *User) GetConnData() (data []byte, err error) {
// 如果是命令行输入TCP消息,会包含换行符 \n
data = buf[:n]
u.msgCount += 1
return
}
// 监听当前User channel的 方法,一旦有消息,就直接发送给对端客户端
func (u *User) ListenMessage() {
for {
msg := <-u.Message
u.SendData([]byte(msg))
msg := <-u.message
u.SendData(msg)
}
}
// SendData 发送数据给客户端。同步操作
func (u User) SendData(d []byte) error {
var err error
dp := model.GetDataPack()
msg := model.Msg{}
err = dp.Unpack(d, &msg)
if msg.ChatType == model.Msg_SINGLE {
// 单聊。发送给TO_USER
_, err = u.conn.Write(d)
}
if msg.ChatType == model.Msg_GROUP {
// 群聊。发送给群里的每一个成员。除了自己
}
_, err := u.conn.Write(d)
return err
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册