protocol/package.go

282 lines
8.2 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package protocol
import (
"bytes"
"errors"
"fmt"
"git.viry.cc/gomod/glog"
"hash/crc32"
"sync"
)
// PREFIX package的起始标志
const PREFIX uint8 = 0x95 // 1001 0101
// VERSION package的版本
const VERSION uint8 = 3
// flag标志位
const (
/*
Heartbeat机制, 用来检查链路是否正常
*/
FlagHeartbeatRequest uint8 = 1 << iota // 心跳请求信号, 发送心跳信号并请求心跳响应
FlagHeartbeatResponse // 心跳响应信号, 发送心跳信号
/*
Package校验机制, 如果携带此标志, 则
1. Header中携带了Header的哈希值, 用来确认Header完整性
2. Header中携带了Body的哈希值, 用来确认Body完整性
*/
FlagHeaderHashCheck // Header拥有校验值
FlagBodyHashCheck // Body拥有校验值
/*
Package确认机制, 接收方收到携带此标志的包裹后,需要在一定时间内向发送方发送确认信号
*/
FlagPackageConfirmRequest // Package请求确认信号, 接收后需要响应确认信号
FlagPackageConfirmResponse // Package确认信号,发送确认信号
)
// data的加密方式
const (
EncryptorNone uint8 = iota // 不加密
EncryptorXor // 异或加密
EncryptorAes // AES加密
)
// 错误
var (
errorPackageIncomplete = errors.New("package incomplete")
)
// 数据长度
const (
HeaderSize = 20 // header的大小
BodyMaxSize = 16384 // body的最大大小
)
// 协议包结构
type protocolPackage struct {
prefix uint8 // 1 byte 0x95
version uint8 // 1 byte protocol version
headerHash uint32 // 4 byte head crc32 checksum (BigEndian)
flag uint8 // 1 byte flag
encryptor uint8 // 1 byte encrypt method
value uint32 // 4 byte 在普通Package中保存Package的序号, 用以Confirm. 在Heartbeat中存储心跳序号和心跳超时时间
bodySize uint32 // 4 byte size of body (BigEndian)
bodyHash uint32 // 4 byte headerHash of body (BigEndian)
body []byte // ? byte body
}
// 缓存池
var (
// *protocolPackage
packagePool = sync.Pool{New: func() any { return &protocolPackage{} }}
// *bytes.Buffer
bytesBufferPool = sync.Pool{New: func() any { return &bytes.Buffer{} }}
// make([]byte, HeaderSize)
parsePackageBufferPool = sync.Pool{New: func() any { return make([]byte, HeaderSize) }}
)
// 创建新package, 所有新package的创建都必须通过此方法
//
// 返回的 *protocolPackage 使用结束后需要手动回收到 packagePool
func newPackage(flag uint8, encryptor uint8, value uint32, body []byte) *protocolPackage {
pkg := packagePool.Get().(*protocolPackage)
pkg.prefix = PREFIX
pkg.version = VERSION
pkg.headerHash = 0
pkg.flag = flag
pkg.encryptor = EncryptorNone
pkg.value = value
pkg.bodySize = uint32(len(body))
pkg.bodyHash = 0
pkg.body = body
if encryptor != EncryptorNone {
pkg.encrypt(encryptor)
}
if (flag & FlagBodyHashCheck) != 0 {
pkg.generateBodyHash()
}
if (flag & FlagHeaderHashCheck) != 0 {
pkg.generateHeaderHash()
}
return pkg
}
// 返回的 *bytes.Buffer 使用结束后需要手动回收到 bytesBufferPool
func (p *protocolPackage) bytesBuffer() *bytes.Buffer {
buf := bytesBufferPool.Get().(*bytes.Buffer)
// prefix
buf.WriteByte(p.prefix)
// version
buf.WriteByte(p.version)
// head hash
writeUint32Bytes(buf, p.headerHash)
// flag
buf.WriteByte(p.flag)
// encrypt method
buf.WriteByte(p.encryptor)
// value
writeUint32Bytes(buf, p.value)
// body curSize
writeUint32Bytes(buf, p.bodySize)
// body hash
writeUint32Bytes(buf, p.bodyHash)
// body
buf.Write(p.body)
return buf
}
// 将head中需要校验的数据拼接起来
//
// 返回的 *bytes.Buffer 使用结束后需要手动回收到 bytesBufferPool
func (p *protocolPackage) headerNeedCheck() *bytes.Buffer {
buf := bytesBufferPool.Get().(*bytes.Buffer)
buf.WriteByte(p.flag)
buf.WriteByte(p.encryptor)
writeUint32Bytes(buf, p.value)
writeUint32Bytes(buf, p.bodySize)
writeUint32Bytes(buf, p.bodyHash)
return buf
}
// 生成header hash
func (p *protocolPackage) generateHeaderHash() {
buf := p.headerNeedCheck()
defer func() {
buf.Reset()
bytesBufferPool.Put(buf)
}()
p.headerHash = crc32.ChecksumIEEE(buf.Bytes())
}
// 校验header
func (p *protocolPackage) checkHeader() bool {
buf := p.headerNeedCheck()
defer func() {
buf.Reset()
bytesBufferPool.Put(buf)
}()
return p.headerHash == crc32.ChecksumIEEE(buf.Bytes())
}
// 生成data的crc32
func (p *protocolPackage) generateBodyHash() {
p.bodyHash = crc32.ChecksumIEEE(p.body)
}
// 校验body
func (p *protocolPackage) checkBody() bool {
if (p.flag & FlagHeaderHashCheck) == 0 {
return true
}
if int(p.bodySize) != len(p.body) {
return false
}
return p.bodyHash == crc32.ChecksumIEEE(p.body)
}
// encrypt 加密data
func (p *protocolPackage) encrypt(encryptor uint8) {
switch encryptor {
case EncryptorXor:
p.xorEncrypt()
case EncryptorAes:
p.aesEncrypt()
default:
glog.Warning("[protocol_package] unknown encrypt method")
}
}
// decrypt 解密data
func (p *protocolPackage) decrypt() {
switch p.encryptor {
case EncryptorNone:
return
case EncryptorXor:
p.xorDecrypt()
case EncryptorAes:
p.aesDecrypt()
default:
glog.Warning("[protocol_package] unknown encrypt method")
}
}
// parsePackage 从buf中读取一个package
//
// 如果协议头标志(prefix)不匹配删除buf中除第一个字符外下一个prefix1到buf开头的所有数据
// 如果协议头标志(prefix)匹配不断从buf中取出数据填充到package结构体
// 如果buf中的数据出错无法正确提取package, 则返回(nil,true), 且已从buf中提取的数据不会退回buf
func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
bufLen := uint32(buf.Len())
bufData := buf.Bytes()
// 协议头标志不匹配删除未知数据寻找下一个package起始位置
if bufData[0] != PREFIX {
nextPackageHead(buf)
return nil, fmt.Errorf("prefix does not match, expected %v got %v", PREFIX, bufData[0])
}
// 判断package的版本
if bufData[1] != VERSION {
nextPackageHead(buf)
return nil, fmt.Errorf("unsupported version expected %d got %d", VERSION, bufData[1])
}
// 获取body长度
bodySize := uint32(bufData[HeaderSize-5]) | uint32(bufData[HeaderSize-6])<<8 | uint32(bufData[HeaderSize-7])<<16 | uint32(bufData[HeaderSize-8])<<24
// 判断Package是否接收完整
if bufLen < HeaderSize+bodySize {
return nil, errorPackageIncomplete
}
// 校验Header
var headerHashExpected uint32
if (bufData[6] & FlagHeaderHashCheck) != 0 {
var headerHashGot uint32
headerHashGot = crc32.ChecksumIEEE(bufData[6:HeaderSize])
headerHashExpected = uint32(bufData[5]) | uint32(bufData[4])<<8 | uint32(bufData[3])<<16 | uint32(bufData[2])<<24
if headerHashExpected != headerHashGot {
nextPackageHead(buf)
return nil, fmt.Errorf("wrong header hash, expected %d got %d", headerHashExpected, headerHashGot)
}
}
// 校验Body
var bodyHashExpected uint32
if (bufData[6] & FlagBodyHashCheck) != 0 {
var bodyHashGot uint32
bodyHashGot = crc32.ChecksumIEEE(bufData[HeaderSize:])
bodyHashExpected = uint32(bufData[HeaderSize-1]) | uint32(bufData[HeaderSize-2])<<8 | uint32(bufData[HeaderSize-3])<<16 | uint32(bufData[HeaderSize-4])<<24
if bodyHashExpected != bodyHashGot {
nextPackageHead(buf)
return nil, fmt.Errorf("wrong header hash, expected %d got %d", bodyHashExpected, bodyHashGot)
}
}
pkg := packagePool.Get().(*protocolPackage)
pkg.prefix = bufData[0]
pkg.version = bufData[1]
pkg.headerHash = headerHashExpected
pkg.flag = bufData[6]
pkg.encryptor = bufData[7]
pkg.value = uint32(bufData[11]) | uint32(bufData[10])<<8 | uint32(bufData[9])<<16 | uint32(bufData[8])<<24
pkg.bodySize = bodySize
pkg.bodyHash = bodyHashExpected
pkg.body = make([]byte, bodySize)
hd := parsePackageBufferPool.Get().([]byte)
buf.Read(hd)
buf.Read(pkg.body)
return pkg, nil
}
// 删除掉buf中第一个byte, 并将buf中的起始位置调整到与PREFIX相同的下一个元素的位置
func nextPackageHead(buf *bytes.Buffer) {
var err error
_, _ = buf.ReadByte()
_, err = buf.ReadBytes(PREFIX) // 只搜索与PREFIX相同元素防止PREFIX出现在buf末尾
if err == nil { // 找到下一个协议头标志把删掉的prefix回退到buffer
_ = buf.UnreadByte()
} else { // 找不到下一个协议头标志清空buffer
buf.Reset()
}
}