Ver código fonte

实现消息解密

wenzl 9 anos atrás
pai
commit
e713b4ffb2
7 arquivos alterados com 388 adições e 0 exclusões
  1. 39 0
      context/context.go
  2. 61 0
      log/log.go
  3. 8 0
      message/message.go
  4. 99 0
      server/server.go
  5. 113 0
      util/crypto.go
  6. 18 0
      util/signature.go
  7. 50 0
      wechat.go

+ 39 - 0
context/context.go

@@ -0,0 +1,39 @@
+package context
+
+import "net/http"
+
+//Context struct
+type Context struct {
+	AppID          string
+	AppSecret      string
+	Token          string
+	EncodingAESKey string
+
+	Writer  http.ResponseWriter
+	Request *http.Request
+}
+
+func (ctx *Context) getAccessToken() {
+
+}
+
+func (ctx *Context) String(str string) error {
+	ctx.Writer.WriteHeader(200)
+	_, err := ctx.Writer.Write([]byte(str))
+	return err
+}
+
+// Query returns the keyed url query value if it exists
+func (ctx *Context) Query(key string) string {
+	value, _ := ctx.GetQuery(key)
+	return value
+}
+
+// GetQuery is like Query(), it returns the keyed url query value
+func (ctx *Context) GetQuery(key string) (string, bool) {
+	req := ctx.Request
+	if values, ok := req.URL.Query()[key]; ok && len(values) > 0 {
+		return values[0], true
+	}
+	return "", false
+}

+ 61 - 0
log/log.go

@@ -0,0 +1,61 @@
+package log
+
+import "github.com/astaxie/beego/logs"
+
+const (
+	LevelEmergency = iota
+	LevelAlert
+	LevelCritical
+	LevelError
+	LevelWarning
+	LevelNotice
+	LevelInformational
+	LevelDebug
+)
+
+type Logger struct {
+	*logs.BeeLogger
+}
+
+func NewLogger(channelLen int64, adapterName string, config string, logLevel int) *Logger {
+	logger := logs.NewLogger(channelLen)
+	logger.SetLogger(adapterName, config)
+	logger.SetLevel(logLevel)
+	logger.EnableFuncCallDepth(true)
+	logger.SetLogFuncCallDepth(3)
+	return &Logger{logger}
+}
+
+func (logger *Logger) Printf(format string, v ...interface{}) {
+	logger.Trace(format, v...)
+}
+
+var l *Logger
+
+func InitLogger(channelLen int64, adapterName string, config string, logLevel int) {
+	l = NewLogger(channelLen, adapterName, config, logLevel)
+}
+
+func Criticalf(format string, v ...interface{}) {
+	l.Critical(format, v...)
+}
+
+func Errorf(format string, v ...interface{}) {
+	l.Error(format, v...)
+}
+
+func Warnf(format string, v ...interface{}) {
+	l.Warn(format, v...)
+}
+
+func Infof(format string, v ...interface{}) {
+	l.Info(format, v...)
+}
+
+func Tracef(format string, v ...interface{}) {
+	l.Trace(format, v...)
+}
+
+func Debugf(format string, v ...interface{}) {
+	l.Debug(format, v...)
+}

+ 8 - 0
message/message.go

@@ -0,0 +1,8 @@
+package message
+
+//EncryptedXMLMsg 安全模式下的消息体
+type EncryptedXMLMsg struct {
+	XMLName      struct{} `xml:"xml" json:"-"`
+	ToUserName   string   `xml:"ToUserName" json:"ToUserName"`
+	EncryptedMsg string   `xml:"Encrypt"    json:"Encrypt"`
+}

+ 99 - 0
server/server.go

@@ -0,0 +1,99 @@
+package server
+
+import (
+	"encoding/xml"
+	"fmt"
+	"io/ioutil"
+
+	"github.com/silenceper/wechat/context"
+	"github.com/silenceper/wechat/message"
+	"github.com/silenceper/wechat/util"
+)
+
+//Server struct
+type Server struct {
+	*context.Context
+	isSafeMode bool
+	rawXMLMsg  string
+}
+
+//NewServer init
+func NewServer(context *context.Context) *Server {
+	srv := new(Server)
+	srv.Context = context
+	return srv
+}
+
+//Serve 处理微信的请求消息
+func (srv *Server) Serve() error {
+
+	if !srv.Validate() {
+		return fmt.Errorf("请求校验失败")
+	}
+
+	echostr, exists := srv.GetQuery("echostr")
+	if exists {
+		return srv.String(echostr)
+	}
+
+	srv.handleRequest()
+
+	return nil
+}
+
+//Validate 校验请求是否合法
+func (srv *Server) Validate() bool {
+	timestamp := srv.Query("timestamp")
+	nonce := srv.Query("nonce")
+	signature := srv.Query("signature")
+	return signature == util.Signature(srv.Token, timestamp, nonce)
+}
+
+//HandleRequest 处理微信的请求
+func (srv *Server) handleRequest() {
+	srv.isSafeMode = false
+	encryptType := srv.Query("encrypt_type")
+	if encryptType == "aes" {
+		srv.isSafeMode = true
+	}
+
+	_, err := srv.getMessage()
+	if err != nil {
+		fmt.Printf("%v", err)
+	}
+}
+
+func (srv *Server) getMessage() (interface{}, error) {
+	var rawXMLMsgBytes []byte
+	var err error
+	if srv.isSafeMode {
+		var encryptedXMLMsg message.EncryptedXMLMsg
+		if err := xml.NewDecoder(srv.Request.Body).Decode(&encryptedXMLMsg); err != nil {
+			return nil, fmt.Errorf("从body中解析xml失败,err=%v", err)
+		}
+
+		//验证消息签名
+		timestamp := srv.Query("timestamp")
+		nonce := srv.Query("nonce")
+		msgSignature := srv.Query("msg_signature")
+		msgSignatureCreate := util.Signature(srv.Token, timestamp, nonce, encryptedXMLMsg.EncryptedMsg)
+		if msgSignature != msgSignatureCreate {
+			return nil, fmt.Errorf("消息不合法,验证签名失败")
+		}
+
+		//解密
+		rawXMLMsgBytes, err = util.DecryptMsg(srv.AppID, encryptedXMLMsg.EncryptedMsg, srv.EncodingAESKey)
+		if err != nil {
+			return nil, fmt.Errorf("消息解密失败,err=%v", err)
+		}
+	} else {
+		rawXMLMsgBytes, err = ioutil.ReadAll(srv.Request.Body)
+		if err != nil {
+			return nil, fmt.Errorf("从body中解析xml失败,err=%v", err)
+		}
+	}
+
+	srv.rawXMLMsg = string(rawXMLMsgBytes)
+	fmt.Println(srv.rawXMLMsg)
+	return nil, nil
+}

+ 113 - 0
util/crypto.go

@@ -0,0 +1,113 @@
+package util
+
+import (
+	"crypto/aes"
+	"crypto/cipher"
+	"encoding/base64"
+	"errors"
+	"fmt"
+)
+
+//DecryptMsg 消息解密
+func DecryptMsg(appID, encryptedMsg, aesKey string) (rawMsgXMLBytes []byte, err error) {
+	var encryptedMsgBytes, key, getAppIDBytes []byte
+	encryptedMsgBytes, err = base64.StdEncoding.DecodeString(encryptedMsg)
+	if err != nil {
+		return
+	}
+	key, err = aesKeyDecode(aesKey)
+	if err != nil {
+		return
+	}
+	_, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key)
+	if err != nil {
+		err = fmt.Errorf("消息解密失败,%v", err)
+		return
+	}
+	if appID != string(getAppIDBytes) {
+		err = fmt.Errorf("消息解密校验APPID失败")
+		return
+	}
+	return
+}
+
+func aesKeyDecode(encodedAESKey string) (key []byte, err error) {
+	if len(encodedAESKey) != 43 {
+		err = errors.New("the length of encodedAESKey must be equal to 43")
+		return
+	}
+	key, err = base64.StdEncoding.DecodeString(encodedAESKey + "=")
+	if err != nil {
+		return
+	}
+	if len(key) != 32 {
+		err = errors.New("encodingAESKey invalid")
+		return
+	}
+	return
+}
+
+// AESDecryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
+func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID []byte, err error) {
+	const (
+		BlockSize = 32            // PKCS#7
+		BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
+	)
+
+	if len(ciphertext) < BlockSize {
+		err = fmt.Errorf("the length of ciphertext too short: %d", len(ciphertext))
+		return
+	}
+	if len(ciphertext)&BlockMask != 0 {
+		err = fmt.Errorf("ciphertext is not a multiple of the block size, the length is %d", len(ciphertext))
+		return
+	}
+
+	plaintext := make([]byte, len(ciphertext)) // len(plaintext) >= BLOCK_SIZE
+
+	// 解密
+	block, err := aes.NewCipher(aesKey)
+	if err != nil {
+		panic(err)
+	}
+	mode := cipher.NewCBCDecrypter(block, aesKey[:16])
+	mode.CryptBlocks(plaintext, ciphertext)
+
+	// PKCS#7 去除补位
+	amountToPad := int(plaintext[len(plaintext)-1])
+	if amountToPad < 1 || amountToPad > BlockSize {
+		err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad)
+		return
+	}
+	plaintext = plaintext[:len(plaintext)-amountToPad]
+
+	// 反拼接
+	// len(plaintext) == 16+4+len(rawXMLMsg)+len(appId)
+	if len(plaintext) <= 20 {
+		err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext))
+		return
+	}
+	rawXMLMsgLen := int(decodeNetworkByteOrder(plaintext[16:20]))
+	if rawXMLMsgLen < 0 {
+		err = fmt.Errorf("incorrect msg length: %d", rawXMLMsgLen)
+		return
+	}
+	appIDOffset := 20 + rawXMLMsgLen
+	if len(plaintext) <= appIDOffset {
+		err = fmt.Errorf("msg length too large: %d", rawXMLMsgLen)
+		return
+	}
+
+	random = plaintext[:16:20]
+	rawXMLMsg = plaintext[20:appIDOffset:appIDOffset]
+	appID = plaintext[appIDOffset:]
+	return
+}
+
+// 从 4 字节的网络字节序里解析出整数
+func decodeNetworkByteOrder(orderBytes []byte) (n uint32) {
+	return uint32(orderBytes[0])<<24 |
+		uint32(orderBytes[1])<<16 |
+		uint32(orderBytes[2])<<8 |
+		uint32(orderBytes[3])
+}

+ 18 - 0
util/signature.go

@@ -0,0 +1,18 @@
+package util
+
+import (
+	"crypto/sha1"
+	"fmt"
+	"io"
+	"sort"
+)
+
+//Signature sha1签名
+func Signature(params ...string) string {
+	sort.Strings(params)
+	h := sha1.New()
+	for _, s := range params {
+		io.WriteString(h, s)
+	}
+	return fmt.Sprintf("%x", h.Sum(nil))
+}

+ 50 - 0
wechat.go

@@ -0,0 +1,50 @@
+package wechat
+
+import (
+	"net/http"
+
+	"github.com/silenceper/wechat/context"
+	"github.com/silenceper/wechat/log"
+	"github.com/silenceper/wechat/server"
+)
+
+//Wechat struct
+type Wechat struct {
+	Context *context.Context
+}
+
+//Config for user
+type Config struct {
+	AppID          string
+	AppSecret      string
+	Token          string
+	EncodingAESKey string
+}
+
+//NewWechat init
+func NewWechat(cfg *Config) *Wechat {
+
+	channelLen := int64(10000)
+	adapterName := "console"
+	config := ""
+	logLevel := log.LevelDebug
+	log.InitLogger(channelLen, adapterName, config, logLevel)
+
+	context := new(context.Context)
+	copyConfigToContext(cfg, context)
+	return &Wechat{context}
+}
+
+func copyConfigToContext(cfg *Config, context *context.Context) {
+	context.AppID = cfg.AppID
+	context.AppSecret = cfg.AppSecret
+	context.Token = cfg.Token
+	context.EncodingAESKey = cfg.EncodingAESKey
+}
+
+//GetServer init
+func (wc *Wechat) GetServer(req *http.Request, writer http.ResponseWriter) *server.Server {
+	wc.Context.Request = req
+	wc.Context.Writer = writer
+	return server.NewServer(wc.Context)
+}