282 lines
8.2 KiB
Go
282 lines
8.2 KiB
Go
package protocol
|
||
|
||
import (
|
||
"bytes"
|
||
"errors"
|
||
"fmt"
|
||
"git.viry.cc/gomod/glog"
|
||
"hash/crc32"
|
||
"sync"
|
||
)
|
||
|
||
// PREFIX package的起始标志
|
||
const PREFIX uint8 = 0x95 // 1001 0101
|
||
|
||
// VERSION package的版本
|
||
const VERSION uint8 = 3
|
||
|
||
// flag标志位
|
||
const (
|
||
/*
|
||
Heartbeat机制, 用来检查链路是否正常
|
||
*/
|
||
FlagHeartbeatRequest uint8 = 1 << iota // 心跳请求信号, 发送心跳信号并请求心跳响应
|
||
FlagHeartbeatResponse // 心跳响应信号, 发送心跳信号
|
||
|
||
/*
|
||
Package校验机制, 如果携带此标志, 则
|
||
1. Header中携带了Header的哈希值, 用来确认Header完整性
|
||
2. Header中携带了Body的哈希值, 用来确认Body完整性
|
||
*/
|
||
FlagHeaderHashCheck // Header拥有校验值
|
||
FlagBodyHashCheck // Body拥有校验值
|
||
|
||
/*
|
||
Package确认机制, 接收方收到携带此标志的包裹后,需要在一定时间内向发送方发送确认信号
|
||
*/
|
||
FlagPackageConfirmRequest // Package请求确认信号, 接收后需要响应确认信号
|
||
FlagPackageConfirmResponse // Package确认信号,发送确认信号
|
||
)
|
||
|
||
// data的加密方式
|
||
const (
|
||
EncryptorNone uint8 = iota // 不加密
|
||
EncryptorXor // 异或加密
|
||
EncryptorAes // AES加密
|
||
)
|
||
|
||
// 错误
|
||
var (
|
||
errorPackageIncomplete = errors.New("package incomplete")
|
||
)
|
||
|
||
// 数据长度
|
||
const (
|
||
HeaderSize = 20 // header的大小
|
||
BodyMaxSize = 16384 // body的最大大小
|
||
)
|
||
|
||
// 协议包结构
|
||
type protocolPackage struct {
|
||
prefix uint8 // 1 byte 0x95
|
||
version uint8 // 1 byte protocol version
|
||
headerHash uint32 // 4 byte head crc32 checksum (BigEndian)
|
||
flag uint8 // 1 byte flag
|
||
encryptor uint8 // 1 byte encrypt method
|
||
value uint32 // 4 byte 在普通Package中保存Package的序号, 用以Confirm. 在Heartbeat中存储心跳序号和心跳超时时间
|
||
bodySize uint32 // 4 byte size of body (BigEndian)
|
||
bodyHash uint32 // 4 byte headerHash of body (BigEndian)
|
||
body []byte // ? byte body
|
||
}
|
||
|
||
// 缓存池
|
||
var (
|
||
// *protocolPackage
|
||
packagePool = sync.Pool{New: func() any { return &protocolPackage{} }}
|
||
// *bytes.Buffer
|
||
bytesBufferPool = sync.Pool{New: func() any { return &bytes.Buffer{} }}
|
||
// make([]byte, HeaderSize)
|
||
parsePackageBufferPool = sync.Pool{New: func() any { return make([]byte, HeaderSize) }}
|
||
)
|
||
|
||
// 创建新package, 所有新package的创建都必须通过此方法
|
||
//
|
||
// 返回的 *protocolPackage 使用结束后需要手动回收到 packagePool
|
||
func newPackage(flag uint8, encryptor uint8, value uint32, body []byte) *protocolPackage {
|
||
pkg := packagePool.Get().(*protocolPackage)
|
||
pkg.prefix = PREFIX
|
||
pkg.version = VERSION
|
||
pkg.headerHash = 0
|
||
pkg.flag = flag
|
||
pkg.encryptor = EncryptorNone
|
||
pkg.value = value
|
||
pkg.bodySize = uint32(len(body))
|
||
pkg.bodyHash = 0
|
||
pkg.body = body
|
||
if encryptor != EncryptorNone {
|
||
pkg.encrypt(encryptor)
|
||
}
|
||
if (flag & FlagBodyHashCheck) != 0 {
|
||
pkg.generateBodyHash()
|
||
}
|
||
if (flag & FlagHeaderHashCheck) != 0 {
|
||
pkg.generateHeaderHash()
|
||
}
|
||
return pkg
|
||
}
|
||
|
||
// 返回的 *bytes.Buffer 使用结束后需要手动回收到 bytesBufferPool
|
||
func (p *protocolPackage) bytesBuffer() *bytes.Buffer {
|
||
buf := bytesBufferPool.Get().(*bytes.Buffer)
|
||
// prefix
|
||
buf.WriteByte(p.prefix)
|
||
// version
|
||
buf.WriteByte(p.version)
|
||
// head hash
|
||
writeUint32Bytes(buf, p.headerHash)
|
||
// flag
|
||
buf.WriteByte(p.flag)
|
||
// encrypt method
|
||
buf.WriteByte(p.encryptor)
|
||
// value
|
||
writeUint32Bytes(buf, p.value)
|
||
// body curSize
|
||
writeUint32Bytes(buf, p.bodySize)
|
||
// body hash
|
||
writeUint32Bytes(buf, p.bodyHash)
|
||
// body
|
||
buf.Write(p.body)
|
||
return buf
|
||
}
|
||
|
||
// 将head中需要校验的数据拼接起来
|
||
//
|
||
// 返回的 *bytes.Buffer 使用结束后需要手动回收到 bytesBufferPool
|
||
func (p *protocolPackage) headerNeedCheck() *bytes.Buffer {
|
||
buf := bytesBufferPool.Get().(*bytes.Buffer)
|
||
buf.WriteByte(p.flag)
|
||
buf.WriteByte(p.encryptor)
|
||
writeUint32Bytes(buf, p.value)
|
||
writeUint32Bytes(buf, p.bodySize)
|
||
writeUint32Bytes(buf, p.bodyHash)
|
||
return buf
|
||
}
|
||
|
||
// 生成header hash
|
||
func (p *protocolPackage) generateHeaderHash() {
|
||
buf := p.headerNeedCheck()
|
||
defer func() {
|
||
buf.Reset()
|
||
bytesBufferPool.Put(buf)
|
||
}()
|
||
p.headerHash = crc32.ChecksumIEEE(buf.Bytes())
|
||
}
|
||
|
||
// 校验header
|
||
func (p *protocolPackage) checkHeader() bool {
|
||
buf := p.headerNeedCheck()
|
||
defer func() {
|
||
buf.Reset()
|
||
bytesBufferPool.Put(buf)
|
||
}()
|
||
return p.headerHash == crc32.ChecksumIEEE(buf.Bytes())
|
||
}
|
||
|
||
// 生成data的crc32
|
||
func (p *protocolPackage) generateBodyHash() {
|
||
p.bodyHash = crc32.ChecksumIEEE(p.body)
|
||
}
|
||
|
||
// 校验body
|
||
func (p *protocolPackage) checkBody() bool {
|
||
if (p.flag & FlagHeaderHashCheck) == 0 {
|
||
return true
|
||
}
|
||
if int(p.bodySize) != len(p.body) {
|
||
return false
|
||
}
|
||
return p.bodyHash == crc32.ChecksumIEEE(p.body)
|
||
}
|
||
|
||
// encrypt 加密data
|
||
func (p *protocolPackage) encrypt(encryptor uint8) {
|
||
switch encryptor {
|
||
case EncryptorXor:
|
||
p.xorEncrypt()
|
||
case EncryptorAes:
|
||
p.aesEncrypt()
|
||
default:
|
||
glog.Warning("[protocol_package] unknown encrypt method")
|
||
}
|
||
}
|
||
|
||
// decrypt 解密data
|
||
func (p *protocolPackage) decrypt() {
|
||
switch p.encryptor {
|
||
case EncryptorNone:
|
||
return
|
||
case EncryptorXor:
|
||
p.xorDecrypt()
|
||
case EncryptorAes:
|
||
p.aesDecrypt()
|
||
default:
|
||
glog.Warning("[protocol_package] unknown encrypt method")
|
||
}
|
||
}
|
||
|
||
// parsePackage 从buf中读取一个package
|
||
//
|
||
// 如果协议头标志(prefix)不匹配,删除buf中除第一个字符外,下一个prefix1到buf开头的所有数据
|
||
// 如果协议头标志(prefix)匹配,不断从buf中取出数据,填充到package结构体
|
||
// 如果buf中的数据出错,无法正确提取package, 则返回(nil,true), 且已从buf中提取的数据不会退回buf
|
||
func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
|
||
bufLen := uint32(buf.Len())
|
||
bufData := buf.Bytes()
|
||
// 协议头标志不匹配,删除未知数据,寻找下一个package起始位置
|
||
if bufData[0] != PREFIX {
|
||
nextPackageHead(buf)
|
||
return nil, fmt.Errorf("prefix does not match, expected %v got %v", PREFIX, bufData[0])
|
||
}
|
||
// 判断package的版本
|
||
if bufData[1] != VERSION {
|
||
nextPackageHead(buf)
|
||
return nil, fmt.Errorf("unsupported version expected %d got %d", VERSION, bufData[1])
|
||
}
|
||
// 获取body长度
|
||
bodySize := uint32(bufData[HeaderSize-5]) | uint32(bufData[HeaderSize-6])<<8 | uint32(bufData[HeaderSize-7])<<16 | uint32(bufData[HeaderSize-8])<<24
|
||
// 判断Package是否接收完整
|
||
if bufLen < HeaderSize+bodySize {
|
||
return nil, errorPackageIncomplete
|
||
}
|
||
// 校验Header
|
||
var headerHashExpected uint32
|
||
if (bufData[6] & FlagHeaderHashCheck) != 0 {
|
||
var headerHashGot uint32
|
||
headerHashGot = crc32.ChecksumIEEE(bufData[6:HeaderSize])
|
||
headerHashExpected = uint32(bufData[5]) | uint32(bufData[4])<<8 | uint32(bufData[3])<<16 | uint32(bufData[2])<<24
|
||
if headerHashExpected != headerHashGot {
|
||
nextPackageHead(buf)
|
||
return nil, fmt.Errorf("wrong header hash, expected %d got %d", headerHashExpected, headerHashGot)
|
||
}
|
||
}
|
||
// 校验Body
|
||
var bodyHashExpected uint32
|
||
if (bufData[6] & FlagBodyHashCheck) != 0 {
|
||
var bodyHashGot uint32
|
||
bodyHashGot = crc32.ChecksumIEEE(bufData[HeaderSize:])
|
||
bodyHashExpected = uint32(bufData[HeaderSize-1]) | uint32(bufData[HeaderSize-2])<<8 | uint32(bufData[HeaderSize-3])<<16 | uint32(bufData[HeaderSize-4])<<24
|
||
if bodyHashExpected != bodyHashGot {
|
||
nextPackageHead(buf)
|
||
return nil, fmt.Errorf("wrong header hash, expected %d got %d", bodyHashExpected, bodyHashGot)
|
||
}
|
||
}
|
||
|
||
pkg := packagePool.Get().(*protocolPackage)
|
||
pkg.prefix = bufData[0]
|
||
pkg.version = bufData[1]
|
||
pkg.headerHash = headerHashExpected
|
||
pkg.flag = bufData[6]
|
||
pkg.encryptor = bufData[7]
|
||
pkg.value = uint32(bufData[11]) | uint32(bufData[10])<<8 | uint32(bufData[9])<<16 | uint32(bufData[8])<<24
|
||
pkg.bodySize = bodySize
|
||
pkg.bodyHash = bodyHashExpected
|
||
pkg.body = make([]byte, bodySize)
|
||
hd := parsePackageBufferPool.Get().([]byte)
|
||
buf.Read(hd)
|
||
buf.Read(pkg.body)
|
||
|
||
return pkg, nil
|
||
}
|
||
|
||
// 删除掉buf中第一个byte, 并将buf中的起始位置调整到与PREFIX相同的下一个元素的位置
|
||
func nextPackageHead(buf *bytes.Buffer) {
|
||
var err error
|
||
_, _ = buf.ReadByte()
|
||
_, err = buf.ReadBytes(PREFIX) // 只搜索与PREFIX相同元素,防止PREFIX出现在buf末尾
|
||
if err == nil { // 找到下一个协议头标志,把删掉的prefix回退到buffer
|
||
_ = buf.UnreadByte()
|
||
} else { // 找不到下一个协议头标志,清空buffer
|
||
buf.Reset()
|
||
}
|
||
}
|