v2
Test / testing (1.19.13, ubuntu-latest) (push) Failing after 4m30s Details
Test / testing (>=1.20, ubuntu-latest) (push) Failing after 1m19s Details

This commit is contained in:
Akvicor 2024-03-15 19:14:58 +08:00
parent 1c9129c1e3
commit f8fe8f5288
11 changed files with 506 additions and 295 deletions

View File

@ -7,6 +7,16 @@
[![License: GPL v2](https://img.shields.io/badge/License-GPL%20v2-blue.svg)](https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html)
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/git.viry.cc/gomod/protocol?tab=doc)
# TODO
- 心跳服务发送单独的Package, 如果value不为0, 则修改心跳超时时间
- 完全关闭心跳, 即双方都不发送心跳信号
- 单/双方间隔发送HeartbeatResponse
- 单/双方间隔发送HeartbeatRequest, 接收方收到Request需要响应Response
- 心跳信号不管是Request还是Response, 都可以携带Confirm信息
- Confirm信息的键由Heartbeat携带
- 更新glog包使得glog能够接受外部File的Paper, 保证使用glog的mod和使用mod的项目公用一个log文件
```go
import "git.viry.cc/gomod/protocol"
```

60
encryptor_aes.go Normal file
View File

@ -0,0 +1,60 @@
package protocol
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
)
var aesEncryptorEnabled = false
var aesCipherBlock cipher.Block = nil
var aesCipherBlockModeEncrypter cipher.BlockMode = nil
var aesCipherBlockModeDecrypter cipher.BlockMode = nil
var aesCipherBlockSize int
var aesCipherIV []byte
func SetEncryptorAesKey(key []byte) (err error) {
// 分组秘钥
aesCipherBlock, err = aes.NewCipher(key)
if err != nil {
return fmt.Errorf("key 长度必须 16/24/32长度: %s", err.Error())
}
aesCipherBlockSize = aesCipherBlock.BlockSize()
aesCipherIV = key[:aesCipherBlockSize]
aesCipherBlockModeEncrypter = cipher.NewCBCEncrypter(aesCipherBlock, aesCipherIV)
aesCipherBlockModeDecrypter = cipher.NewCBCDecrypter(aesCipherBlock, aesCipherIV)
aesEncryptorEnabled = true
return nil
}
func (p *protocolPackage) aesEncrypt() {
if !aesEncryptorEnabled {
return
}
// PKCS7Padding 补码
padding := aesCipherBlockSize - int(p.bodySize)%aesCipherBlockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
p.bodySize += uint32(padding)
p.body = append(p.body, padtext...)
// 创建数组
cryted := make([]byte, p.bodySize)
// 加密
aesCipherBlockModeEncrypter.CryptBlocks(cryted, p.body)
p.body = cryted
p.encryptor = EncryptorAes
}
func (p *protocolPackage) aesDecrypt() {
if !aesEncryptorEnabled {
return
}
// 创建数组
orig := make([]byte, p.bodySize)
// 解密
aesCipherBlockModeDecrypter.CryptBlocks(orig, p.body)
// PKCS7UnPadding 去码
p.bodySize = p.bodySize - uint32(orig[p.bodySize-1])
p.body = orig[:p.bodySize]
p.encryptor = EncryptorNone
}

54
encryptor_aes_test.go Normal file
View File

@ -0,0 +1,54 @@
package protocol
import (
"bytes"
"fmt"
"git.viry.cc/gomod/util"
"testing"
)
func TestEncryptorAes(t *testing.T) {
var ori, data, key []byte
var err error
for i := 1; i < 1024; i++ {
ori = util.RandomKey(i).Bytes()
data = make([]byte, len(ori))
copy(data, ori)
for j := 0; j < 3; j++ {
key = util.RandomKey(16 + j*8).Bytes()
// check
err = SetEncryptorAesKey(key)
if err != nil {
t.Errorf("set aes key failed, %v", err)
}
pkg := newPackage(0, 0, 0, data)
// fmt.Printf("body %d", pkg.bodySize)
pkg.aesEncrypt()
// fmt.Printf(" -> %d", pkg.bodySize)
pkg.aesDecrypt()
// fmt.Printf(" -> %d\n", pkg.bodySize)
if uint32(len(pkg.body)) != pkg.bodySize || pkg.bodySize != uint32(len(ori)) {
t.Errorf("expected [%d] got [%d, %d]\n", uint32(len(ori)), pkg.bodySize, uint32(len(pkg.body)))
}
if !bytes.Equal(pkg.body, ori) {
t.Errorf("expected [%0#2v] got [%0#2v]\n", ori, data)
}
}
}
}
func BenchmarkEncryptorAesEncode(b *testing.B) {
data := util.RandomKey(256 * 1024 * 1024).Bytes()
key := util.RandomKey(16).Bytes()
pkg := newPackage(0, 0, 0, data)
err := SetEncryptorAesKey(key)
if err != nil {
panic(fmt.Sprintf("set aes key failed, %v", err))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
pkg.aesEncrypt()
pkg.aesDecrypt()
}
}

31
encryptor_xor.go Normal file
View File

@ -0,0 +1,31 @@
package protocol
var xorEncryptorEnabled = false
var xorKey []byte
var xorKeyLen int
func SetEncryptorXorKey(key []byte) {
xorKey = key
xorKeyLen = len(key)
xorEncryptorEnabled = true
}
func (p *protocolPackage) xorEncrypt() {
if !xorEncryptorEnabled {
return
}
for i := range p.body {
p.body[i] ^= xorKey[i%xorKeyLen]
}
p.encryptor = EncryptorXor
}
func (p *protocolPackage) xorDecrypt() {
if !xorEncryptorEnabled {
return
}
for i := range p.body {
p.body[i] ^= xorKey[i%xorKeyLen]
}
p.encryptor = EncryptorNone
}

52
encryptor_xor_test.go Normal file
View File

@ -0,0 +1,52 @@
package protocol
import (
"bytes"
"git.viry.cc/gomod/util"
"testing"
)
func TestEncryptorXor(t *testing.T) {
var ori, data, key, ans []byte
for i := 1; i < 1024; i++ {
ori = util.RandomKey(i).Bytes()
data = make([]byte, len(ori))
copy(data, ori)
for j := 1; j < 128; j++ {
key = util.RandomKey(j).Bytes()
// build ans
ans = make([]byte, len(ori))
copy(ans, ori)
for k := range ans {
ans[k] ^= key[k%j]
}
// check
SetEncryptorXorKey(key)
pkg := newPackage(0, 0, 0, data)
pkg.xorEncrypt()
if !bytes.Equal(pkg.body, ans) {
t.Errorf("expected [%0#2v] got [%0#2v]\n", ans, data)
}
pkg.xorDecrypt()
if uint32(len(pkg.body)) != pkg.bodySize || pkg.bodySize != uint32(len(ori)) {
t.Errorf("expected [%d] got [%d, %d]\n", uint32(len(ori)), pkg.bodySize, uint32(len(pkg.body)))
}
if !bytes.Equal(pkg.body, ori) {
t.Errorf("expected [%0#2v] got [%0#2v]\n", ori, data)
}
}
}
}
func BenchmarkEncryptorXorEncode(b *testing.B) {
data := util.RandomKey(256 * 1024 * 1024).Bytes()
key := util.RandomKey(16).Bytes()
pkg := newPackage(0, 0, 0, data)
SetEncryptorXorKey(key)
b.ResetTimer()
for i := 0; i < b.N; i++ {
pkg.xorEncrypt()
pkg.xorDecrypt()
}
}

55
log.go
View File

@ -5,50 +5,28 @@ import (
)
const (
MaskUNKNOWN = glog.MaskUNKNOWN
MaskDEBUG = glog.MaskDEBUG
MaskTRACE = glog.MaskTRACE
MaskINFO = glog.MaskINFO
MaskWARNING = glog.MaskWARNING
MaskERROR = glog.MaskERROR
MaskFATAL = glog.MaskFATAL
MaskStd = glog.MaskStd
MaskAll = glog.MaskAll
MaskDev = MaskFATAL | MaskERROR | MaskWARNING | MaskINFO | MaskTRACE | MaskDEBUG | MaskUNKNOWN
MaskProd = MaskFATAL | MaskERROR | MaskWARNING
LogMaskDEBUG uint32 = glog.MaskUNKNOWN | glog.MaskDEBUG | glog.MaskTRACE | glog.MaskINFO | glog.MaskWARNING | glog.MaskERROR | glog.MaskFATAL
LogFlagDEBUG uint32 = glog.FlagDate | glog.FlagTime | glog.FlagShortFile | glog.FlagPrefix | glog.FlagSuffix
)
const (
FlagDate = glog.FlagDate
FlagTime = glog.FlagTime
FlagLongFile = glog.FlagLongFile
FlagShortFile = glog.FlagShortFile
FlagFunc = glog.FlagFunc
FlagPrefix = glog.FlagPrefix
FlagSuffix = glog.FlagSuffix
LogMaskINFO uint32 = glog.MaskINFO | glog.MaskWARNING | glog.MaskERROR | glog.MaskFATAL
LogFlagINFO uint32 = glog.FlagDate | glog.FlagTime | glog.FlagPrefix
)
FlagStd = glog.FlagStd
FlagAll = glog.FlagAll
const (
LogMaskWARNING uint32 = glog.MaskWARNING | glog.MaskERROR | glog.MaskFATAL
LogFlagWARNING uint32 = glog.FlagDate | glog.FlagTime | glog.FlagPrefix
)
FlagDev = FlagDate | FlagTime | FlagShortFile | FlagFunc | FlagPrefix | FlagSuffix
FlagProd = FlagDate | FlagTime | FlagPrefix
const (
LogMaskNONE uint32 = 0
LogFlagNONE uint32 = glog.FlagDate | glog.FlagTime | glog.FlagPrefix
)
func init() {
glog.SetMask(MaskStd)
glog.SetFlag(FlagStd)
}
func SetLogProd(isProd bool) {
if isProd {
glog.SetMask(MaskProd)
glog.SetFlag(FlagProd)
} else {
glog.SetMask(MaskStd)
glog.SetFlag(FlagStd)
}
glog.SetMask(LogMaskNONE)
glog.SetFlag(LogFlagNONE)
}
func SetLogMask(mask uint32) {
@ -58,8 +36,3 @@ func SetLogMask(mask uint32) {
func SetLogFlag(f uint32) {
glog.SetFlag(f)
}
func init() {
glog.SetMask(MaskProd)
glog.SetFlag(FlagProd)
}

View File

@ -4,268 +4,278 @@ import (
"bytes"
"errors"
"fmt"
"git.viry.cc/gomod/glog"
"git.viry.cc/gomod/util"
"hash/crc32"
"sync"
)
// package的起始标志
var prefix = [headLengthPrefix]uint8{0xff, 0x07, 0x55, 0x00}
// PREFIX package的起始标志
const PREFIX uint8 = 0x95 // 1001 0101
var ErrorPackageIncomplete = errors.New("package incomplete")
// head中各部分的长度
const (
headLengthPrefix = 4
headLengthVersion = 1
headLengthCRC32Checksum = 4
headLengthFlag = 1
headLengthEncryptMethod = 1
headLengthCustomValue = 1
headLengthDataSize = 4
headLengthDataCrc32 = 4
)
// head中各部分的偏移
const (
headOffsetPrefix = 0
headOffsetVersion = headOffsetPrefix + headLengthPrefix
headOffsetCRC32Checksum = headOffsetVersion + headLengthVersion
headOffsetFlag = headOffsetCRC32Checksum + headLengthCRC32Checksum
headOffsetEncryptMethod = headOffsetFlag + headLengthFlag
headOffsetCustomValue = headOffsetEncryptMethod + headLengthEncryptMethod
headOffsetDataSize = headOffsetCustomValue + headLengthCustomValue
headOffsetDataCrc32 = headOffsetDataSize + headLengthDataSize
headOffsetData = headOffsetDataCrc32 + headLengthDataCrc32
)
// 计算head的crc32时起始偏移
const headOffsetNeedCheck = headOffsetCRC32Checksum + headLengthCRC32Checksum
// data的加密方式
const (
encryptNone uint8 = iota
)
// VERSION package的版本
const VERSION uint8 = 3
// flag标志位
const (
// 普通心跳信号,心跳响应信号
flagHeartbeat uint8 = 1 << iota
// 心跳请求信号接收方必须回复flagHeartbeat
flagHeartbeatRequest
/*
Heartbeat机制, 用来检查链路是否正常
*/
FlagHeartbeatRequest uint8 = 1 << iota // 心跳请求信号, 发送心跳信号并请求心跳响应
FlagHeartbeatResponse // 心跳响应信号, 发送心跳信号
/*
Package校验机制, 如果携带此标志,
1. Header中携带了Header的哈希值, 用来确认Header完整性
2. Header中携带了Body的哈希值, 用来确认Body完整性
*/
FlagHeaderHashCheck // Header拥有校验值
FlagBodyHashCheck // Body拥有校验值
/*
Package确认机制, 接收方收到携带此标志的包裹后,需要在一定时间内向发送方发送确认信号
*/
FlagPackageConfirmRequest // Package请求确认信号, 接收后需要响应确认信号
FlagPackageConfirmResponse // Package确认信号,发送确认信号
)
// package的head的大小 (byte)
const packageHeadSize = headOffsetData
// data的加密方式
const (
EncryptorNone uint8 = iota // 不加密
EncryptorXor // 异或加密
EncryptorAes // AES加密
)
// package的最大size (byte)
const packageMaxSize = 4096
// 错误
var (
errorPackageIncomplete = errors.New("package incomplete")
)
// package的data的最大size (byte)
const dataMaxSize = packageMaxSize - packageHeadSize
// 数据长度
const (
HeaderSize = 20 // header的大小
BodyMaxSize = 16384 // body的最大大小
)
// 协议包结构
type protocolPackage struct {
prefix [headLengthPrefix]byte // 4 byte 0xff 0x55
version uint8 // 1 byte protocol version
crc32 uint32 // 4 byte head crc32 checksum (BigEndian)
flag uint8 // 1 byte flag
encryptMethod uint8 // 1 byte encrypted method
value uint8 // 1 byte custom value (for heartbeat)
dataSize uint32 // 4 byte curSize of data (BigEndian)
dataCrc32 uint32 // 4 byte crc32 of data (BigEndian)
data []byte // ? byte data
prefix uint8 // 1 byte 0x95
version uint8 // 1 byte protocol version
headerHash uint32 // 4 byte head crc32 checksum (BigEndian)
flag uint8 // 1 byte flag
encryptor uint8 // 1 byte encrypt method
value uint32 // 4 byte 在普通Package中保存Package的序号, 用以Confirm. 在Heartbeat中存储心跳序号和心跳超时时间
bodySize uint32 // 4 byte size of body (BigEndian)
bodyHash uint32 // 4 byte headerHash of body (BigEndian)
body []byte // ? byte body
}
// 缓存池
var (
// *protocolPackage
packagePool = sync.Pool{New: func() any { return &protocolPackage{} }}
// *bytes.Buffer
bytesBufferPool = sync.Pool{New: func() any { return &bytes.Buffer{} }}
// make([]byte, HeaderSize)
parsePackageBufferPool = sync.Pool{New: func() any { return make([]byte, HeaderSize) }}
)
// 创建新package, 所有新package的创建都必须通过此方法
func newPackage(flag uint8, encrypt uint8, value uint8, data []byte) *protocolPackage {
pkg := &protocolPackage{
prefix: prefix,
version: VERSION,
crc32: 0,
flag: flag,
encryptMethod: encryptNone,
value: value,
dataSize: uint32(len(data)),
dataCrc32: 0,
data: data,
//
// 返回的 *protocolPackage 使用结束后需要手动回收到 packagePool
func newPackage(flag uint8, encryptor uint8, value uint32, body []byte) *protocolPackage {
pkg := packagePool.Get().(*protocolPackage)
pkg.prefix = PREFIX
pkg.version = VERSION
pkg.headerHash = 0
pkg.flag = flag
pkg.encryptor = EncryptorNone
pkg.value = value
pkg.bodySize = uint32(len(body))
pkg.bodyHash = 0
pkg.body = body
if encryptor != EncryptorNone {
pkg.encrypt(encryptor)
}
if (flag & FlagBodyHashCheck) != 0 {
pkg.generateBodyHash()
}
if (flag & FlagHeaderHashCheck) != 0 {
pkg.generateHeaderHash()
}
pkg.generateDataCheck()
pkg.generateHeadCheck()
pkg.encrypt(encrypt)
return pkg
}
func (p *protocolPackage) Bytes() *bytes.Buffer {
buf := &bytes.Buffer{}
// 返回的 *bytes.Buffer 使用结束后需要手动回收到 bytesBufferPool
func (p *protocolPackage) bytesBuffer() *bytes.Buffer {
buf := bytesBufferPool.Get().(*bytes.Buffer)
// prefix
buf.Write(p.prefix[:])
buf.WriteByte(p.prefix)
// version
buf.WriteByte(p.version)
// crc32
buf.Write(util.UInt32ToBytesSlice(p.crc32))
// head hash
writeUint32Bytes(buf, p.headerHash)
// flag
buf.WriteByte(p.flag)
// encrypt method
buf.WriteByte(p.encryptMethod)
buf.WriteByte(p.encryptor)
// value
buf.WriteByte(p.value)
// data curSize
buf.Write(util.UInt32ToBytesSlice(p.dataSize))
// data crc32
buf.Write(util.UInt32ToBytesSlice(p.dataCrc32))
// data
buf.Write(p.data)
writeUint32Bytes(buf, p.value)
// body curSize
writeUint32Bytes(buf, p.bodySize)
// body hash
writeUint32Bytes(buf, p.bodyHash)
// body
buf.Write(p.body)
return buf
}
// 将head中需要校验的数据拼接起来
func (p *protocolPackage) headNeedCheckBytes() *bytes.Buffer {
buf := &bytes.Buffer{}
//
// 返回的 *bytes.Buffer 使用结束后需要手动回收到 bytesBufferPool
func (p *protocolPackage) headerNeedCheck() *bytes.Buffer {
buf := bytesBufferPool.Get().(*bytes.Buffer)
buf.WriteByte(p.flag)
buf.WriteByte(p.encryptMethod)
buf.WriteByte(p.value)
buf.Write(util.UInt32ToBytesSlice(p.dataSize))
buf.Write(util.UInt32ToBytesSlice(p.dataCrc32))
buf.WriteByte(p.encryptor)
writeUint32Bytes(buf, p.value)
writeUint32Bytes(buf, p.bodySize)
writeUint32Bytes(buf, p.bodyHash)
return buf
}
// 生成head的crc32
func (p *protocolPackage) generateHeadCheck() {
p.crc32 = util.NewCRC32().FromBytes(p.headNeedCheckBytes().Bytes()).Value()
glog.Trace("[protocol_package] head crc32 is %d", p.crc32)
// 生成header hash
func (p *protocolPackage) generateHeaderHash() {
buf := p.headerNeedCheck()
defer func() {
buf.Reset()
bytesBufferPool.Put(buf)
}()
p.headerHash = crc32.ChecksumIEEE(buf.Bytes())
}
// 校验head的crc32
func (p *protocolPackage) checkHead() bool {
return p.crc32 == util.NewCRC32().FromBytes(p.headNeedCheckBytes().Bytes()).Value()
// 校验header
func (p *protocolPackage) checkHeader() bool {
buf := p.headerNeedCheck()
defer func() {
buf.Reset()
bytesBufferPool.Put(buf)
}()
return p.headerHash == crc32.ChecksumIEEE(buf.Bytes())
}
// 生成data的crc32
func (p *protocolPackage) generateDataCheck() {
p.dataCrc32 = util.NewCRC32().FromBytes(p.data).Value()
glog.Trace("[protocol_package] data crc32 is %d", p.dataCrc32)
func (p *protocolPackage) generateBodyHash() {
p.bodyHash = crc32.ChecksumIEEE(p.body)
}
// 校验data的crc32
func (p *protocolPackage) checkData() bool {
if int(p.dataSize) != len(p.data) {
glog.Trace("[protocol_package] pkg.dataSize != len(pkg.data)")
// 校验body
func (p *protocolPackage) checkBody() bool {
if (p.flag & FlagHeaderHashCheck) == 0 {
return true
}
if int(p.bodySize) != len(p.body) {
return false
}
return p.dataCrc32 == util.NewCRC32().FromBytes(p.data).Value()
return p.bodyHash == crc32.ChecksumIEEE(p.body)
}
// encrypt 加密data
func (p *protocolPackage) encrypt(method uint8) {
if p.encryptMethod == method {
glog.Trace("[protocol_package] is already encrypted [%d]", method)
return // 已经加密
func (p *protocolPackage) encrypt(encryptor uint8) {
switch encryptor {
case EncryptorXor:
p.xorEncrypt()
case EncryptorAes:
p.aesEncrypt()
default:
glog.Warning("[protocol_package] unknown encrypt method")
}
if p.encryptMethod != encryptNone {
glog.Trace("[protocol_package] encrypt with other method got [%d] need encryptNone[%d]", p.encryptMethod, encryptNone)
return // 已经通过其他方式加密
}
glog.Warning("[protocol_package] unknown encrypt method")
}
// decrypt 解密data
func (p *protocolPackage) decrypt() {
if !p.isEncrypted() {
glog.Trace("[protocol_package] is not encrypted")
switch p.encryptor {
case EncryptorNone:
return
case EncryptorXor:
p.xorDecrypt()
case EncryptorAes:
p.aesDecrypt()
default:
glog.Warning("[protocol_package] unknown encrypt method")
}
}
// isEncrypted 是否已经加密
func (p *protocolPackage) isEncrypted() bool {
return p.encryptMethod != encryptNone
}
// parsePackage 从buf中读取一个package
//
// 如果协议头标志(prefix)不匹配删除buf中除第一个字符外下一个prefix1到buf开头的所有数据
// 如果协议头标志(prefix)匹配不断从buf中取出数据填充到package结构体
// 如果buf中的数据出错无法正确提取package, 则返回(nil,true), 且已从buf中提取的数据不会退回buf
func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
bufLen := uint32(buf.Len())
bufData := buf.Bytes()
// 协议头标志不匹配删除未知数据寻找下一个package起始位置
if !bytes.Equal(prefix[:], buf.Bytes()[headOffsetPrefix:headOffsetPrefix+headLengthPrefix]) {
glog.Trace("[protocol_package] prefix does not match, need %v got %v", prefix, buf.Bytes()[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
if bufData[0] != PREFIX {
nextPackageHead(buf)
return nil, fmt.Errorf("prefix does not match, need %v got %v", prefix, buf.Bytes()[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
return nil, fmt.Errorf("prefix does not match, expected %v got %v", PREFIX, bufData[0])
}
// 判断package的版本
// 暂时只处理VERSION版本的package
if buf.Len() < headOffsetVersion+headLengthVersion {
glog.Trace("[protocol_package] incomplete version information, need %d got %d", headOffsetVersion+headLengthVersion, buf.Len())
return nil, ErrorPackageIncomplete
}
if buf.Bytes()[headOffsetVersion] != VERSION {
glog.Trace("[protocol_package] unsupported version need %d got %d", VERSION, buf.Bytes()[headOffsetVersion])
if bufData[1] != VERSION {
nextPackageHead(buf)
return nil, fmt.Errorf("unsupported version need %d got %d", VERSION, buf.Bytes()[headOffsetVersion])
return nil, fmt.Errorf("unsupported version expected %d got %d", VERSION, bufData[1])
}
// 开始判断是否为package并提取package
if buf.Len() < packageHeadSize {
glog.Trace("[protocol_package] incomplete head, need %d got %d", packageHeadSize, buf.Len())
return nil, ErrorPackageIncomplete
// 获取body长度
bodySize := uint32(bufData[HeaderSize-5]) | uint32(bufData[HeaderSize-6])<<8 | uint32(bufData[HeaderSize-7])<<16 | uint32(bufData[HeaderSize-8])<<24
// 判断Package是否接收完整
if bufLen < HeaderSize+bodySize {
return nil, errorPackageIncomplete
}
head := make([]byte, packageHeadSize)
copy(head, buf.Bytes()[:packageHeadSize])
// 检查head是否完整删除未知数据寻找下一个package起始位置
headChecksum := util.BytesSliceToUInt32(head[headOffsetCRC32Checksum : headOffsetCRC32Checksum+headLengthCRC32Checksum])
headCrc32 := util.NewCRC32().FromBytes(head[headOffsetNeedCheck:]).Value()
if headChecksum != headCrc32 {
glog.Trace("[protocol_package] head crc32 checksum does not match, need %d got %d", headChecksum, headCrc32)
nextPackageHead(buf)
return nil, fmt.Errorf("head crc32 checksum does not match, need %d got %d", headChecksum, headCrc32)
// 校验Header
var headerHashExpected uint32
if (bufData[6] & FlagHeaderHashCheck) != 0 {
var headerHashGot uint32
headerHashGot = crc32.ChecksumIEEE(bufData[6:HeaderSize])
headerHashExpected = uint32(bufData[5]) | uint32(bufData[4])<<8 | uint32(bufData[3])<<16 | uint32(bufData[2])<<24
if headerHashExpected != headerHashGot {
nextPackageHead(buf)
return nil, fmt.Errorf("wrong header hash, expected %d got %d", headerHashExpected, headerHashGot)
}
}
// 检查package是否完整不完整则等待
packageDataSize := util.BytesSliceToUInt32(head[headOffsetDataSize : headOffsetDataSize+headLengthDataSize])
if packageHeadSize+int(packageDataSize) > buf.Len() {
glog.Trace("[protocol_package] incomplete data, need %d got %d", packageHeadSize+packageDataSize, buf.Len())
return nil, ErrorPackageIncomplete
}
// package完整
pkg := &protocolPackage{}
_, _ = buf.Read(make([]byte, packageHeadSize))
// prefix
copy(pkg.prefix[:], head[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
// crc32
pkg.version = head[headOffsetVersion]
// crc32
pkg.crc32 = headCrc32
// flag
pkg.flag = head[headOffsetFlag]
// encrypt method
pkg.encryptMethod = head[headOffsetEncryptMethod]
// value
pkg.value = head[headOffsetCustomValue]
// data curSize
pkg.dataSize = packageDataSize
// data crc32
pkg.dataCrc32 = util.BytesSliceToUInt32(head[headOffsetDataCrc32 : headOffsetDataCrc32+headLengthDataCrc32])
// data
pkg.data = make([]byte, pkg.dataSize)
_, _ = buf.Read(pkg.data)
dataCrc32 := util.NewCRC32().FromBytes(pkg.data).Value()
if pkg.dataCrc32 != dataCrc32 {
glog.Trace("[protocol_package] data crc32 checksum does not match, need %d got %d", pkg.dataCrc32, dataCrc32)
nextPackageHead(buf)
return nil, fmt.Errorf("data crc32 checksum does not match, need %d got %d", pkg.dataCrc32, dataCrc32)
// 校验Body
var bodyHashExpected uint32
if (bufData[6] & FlagBodyHashCheck) != 0 {
var bodyHashGot uint32
bodyHashGot = crc32.ChecksumIEEE(bufData[HeaderSize:])
bodyHashExpected = uint32(bufData[HeaderSize-1]) | uint32(bufData[HeaderSize-2])<<8 | uint32(bufData[HeaderSize-3])<<16 | uint32(bufData[HeaderSize-4])<<24
if bodyHashExpected != bodyHashGot {
nextPackageHead(buf)
return nil, fmt.Errorf("wrong header hash, expected %d got %d", bodyHashExpected, bodyHashGot)
}
}
pkg := packagePool.Get().(*protocolPackage)
pkg.prefix = bufData[0]
pkg.version = bufData[1]
pkg.headerHash = headerHashExpected
pkg.flag = bufData[6]
pkg.encryptor = bufData[7]
pkg.value = uint32(bufData[11]) | uint32(bufData[10])<<8 | uint32(bufData[9])<<16 | uint32(bufData[8])<<24
pkg.bodySize = bodySize
pkg.bodyHash = bodyHashExpected
pkg.body = make([]byte, bodySize)
hd := parsePackageBufferPool.Get().([]byte)
buf.Read(hd)
buf.Read(pkg.body)
return pkg, nil
}
// 删除掉buf中第一个byte, 并将buf中的起始位置调整到与prefix[0]相同的下一个元素的位置
// 删除掉buf中第一个byte, 并将buf中的起始位置调整到与PREFIX相同的下一个元素的位置
func nextPackageHead(buf *bytes.Buffer) {
var err error
_, _ = buf.ReadByte()
_, err = buf.ReadBytes(prefix[0]) // 只搜索与prefix[0]相同元素防止prefix[0]出现在buf末尾
if err == nil { // 找到下一个协议头标志把删掉的prefix回退到buffer
_, err = buf.ReadBytes(PREFIX) // 只搜索与PREFIX相同元素防止PREFIX出现在buf末尾
if err == nil { // 找到下一个协议头标志把删掉的prefix回退到buffer
_ = buf.UnreadByte()
glog.Trace("[protocol_package] prefix does not match, prefix[0] found, trim buf, buf length [%d]", buf.Len())
} else { // 找不到下一个协议头标志清空buffer
buf.Reset()
glog.Trace("[protocol_package] prefix does not match, prefix[0] not found, reset buf, buf length [%d]", buf.Len())
}
}

View File

@ -3,15 +3,12 @@ package protocol
import (
"bytes"
"errors"
"git.viry.cc/gomod/glog"
"io"
"sync/atomic"
"time"
"git.viry.cc/gomod/glog"
)
const VERSION uint8 = 1
var ErrorReadCallbackIsNil = errors.New("read callback is nil")
var ErrorReaderIsNil = errors.New("reader is nil")
var ErrorWriterIsNil = errors.New("writer is nil")
@ -20,7 +17,7 @@ var ErrorWriterIsKilled = errors.New("writer is killed")
var ErrorWriterQueueIsNil = errors.New("writer queue is nil")
var ErrorHeartbeatIsKilled = errors.New("heartbeat is killed")
var ErrorHeartbeatCallbackIsNil = errors.New("heartbeat callback is nil")
var ErrorDataSizeExceedsLimit = errors.New("data size exceeds limit")
var ErrorDataSizeExceedsLimit = errors.New("body size exceeds limit")
var ErrorTimeout = errors.New("timeout")
const (
@ -44,9 +41,9 @@ type Protocol struct {
runningRoutines int32
// 心跳信号,同时也是心跳响应信号
heartbeatSig chan uint8
heartbeatSig chan uint32
// 心跳请求信号,收到此信号必须回复对方
heartbeatSigReq chan uint8
heartbeatSigReq chan uint32
// 发送心跳请求的间隔
heartbeatInterval uint32
// 接收心跳请求的超时时间
@ -114,8 +111,8 @@ func New(tag string, r io.Reader, w io.Writer, writeQueueSize int, readCallback
readCallback: readCallback,
writeQueue: newQueue(writeQueueSize),
runningRoutines: 0,
heartbeatSig: make(chan uint8, 1),
heartbeatSigReq: make(chan uint8, 1),
heartbeatSig: make(chan uint32, 1),
heartbeatSigReq: make(chan uint32, 1),
heartbeatInterval: 15,
heartbeatTimeout: 40,
heartbeatTimeoutCallback: heartbeatTimeoutCallback,
@ -144,32 +141,32 @@ func (p *Protocol) handlePackage(pkg *protocolPackage) {
glog.Trace("[protocol.%s] package is nil", p.tag)
return
}
if pkg.isEncrypted() {
glog.Trace("[protocol.%s] package is encrypted, decrypt package", p.tag)
pkg.decrypt()
}
if !pkg.checkHead() {
defer packagePool.Put(pkg)
if (pkg.flag&FlagHeaderHashCheck) != 0 && !pkg.checkHeader() {
glog.Trace("[protocol.%s] package head broken", p.tag)
return
}
if (pkg.flag & flagHeartbeat) != 0 {
if pkg.encryptor != EncryptorNone {
pkg.decrypt()
}
if (pkg.flag & FlagHeartbeatResponse) != 0 {
glog.Info("[protocol.%s] heartbeat signal in package", p.tag)
p.heartbeatSig <- pkg.value
}
if (pkg.flag & flagHeartbeatRequest) != 0 {
if (pkg.flag & FlagHeartbeatRequest) != 0 {
glog.Info("[protocol.%s] heartbeat request signal in package", p.tag)
p.heartbeatSigReq <- pkg.value
}
if !pkg.checkData() {
glog.Trace("[protocol.%s] package data broken", p.tag)
if !pkg.checkBody() {
glog.Trace("[protocol.%s] package body broken", p.tag)
return
}
if pkg.dataSize == 0 {
glog.Trace("[protocol.%s] package data empty", p.tag)
if pkg.bodySize == 0 {
glog.Trace("[protocol.%s] package body empty", p.tag)
return
}
glog.Info("[protocol.%s] handle package successful, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize)
p.readCallback(pkg.data)
glog.Info("[protocol.%s] handle package successful, crc32:[%d] flag:[%x] bodySize:[%d]", p.tag, pkg.headerHash, pkg.flag, pkg.bodySize)
p.readCallback(pkg.body)
}
// Reader 阻塞接收数据并提交给readCallback
@ -183,7 +180,7 @@ func (p *Protocol) reader() {
defer p.decRunningRoutine()
buffer := &bytes.Buffer{}
buf := make([]byte, packageMaxSize)
buf := make([]byte, HeaderSize+BodyMaxSize)
var err error
var n int
// 监听并接收数据
@ -214,11 +211,11 @@ func (p *Protocol) reader() {
}
n, err = buffer.Write(buf[:n])
glog.Trace("[protocol.%s] write %d bytes, buffer already %d bytes, error is %v", p.tag, n, buffer.Len(), err)
for buffer.Len() >= packageHeadSize {
for buffer.Len() >= HeaderSize {
glog.Trace("[protocol.%s] complete package, buffer length %d", p.tag, buffer.Len())
pkg, err := parsePackage(buffer)
if err != nil {
if errors.Is(err, ErrorPackageIncomplete) {
if errors.Is(err, errorPackageIncomplete) {
glog.Trace("[protocol.%s] incomplete package, buffer length %d", p.tag, buffer.Len())
break
}
@ -226,7 +223,7 @@ func (p *Protocol) reader() {
}
if pkg != nil {
glog.Info("[protocol.%s] receive new package, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize)
glog.Info("[protocol.%s] receive new package, crc32:[%d] flag:[%x] bodySize:[%d]", p.tag, pkg.headerHash, pkg.flag, pkg.bodySize)
go p.handlePackage(pkg)
}
}
@ -261,7 +258,7 @@ func (p *Protocol) writer() {
p.setFuncBeforeWrite()
}
glog.Trace("[protocol.%s] writer wait write", p.tag)
n, err = p.w.Write(pkg.Bytes().Bytes())
n, err = p.w.Write(pkg.bytesBuffer().Bytes())
if p.setFuncAfterWrite != nil {
glog.Trace("[protocol.%s] writer func after write", p.tag)
p.setFuncAfterWrite(err)
@ -277,25 +274,26 @@ func (p *Protocol) writer() {
}
}
}
glog.Info("[protocol.%s] send package successful, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize)
packagePool.Put(pkg)
glog.Info("[protocol.%s] send package successful, crc32:[%d] flag:[%x] bodySize:[%d]", p.tag, pkg.headerHash, pkg.flag, pkg.bodySize)
}
}
// Write 发送数据
func (p *Protocol) Write(data []byte) error {
func (p *Protocol) Write(flag, encryptor uint8, data []byte) error {
glog.Trace("[protocol.%s] write", p.tag)
if len(data) > dataMaxSize {
glog.Info("[protocol.%s] maximum supported data size exceeded", p.tag)
if len(data) > BodyMaxSize {
glog.Info("[protocol.%s] maximum supported body size exceeded", p.tag)
return ErrorDataSizeExceedsLimit
}
pkg := newPackage(0, encryptNone, 0, data)
pkg := newPackage(flag, encryptor, 0, data)
for {
if p.getStatus() == statusKilled {
glog.Info("[protocol.%s] protocol is killed", p.tag)
return ErrorWriterIsKilled
}
if p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) {
glog.Info("[protocol.%s] write successful, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize)
glog.Info("[protocol.%s] write successful, crc32:[%d] flag:[%x] bodySize:[%d]", p.tag, pkg.headerHash, pkg.flag, pkg.bodySize)
return nil
}
}
@ -364,9 +362,9 @@ func (p *Protocol) sendHeartbeatSignal(isReq bool) {
glog.Trace("[protocol.%s] send heartbeat signal", p.tag)
var pkg *protocolPackage
if isReq {
pkg = newPackage(flagHeartbeatRequest, encryptNone, 0, nil)
pkg = newPackage(FlagHeartbeatRequest, EncryptorNone, 0, nil)
} else {
pkg = newPackage(flagHeartbeat, encryptNone, 0, nil)
pkg = newPackage(FlagHeartbeatResponse, EncryptorNone, 0, nil)
}
for !p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) {
if p.getStatus() == statusKilled {
@ -390,28 +388,28 @@ func (p *Protocol) getStatus() int32 {
return atomic.LoadInt32(&p.status)
}
func (p *Protocol) SetHeartbeatInterval(interval uint8) {
func (p *Protocol) SetHeartbeatInterval(interval uint32) {
if interval < 3 {
glog.Trace("[protocol.%s] heartbeatInterval is < 3, use 3", p.tag)
interval = 3
}
atomic.StoreUint32(&p.heartbeatInterval, uint32(interval))
atomic.StoreUint32(&p.heartbeatInterval, interval)
}
func (p *Protocol) GetHeartbeatInterval() uint8 {
return uint8(atomic.LoadUint32(&p.heartbeatInterval))
func (p *Protocol) GetHeartbeatInterval() uint32 {
return atomic.LoadUint32(&p.heartbeatInterval)
}
func (p *Protocol) SetHeartbeatTimeout(timeout uint8) {
func (p *Protocol) SetHeartbeatTimeout(timeout uint32) {
if timeout < 6 {
glog.Trace("[protocol.%s] heartbeatTimeout is < 6, use 6", p.tag)
timeout = 6
}
atomic.StoreUint32(&p.heartbeatTimeout, uint32(timeout))
atomic.StoreUint32(&p.heartbeatTimeout, timeout)
}
func (p *Protocol) GetHeartbeatTimeout() uint8 {
return uint8(atomic.LoadUint32(&p.heartbeatTimeout))
func (p *Protocol) GetHeartbeatTimeout() uint32 {
return atomic.LoadUint32(&p.heartbeatTimeout)
}
func (p *Protocol) setHeartbeatLastReceived() {
@ -481,14 +479,14 @@ func defaultKillCallback() {
glog.Trace("[protocol] default kill callback")
}
func GetDataMaxSize() int {
return dataMaxSize
}
func CalculateTheNumberOfPackages(size int64) int64 {
res := size / dataMaxSize
if size%dataMaxSize != 0 {
res += 1
}
return res
}
// func GetDataMaxSize() int {
// return BodyMaxSize
// }
//
// func CalculateTheNumberOfPackages(size int64) int64 {
// res := size / BodyMaxSize
// if size%BodyMaxSize != 0 {
// res += 1
// }
// return res
// }

View File

@ -14,7 +14,10 @@ var serverClosed = make(chan bool, 1)
var clientClosed = make(chan bool, 1)
func TestProtocol(t *testing.T) {
SetLogMask(MaskFATAL | MaskERROR | MaskWARNING | MaskINFO)
// SetLogMask(LogMaskDEBUG)
// SetLogFlag(LogFlagDEBUG)
SetLogMask(LogMaskINFO)
SetLogFlag(LogFlagINFO)
go testServer(t)
time.Sleep(time.Second)
go testClient(t)
@ -73,7 +76,7 @@ func testServer(t *testing.T) {
i++
msg := fmt.Sprintf("serv msg %d", i)
fmt.Printf("[server] send [%s]\n", msg)
err = protServer.Write([]byte(msg))
err = protServer.Write(FlagHeaderHashCheck|FlagBodyHashCheck, EncryptorAes, []byte(msg))
if err != nil {
if !errors.Is(err, ErrorWriterIsKilled) {
glog.Warning("[server] failed to write %v", err)
@ -136,7 +139,7 @@ func testClient(t *testing.T) {
i++
msg := fmt.Sprintf("client msg %d", i)
fmt.Printf("[client] send [%s]\n", msg)
err = protClient.Write([]byte(msg))
err = protClient.Write(FlagHeaderHashCheck|FlagBodyHashCheck, EncryptorXor, []byte(msg))
if err != nil {
if !errors.Is(err, ErrorWriterIsKilled) {
glog.Warning("[client] failed to write %v", err)

View File

@ -24,8 +24,8 @@ func TestQueue(t *testing.T) {
if que.size() != 0 || !que.isEmpty() {
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
}
if pkg11.value != 1 || len(pkg11.data) != 1 || pkg11.data[0] != 1 {
t.Errorf("value:%d data:%v\n", pkg11.value, pkg11.data)
if pkg11.value != 1 || len(pkg11.body) != 1 || pkg11.body[0] != 1 {
t.Errorf("value:%d body:%v\n", pkg11.value, pkg11.body)
}
que.push(pkg2, 0)
if que.size() != 1 || que.isEmpty() {
@ -33,8 +33,8 @@ func TestQueue(t *testing.T) {
}
pkg21 := que.pop(0)
if pkg21.value != 2 || len(pkg21.data) != 2 || pkg21.data[0] != 1 || pkg21.data[1] != 2 {
t.Errorf("value:%d data:%v\n", pkg21.value, pkg21.data)
if pkg21.value != 2 || len(pkg21.body) != 2 || pkg21.body[0] != 1 || pkg21.body[1] != 2 {
t.Errorf("value:%d body:%v\n", pkg21.value, pkg21.body)
}
pkg1 = newPackage(0, 0, 1, nil)
@ -67,8 +67,8 @@ func TestQueue(t *testing.T) {
time.Sleep(3 * time.Second)
fmt.Println("pop")
pkg11 := que2.pop(0)
if pkg11.value != 1 || pkg11.data != nil {
t.Errorf("value:%d data:%v\n", pkg11.value, pkg11.data)
if pkg11.value != 1 || pkg11.body != nil {
t.Errorf("value:%d body:%v\n", pkg11.value, pkg11.body)
}
}()
@ -79,18 +79,18 @@ func TestQueue(t *testing.T) {
}
pkg21 = que2.pop(0)
if pkg21.value != 2 || pkg21.data != nil {
t.Errorf("value:%d data:%v\n", pkg21.value, pkg21.data)
if pkg21.value != 2 || pkg21.body != nil {
t.Errorf("value:%d body:%v\n", pkg21.value, pkg21.body)
}
pkg31 := que2.pop(0)
if pkg31.value != 3 || pkg31.data != nil {
t.Errorf("value:%d data:%v\n", pkg31.value, pkg31.data)
if pkg31.value != 3 || pkg31.body != nil {
t.Errorf("value:%d body:%v\n", pkg31.value, pkg31.body)
}
pkg41 := que2.pop(0)
if pkg41.value != 4 || pkg31.data != nil {
t.Errorf("value:%d data:%v\n", pkg41.value, pkg41.data)
if pkg41.value != 4 || pkg31.body != nil {
t.Errorf("value:%d body:%v\n", pkg41.value, pkg41.body)
}
go func() {
@ -101,8 +101,8 @@ func TestQueue(t *testing.T) {
fmt.Println("wait push")
pkg51 := que2.pop(0)
if pkg51.value != 5 || len(pkg51.data) != 1 || pkg51.data[0] != 55 {
t.Errorf("value:%d data:%v\n", pkg51.value, pkg51.data)
if pkg51.value != 5 || len(pkg51.body) != 1 || pkg51.body[0] != 55 {
t.Errorf("value:%d body:%v\n", pkg51.value, pkg51.body)
}
fmt.Println("size")

20
util.go Normal file
View File

@ -0,0 +1,20 @@
package protocol
import (
"bytes"
"encoding/binary"
"sync"
)
var uint32BytesPool = sync.Pool{
New: func() interface{} {
return make([]byte, 4)
},
}
func writeUint32Bytes(b *bytes.Buffer, i uint32) {
buf := uint32BytesPool.Get().([]byte)
binary.BigEndian.PutUint32(buf, i)
b.Write(buf)
uint32BytesPool.Put(buf)
}