||
- package auth
- import (
- "bytes"
- "crypto/md5"
- "encoding/binary"
- "encoding/hex"
- "errors"
- "fmt"
- "net"
- "strconv"
- "sync"
- "time"
- "git.qianqiusoft.com/library/glog"
- "git.qianqiusoft.com/qianqiusoft/light-apiengine/config"
- "git.qianqiusoft.com/qianqiusoft/light-apiengine/entitys"
- sysutils "git.qianqiusoft.com/qianqiusoft/light-apiengine/utils"
- )
- const (
- __KEY = "Light#dauth-@*I2"
- CMD_NEW = "new"
- CMD_REMOVE = "remove"
- CMD_PINGPONG = "pg"
- CMD_PINGPONG_RESP = "pg_resp"
- )
- var TCPClient *TcpClient
- type authPackage struct {
- Cmd string
- TokenStr string
- Content []byte
- }
- func (ap *authPackage) toBytes() []byte {
- buf := bytes.NewBuffer([]byte{})
- b := []byte(ap.Cmd)
- buf.Write(uint32ToBytes(len(b)))
- buf.Write(b)
- buf.Write(uint32ToBytes(len(ap.Content)))
- buf.Write(ap.Content)
- return buf.Bytes()
- }
- type TcpClient struct {
- conn net.Conn // 连接
- pchan chan *authPackage // 包chan
- done chan bool // 是否done
- exited bool // 退出
- verified bool // 验证是否
- }
- // 创建client
- func NewTcpClient() *TcpClient {
- c := &TcpClient{
- pchan: make(chan *authPackage, 100),
- done: make(chan bool),
- exited: false,
- verified: false,
- }
- return c
- }
- // 启动
- func (c *TcpClient) Start() {
- go func() {
- defer func() {
- if p := recover(); p != nil {
- fmt.Println("ecover", p)
- }
- if c.conn != nil {
- c.conn.Close()
- }
- c.restart()
- }()
- var err error = nil
- address := config.AppConfig.GetKey("auth_server")
- //fmt.Println("auth client start, dial address is", address)
- c.conn, err = net.Dial("tcp", address)
- if err != nil {
- //fmt.Println("Error dialing", err.Error())
- return
- }
- fmt.Println("发送验证")
- sendVerify(c.conn) // 发送验证,不需要读取返回值,如果验证错误立刻关掉
- fmt.Println("读取验证结果")
- vresp, err := readString(c.conn)
- if err != nil {
- fmt.Println("Error dialing", err.Error())
- return
- }
- if vresp != "ok" {
- // 验证失败
- fmt.Println("verify is not ok", vresp)
- return
- }
- fmt.Println("验证成功")
- c.verified = true
- // send
- go func() {
- for {
- select {
- case data := <-c.pchan:
- glog.Infoln("写入数据")
- c.conn.SetWriteDeadline(time.Now().Add(time.Second * 2))
- _, err := c.conn.Write(data.toBytes())
- if err != nil {
- fmt.Println("写入内容错误", err.Error())
- return
- }
- case <-c.done:
- glog.Infoln("发送数据done退出")
- return
- }
- }
- }()
- // receive
- for {
- cmd, err := readString(c.conn) // 读取命令
- if err != nil {
- c.done <- true
- fmt.Println("读取命令错误", err.Error())
- break
- }
- if cmd == CMD_NEW {
- err = c.newHandler()
- } else if cmd == CMD_REMOVE {
- err = c.removeHandler()
- } else if cmd == CMD_PINGPONG_RESP {
- } else {
- fmt.Println("未知cmd", cmd)
- continue
- }
- if err != nil {
- c.done <- true
- fmt.Println("处理错误", err.Error())
- break
- }
- }
- }()
- }
- // 停止
- func (c *TcpClient) Stop() {
- c.exited = true
- c.conn.Close()
- }
- // 检测
- func (c *TcpClient) restart() {
- if c.exited {
- // 已退出则不重启
- return
- }
- go func() {
- c.verified = false
- c.done = make(chan bool)
- c.pchan = make(chan *authPackage, 100)
- c.exited = false
- time.Sleep(3 * time.Second)
- c.Start()
- }()
- }
- // 发送bytes
- func (c *TcpClient) Send(cmd string, bytess []byte) {
- if !c.verified {
- fmt.Println("未认证")
- return
- }
- glog.Infoln("发送指令1", cmd)
- c.pchan <- &authPackage{
- Cmd: cmd,
- Content: bytess,
- }
- glog.Infoln("发送指令2", cmd)
- }
- // 发送token
- func (c *TcpClient) SendToken(token *entitys.Token) {
- // glog.Infoln("发送新建token")
- bytess := tokenToBytes(token)
- c.Send(CMD_NEW, bytess)
- }
- // 处理创建
- func (c *TcpClient) newHandler() error {
- // fmt.Println("处理新建")
- bytess, err := readBytes(c.conn)
- if err != nil {
- fmt.Println("读取token内容错误", err.Error())
- return err
- }
- // 新建
- token, err := bytesToToken(bytess)
- if err != nil {
- glog.Infoln("bytesToToken 错误", err.Error())
- return err
- }
- sysutils.GetGlobalTokenStore().Set(token.AccessToken, token)
- return nil
- }
- // 处理删除
- func (c *TcpClient) removeHandler() error {
- fmt.Println("处理删除")
- bytess, err := readBytes(c.conn)
- if err != nil {
- fmt.Println("读取token内容错误", err.Error())
- return err
- }
- // 移除,此时bytess为tokenstring
- sysutils.GetGlobalTokenStore().Remove(string(bytess))
- return nil
- }
- // 读取字符串
- func readString(conn net.Conn) (string, error) {
- // 读长度
- size, err := readUInt32(conn)
- if err != nil {
- fmt.Println("读取长度失败,", err.Error())
- return "", err
- }
- // 读字符串
- b := make([]byte, size)
- n, err := conn.Read(b)
- if n != int(size) {
- return "", errors.New("读取长度不是" + strconv.Itoa(int(size)))
- }
- return string(b), nil
- }
- // 写入字符串
- func writeString(conn net.Conn, str string) error {
- if str == "" {
- return errors.New("字符串为空")
- }
- bytess := []byte(str)
- size := len(bytess)
- // 发送长度
- err := writeUInt32(conn, uint32(size))
- if err != nil {
- fmt.Println("发送长度失败,", err.Error())
- return err
- }
- // 发送内容
- n, err := conn.Write(bytess)
- if err != nil {
- fmt.Println("发送内容失败,", err.Error())
- return err
- }
- if n != size {
- return errors.New("发送长度不是" + strconv.Itoa(int(size)))
- }
- return nil
- }
- // 读取bytes
- func readBytes(conn net.Conn) ([]byte, error) {
- // 读长度
- size, err := readUInt32(conn)
- if err != nil {
- fmt.Println("读取长度失败,", err.Error())
- return nil, err
- }
- // 读字符串
- b := make([]byte, size)
- n, err := conn.Read(b)
- if n != int(size) {
- return nil, errors.New("读取长度不是" + strconv.Itoa(int(size)))
- }
- return b, nil
- }
- // 读取uint64
- func readUInt32(conn net.Conn) (uint32, error) {
- b := make([]byte, 4)
- n, err := conn.Read(b)
- if err != nil {
- fmt.Println("读取长度失败,", err.Error())
- return 0, err
- }
- if n != 4 {
- return 0, errors.New("读取长度不是4")
- }
- size := binary.BigEndian.Uint32(b)
- return size, nil
- }
- // 写入长度
- func writeUInt32(conn net.Conn, v uint32) error {
- // 发送长度
- b := make([]byte, 4)
- binary.BigEndian.PutUint32(b, v)
- n, err := conn.Write(b)
- if err != nil {
- fmt.Println("发送长度失败,", err.Error())
- return err
- }
- if n != 4 {
- return errors.New("发送长度不是4")
- }
- return nil
- }
- // 写入长度
- func writeUInt64(conn net.Conn, v uint64) error {
- // 发送长度
- b := make([]byte, 8)
- binary.BigEndian.PutUint64(b, v)
- n, err := conn.Write(b)
- if err != nil {
- fmt.Println("发送长度失败,", err.Error())
- return err
- }
- if n != 4 {
- return errors.New("发送长度不是4")
- }
- return nil
- }
- // 读取uint64
- func readStringByBytes(bytess []byte) (string, int, error) {
- size := binary.BigEndian.Uint32(bytess)
- return string(bytess[4 : 4+size]), int(size), nil
- }
- // int转bytes
- func uint32ToBytes(v int) []byte {
- b := make([]byte, 4)
- binary.BigEndian.PutUint32(b, uint32(v))
- return b
- }
- // int转bytes
- func uint64ToBytes(v int) []byte {
- b := make([]byte, 8)
- binary.BigEndian.PutUint32(b, uint32(v))
- return b
- }
- // 转token
- func bytesToToken(content []byte) (*entitys.Token, error) {
- token := &entitys.Token{Lock: new(sync.RWMutex)}
- var index int = 0
- var size int
- var err error = nil
- // fmt.Println("读取userid")
- token.UserId, size, err = readStringByBytes(content)
- if err != nil {
- fmt.Println("读取userid错误")
- return nil, err
- }
- index += 4 + size
- // fmt.Println("读取AccessToken")
- token.AccessToken, size, err = readStringByBytes(content[index:])
- if err != nil {
- fmt.Println("读取AccessToken错误")
- return nil, err
- }
- index += 4 + size
- // fmt.Println("读取RefreshToken")
- token.RefreshToken, size, err = readStringByBytes(content[index:])
- if err != nil {
- fmt.Println("读取RefreshToken错误")
- return nil, err
- }
- index += 4 + size
- // fmt.Println("读取LoginID")
- token.LoginID, size, err = readStringByBytes(content[index:])
- if err != nil {
- fmt.Println("读取LoginID错误")
- return nil, err
- }
- index += 4 + size
- // fmt.Println("读取timestamp")
- token.TimeStamp = binary.BigEndian.Uint64(content[index:])
- index += 8
- // fmt.Println("读取ServerIp")
- token.ServerIp, size, err = readStringByBytes(content[index:])
- if err != nil {
- fmt.Println("读取ServerIp错误")
- return nil, err
- }
- index += 4 + size
- // fmt.Println("读取Domain")
- token.Domain, size, err = readStringByBytes(content[index:])
- if err != nil {
- fmt.Println("读取Domain错误")
- return nil, err
- }
- index += 4 + size
- return token, nil
- }
- // 转bytes,包括token开头
- func tokenToBytes(token *entitys.Token) []byte {
- buf := bytes.NewBuffer([]byte{})
- t := []byte(token.UserId)
- buf.Write(uint32ToBytes(len(t)))
- buf.Write(t)
- t = []byte(token.AccessToken)
- buf.Write(uint32ToBytes(len(t)))
- buf.Write(t)
- t = []byte(token.RefreshToken)
- buf.Write(uint32ToBytes(len(t)))
- buf.Write(t)
- t = []byte(token.LoginID)
- buf.Write(uint32ToBytes(len(t)))
- buf.Write(t)
- buf.Write(uint64ToBytes(int(token.TimeStamp)))
- fmt.Println(token.ServerIp)
- t = []byte(token.ServerIp)
- buf.Write(uint32ToBytes(len(t)))
- buf.Write(t)
- fmt.Println(token.Domain)
- t = []byte(token.Domain)
- buf.Write(uint32ToBytes(len(t)))
- buf.Write(t)
- bytess := buf.Bytes()
- buf = bytes.NewBuffer([]byte{}) // 这里用reset是错误的
- tokenstrbytes := []byte(token.AccessToken)
- buf.Write(uint32ToBytes(len(tokenstrbytes)))
- buf.Write(tokenstrbytes)
- buf.Write(uint32ToBytes(len(bytess)))
- buf.Write(bytess)
- return buf.Bytes()
- }
- // 发送验证
- func sendVerify(conn net.Conn) {
- timestamp := time.Now().UnixNano()
- timestampStr := strconv.Itoa(int(timestamp))
- seed := timestampStr + __KEY
- hashVal := hash(seed)
- writeUInt64(conn, uint64(timestamp))
- writeString(conn, hashVal)
- }
- // md5 哈希
- func hash(str string) string {
- h := md5.New()
- h.Write([]byte(str))
- return hex.EncodeToString(h.Sum(nil))
- }
|