package auth import ( "bytes" "crypto/md5" "encoding/binary" "encoding/hex" "errors" "fmt" "git.qianqiusoft.com/library/glog" "git.qianqiusoft.com/qianqiusoft/light-apiengine/config" "git.qianqiusoft.com/qianqiusoft/light-apiengine/entitys" sysutils "git.qianqiusoft.com/qianqiusoft/light-apiengine/utils" "net" "strconv" "time" ) const ( __KEY = "Light#dauth-@*I2" CMD_NEW = "new" CMD_REMOVE = "remove" CMD_PINGPONG = "pg" CMD_PINGPONG_RESP = "pg_resp" ) var TCPClient *TcpClient type authPackage struct { Cmd string TokenStr string Content []byte } func (ap *authPackage)toBytes()[]byte{ buf := bytes.NewBuffer([]byte{}) b := []byte(ap.Cmd) buf.Write(uint32ToBytes(len(b))) buf.Write(b) buf.Write(uint32ToBytes(len(ap.Content))) buf.Write(ap.Content) return buf.Bytes() } type TcpClient struct { conn net.Conn // 连接 pchan chan *authPackage // 包chan done chan bool // 是否done exited bool // 退出 verified bool // 验证是否 } // 创建client func NewTcpClient()*TcpClient{ c := &TcpClient{ pchan: make(chan *authPackage, 100), done: make(chan bool), exited: false, verified: false, } return c } // 启动 func(c *TcpClient)Start() { go func() { defer func() { if p := recover(); p != nil { fmt.Println("ecover", p) } if c.conn != nil{ c.conn.Close() } c.restart() }() var err error = nil address := config.AppConfig.GetKey("auth_server") //fmt.Println("auth client start, dial address is", address) c.conn, err = net.Dial("tcp", address) if err != nil { //fmt.Println("Error dialing", err.Error()) return } fmt.Println("发送验证") sendVerify(c.conn) // 发送验证,不需要读取返回值,如果验证错误立刻关掉 fmt.Println("读取验证结果") vresp, err := readString(c.conn) if err != nil { fmt.Println("Error dialing", err.Error()) return } if vresp != "ok"{ // 验证失败 fmt.Println("verify is not ok", vresp) return } fmt.Println("验证成功") c.verified = true // send go func() { for { select { case data := <-c.pchan: glog.Infoln("写入数据") c.conn.SetWriteDeadline(time.Now().Add(time.Second * 2)) _, err := c.conn.Write(data.toBytes()) if err != nil { fmt.Println("写入内容错误", err.Error()) return } case <-c.done: glog.Infoln("发送数据done退出") return } } }() // receive for { cmd, err := readString(c.conn) // 读取命令 if err != nil { c.done <- true fmt.Println("读取命令错误", err.Error()) break } if cmd == CMD_NEW{ err = c.newHandler() } else if cmd == CMD_REMOVE{ err = c.removeHandler() } else if cmd == CMD_PINGPONG_RESP{ } else { fmt.Println("未知cmd", cmd) continue } if err != nil{ c.done <- true fmt.Println("处理错误", err.Error()) break } } }() } // 停止 func(c *TcpClient)Stop(){ c.exited = true c.conn.Close() } // 检测 func(c *TcpClient)restart(){ if c.exited{ // 已退出则不重启 return } go func(){ c.verified = false c.done = make(chan bool) c.pchan = make(chan *authPackage, 100) c.exited = false time.Sleep(3 * time.Second) c.Start() }() } // 发送bytes func(c *TcpClient)Send(cmd string, bytess []byte){ if !c.verified{ fmt.Println("未认证") return } glog.Infoln("发送指令1", cmd) c.pchan <- &authPackage{ Cmd: cmd, Content: bytess, } glog.Infoln("发送指令2", cmd) } // 发送token func(c *TcpClient)SendToken(token *entitys.Token){ glog.Infoln("发送新建token") bytess := tokenToBytes(token) c.Send(CMD_NEW, bytess) } // 处理创建 func(c *TcpClient)newHandler()error{ fmt.Println("处理新建") bytess, err := readBytes(c.conn) if err != nil { fmt.Println("读取token内容错误", err.Error()) return err } // 新建 token, err := bytesToToken(bytess) if err != nil{ glog.Infoln("bytesToToken 错误", err.Error()) return err } sysutils.GetGlobalTokenStore().Set(token.AccessToken, token) return nil } // 处理删除 func(c *TcpClient)removeHandler()error{ fmt.Println("处理删除") bytess, err := readBytes(c.conn) if err != nil { fmt.Println("读取token内容错误", err.Error()) return err } // 移除,此时bytess为tokenstring sysutils.GetGlobalTokenStore().Remove(string(bytess)) return nil } // 读取字符串 func readString(conn net.Conn)(string, error){ // 读长度 size, err := readUInt32(conn) if err != nil{ fmt.Println("读取长度失败,", err.Error()) return "", err } // 读字符串 b := make([]byte, size) n, err := conn.Read(b) if n != int(size){ return "", errors.New("读取长度不是" + strconv.Itoa(int(size))) } return string(b), nil } // 写入字符串 func writeString(conn net.Conn, str string)error{ if str == ""{ return errors.New("字符串为空") } bytess := []byte(str) size := len(bytess) // 发送长度 err := writeUInt32(conn, uint32(size)) if err != nil{ fmt.Println("发送长度失败,", err.Error()) return err } // 发送内容 n, err := conn.Write(bytess) if err != nil{ fmt.Println("发送内容失败,", err.Error()) return err } if n != size{ return errors.New("发送长度不是" + strconv.Itoa(int(size))) } return nil } // 读取bytes func readBytes(conn net.Conn)([]byte, error){ // 读长度 size, err := readUInt32(conn) if err != nil{ fmt.Println("读取长度失败,", err.Error()) return nil, err } // 读字符串 b := make([]byte, size) n, err := conn.Read(b) if n != int(size){ return nil, errors.New("读取长度不是" + strconv.Itoa(int(size))) } return b, nil } // 读取uint64 func readUInt32(conn net.Conn)(uint32, error){ b := make([]byte, 4) n, err := conn.Read(b) if err != nil{ fmt.Println("读取长度失败,", err.Error()) return 0, err } if n != 4{ return 0, errors.New("读取长度不是4") } size := binary.BigEndian.Uint32(b) return size, nil } // 写入长度 func writeUInt32(conn net.Conn, v uint32)error{ // 发送长度 b := make([]byte, 4) binary.BigEndian.PutUint32(b, v) n, err := conn.Write(b) if err != nil{ fmt.Println("发送长度失败,", err.Error()) return err } if n != 4{ return errors.New("发送长度不是4") } return nil } // 写入长度 func writeUInt64(conn net.Conn, v uint64)error{ // 发送长度 b := make([]byte, 8) binary.BigEndian.PutUint64(b, v) n, err := conn.Write(b) if err != nil{ fmt.Println("发送长度失败,", err.Error()) return err } if n != 4{ return errors.New("发送长度不是4") } return nil } // 读取uint64 func readStringByBytes(bytess []byte)(string, int, error) { size := binary.BigEndian.Uint32(bytess) return string(bytess[4 : 4+size]), int(size), nil } // int转bytes func uint32ToBytes(v int)[]byte{ b := make([]byte, 4) binary.BigEndian.PutUint32(b, uint32(v)) return b } // int转bytes func uint64ToBytes(v int)[]byte{ b := make([]byte, 8) binary.BigEndian.PutUint32(b, uint32(v)) return b } // 转token func bytesToToken(content []byte)(*entitys.Token, error){ token := &entitys.Token{} var index int = 0 var size int var err error = nil fmt.Println("读取userid") token.UserId, size, err = readStringByBytes(content) if err != nil{ fmt.Println("读取userid错误") return nil, err } index += 4 + size fmt.Println("读取AccessToken") token.AccessToken, size, err = readStringByBytes(content[index:]) if err != nil{ fmt.Println("读取AccessToken错误") return nil, err } index += 4 + size fmt.Println("读取RefreshToken") token.RefreshToken, size, err = readStringByBytes(content[index:]) if err != nil{ fmt.Println("读取RefreshToken错误") return nil, err } index += 4 + size fmt.Println("读取LoginID") token.LoginID, size, err = readStringByBytes(content[index:]) if err != nil{ fmt.Println("读取LoginID错误") return nil, err } index += 4 + size fmt.Println("读取timestamp") token.TimeStamp = binary.BigEndian.Uint64(content[index:]) index += 8 fmt.Println("读取ServerIp") token.ServerIp, size, err = readStringByBytes(content[index:]) if err != nil{ fmt.Println("读取ServerIp错误") return nil, err } index += 4 + size fmt.Println("读取Domain") token.Domain, size, err = readStringByBytes(content[index:]) if err != nil{ fmt.Println("读取Domain错误") return nil, err } index += 4 + size return token, nil } // 转bytes,包括token开头 func tokenToBytes(token *entitys.Token)[]byte{ buf := bytes.NewBuffer([]byte{}) t := []byte(token.UserId) buf.Write(uint32ToBytes(len(t))) buf.Write(t) t = []byte(token.AccessToken) buf.Write(uint32ToBytes(len(t))) buf.Write(t) t = []byte(token.RefreshToken) buf.Write(uint32ToBytes(len(t))) buf.Write(t) t = []byte(token.LoginID) buf.Write(uint32ToBytes(len(t))) buf.Write(t) buf.Write(uint64ToBytes(int(token.TimeStamp))) fmt.Println(token.ServerIp) t = []byte(token.ServerIp) buf.Write(uint32ToBytes(len(t))) buf.Write(t) fmt.Println(token.Domain) t = []byte(token.Domain) buf.Write(uint32ToBytes(len(t))) buf.Write(t) bytess := buf.Bytes() buf = bytes.NewBuffer([]byte{}) // 这里用reset是错误的 tokenstrbytes := []byte(token.AccessToken) buf.Write(uint32ToBytes(len(tokenstrbytes))) buf.Write(tokenstrbytes) buf.Write(uint32ToBytes(len(bytess))) buf.Write(bytess) return buf.Bytes() } // 发送验证 func sendVerify(conn net.Conn){ timestamp := time.Now().UnixNano() timestampStr := strconv.Itoa(int(timestamp)) seed := timestampStr + __KEY hashVal := hash(seed) writeUInt64(conn, uint64(timestamp)) writeString(conn, hashVal) } // md5 哈希 func hash(str string)string{ h := md5.New() h.Write([]byte(str)) return hex.EncodeToString(h.Sum(nil)) }