v2
This commit is contained in:
parent
1c9129c1e3
commit
f8fe8f5288
10
README.md
10
README.md
|
@ -7,6 +7,16 @@
|
|||
[![License: GPL v2](https://img.shields.io/badge/License-GPL%20v2-blue.svg)](https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html)
|
||||
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/git.viry.cc/gomod/protocol?tab=doc)
|
||||
|
||||
# TODO
|
||||
|
||||
- 心跳服务发送单独的Package, 如果value不为0, 则修改心跳超时时间
|
||||
- 完全关闭心跳, 即双方都不发送心跳信号
|
||||
- 单/双方间隔发送HeartbeatResponse
|
||||
- 单/双方间隔发送HeartbeatRequest, 接收方收到Request需要响应Response
|
||||
- 心跳信号不管是Request还是Response, 都可以携带Confirm信息
|
||||
- Confirm信息的键由Heartbeat携带
|
||||
- 更新glog包,使得glog能够接受外部File的Paper, 保证使用glog的mod和使用mod的项目公用一个log文件
|
||||
|
||||
```go
|
||||
import "git.viry.cc/gomod/protocol"
|
||||
```
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var aesEncryptorEnabled = false
|
||||
var aesCipherBlock cipher.Block = nil
|
||||
var aesCipherBlockModeEncrypter cipher.BlockMode = nil
|
||||
var aesCipherBlockModeDecrypter cipher.BlockMode = nil
|
||||
var aesCipherBlockSize int
|
||||
var aesCipherIV []byte
|
||||
|
||||
func SetEncryptorAesKey(key []byte) (err error) {
|
||||
// 分组秘钥
|
||||
aesCipherBlock, err = aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("key 长度必须 16/24/32长度: %s", err.Error())
|
||||
}
|
||||
aesCipherBlockSize = aesCipherBlock.BlockSize()
|
||||
aesCipherIV = key[:aesCipherBlockSize]
|
||||
aesCipherBlockModeEncrypter = cipher.NewCBCEncrypter(aesCipherBlock, aesCipherIV)
|
||||
aesCipherBlockModeDecrypter = cipher.NewCBCDecrypter(aesCipherBlock, aesCipherIV)
|
||||
aesEncryptorEnabled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *protocolPackage) aesEncrypt() {
|
||||
if !aesEncryptorEnabled {
|
||||
return
|
||||
}
|
||||
// PKCS7Padding 补码
|
||||
padding := aesCipherBlockSize - int(p.bodySize)%aesCipherBlockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
p.bodySize += uint32(padding)
|
||||
p.body = append(p.body, padtext...)
|
||||
// 创建数组
|
||||
cryted := make([]byte, p.bodySize)
|
||||
// 加密
|
||||
aesCipherBlockModeEncrypter.CryptBlocks(cryted, p.body)
|
||||
p.body = cryted
|
||||
p.encryptor = EncryptorAes
|
||||
}
|
||||
|
||||
func (p *protocolPackage) aesDecrypt() {
|
||||
if !aesEncryptorEnabled {
|
||||
return
|
||||
}
|
||||
// 创建数组
|
||||
orig := make([]byte, p.bodySize)
|
||||
// 解密
|
||||
aesCipherBlockModeDecrypter.CryptBlocks(orig, p.body)
|
||||
// PKCS7UnPadding 去码
|
||||
p.bodySize = p.bodySize - uint32(orig[p.bodySize-1])
|
||||
p.body = orig[:p.bodySize]
|
||||
p.encryptor = EncryptorNone
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"git.viry.cc/gomod/util"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncryptorAes(t *testing.T) {
|
||||
var ori, data, key []byte
|
||||
var err error
|
||||
|
||||
for i := 1; i < 1024; i++ {
|
||||
ori = util.RandomKey(i).Bytes()
|
||||
data = make([]byte, len(ori))
|
||||
copy(data, ori)
|
||||
for j := 0; j < 3; j++ {
|
||||
key = util.RandomKey(16 + j*8).Bytes()
|
||||
// check
|
||||
err = SetEncryptorAesKey(key)
|
||||
if err != nil {
|
||||
t.Errorf("set aes key failed, %v", err)
|
||||
}
|
||||
pkg := newPackage(0, 0, 0, data)
|
||||
// fmt.Printf("body %d", pkg.bodySize)
|
||||
pkg.aesEncrypt()
|
||||
// fmt.Printf(" -> %d", pkg.bodySize)
|
||||
pkg.aesDecrypt()
|
||||
// fmt.Printf(" -> %d\n", pkg.bodySize)
|
||||
if uint32(len(pkg.body)) != pkg.bodySize || pkg.bodySize != uint32(len(ori)) {
|
||||
t.Errorf("expected [%d] got [%d, %d]\n", uint32(len(ori)), pkg.bodySize, uint32(len(pkg.body)))
|
||||
}
|
||||
if !bytes.Equal(pkg.body, ori) {
|
||||
t.Errorf("expected [%0#2v] got [%0#2v]\n", ori, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncryptorAesEncode(b *testing.B) {
|
||||
data := util.RandomKey(256 * 1024 * 1024).Bytes()
|
||||
key := util.RandomKey(16).Bytes()
|
||||
pkg := newPackage(0, 0, 0, data)
|
||||
err := SetEncryptorAesKey(key)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("set aes key failed, %v", err))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pkg.aesEncrypt()
|
||||
pkg.aesDecrypt()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package protocol
|
||||
|
||||
var xorEncryptorEnabled = false
|
||||
var xorKey []byte
|
||||
var xorKeyLen int
|
||||
|
||||
func SetEncryptorXorKey(key []byte) {
|
||||
xorKey = key
|
||||
xorKeyLen = len(key)
|
||||
xorEncryptorEnabled = true
|
||||
}
|
||||
|
||||
func (p *protocolPackage) xorEncrypt() {
|
||||
if !xorEncryptorEnabled {
|
||||
return
|
||||
}
|
||||
for i := range p.body {
|
||||
p.body[i] ^= xorKey[i%xorKeyLen]
|
||||
}
|
||||
p.encryptor = EncryptorXor
|
||||
}
|
||||
|
||||
func (p *protocolPackage) xorDecrypt() {
|
||||
if !xorEncryptorEnabled {
|
||||
return
|
||||
}
|
||||
for i := range p.body {
|
||||
p.body[i] ^= xorKey[i%xorKeyLen]
|
||||
}
|
||||
p.encryptor = EncryptorNone
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"git.viry.cc/gomod/util"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncryptorXor(t *testing.T) {
|
||||
var ori, data, key, ans []byte
|
||||
|
||||
for i := 1; i < 1024; i++ {
|
||||
ori = util.RandomKey(i).Bytes()
|
||||
data = make([]byte, len(ori))
|
||||
copy(data, ori)
|
||||
for j := 1; j < 128; j++ {
|
||||
key = util.RandomKey(j).Bytes()
|
||||
// build ans
|
||||
ans = make([]byte, len(ori))
|
||||
copy(ans, ori)
|
||||
for k := range ans {
|
||||
ans[k] ^= key[k%j]
|
||||
}
|
||||
// check
|
||||
SetEncryptorXorKey(key)
|
||||
pkg := newPackage(0, 0, 0, data)
|
||||
pkg.xorEncrypt()
|
||||
if !bytes.Equal(pkg.body, ans) {
|
||||
t.Errorf("expected [%0#2v] got [%0#2v]\n", ans, data)
|
||||
}
|
||||
pkg.xorDecrypt()
|
||||
if uint32(len(pkg.body)) != pkg.bodySize || pkg.bodySize != uint32(len(ori)) {
|
||||
t.Errorf("expected [%d] got [%d, %d]\n", uint32(len(ori)), pkg.bodySize, uint32(len(pkg.body)))
|
||||
}
|
||||
if !bytes.Equal(pkg.body, ori) {
|
||||
t.Errorf("expected [%0#2v] got [%0#2v]\n", ori, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncryptorXorEncode(b *testing.B) {
|
||||
data := util.RandomKey(256 * 1024 * 1024).Bytes()
|
||||
key := util.RandomKey(16).Bytes()
|
||||
pkg := newPackage(0, 0, 0, data)
|
||||
SetEncryptorXorKey(key)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pkg.xorEncrypt()
|
||||
pkg.xorDecrypt()
|
||||
}
|
||||
}
|
55
log.go
55
log.go
|
@ -5,50 +5,28 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
MaskUNKNOWN = glog.MaskUNKNOWN
|
||||
MaskDEBUG = glog.MaskDEBUG
|
||||
MaskTRACE = glog.MaskTRACE
|
||||
MaskINFO = glog.MaskINFO
|
||||
MaskWARNING = glog.MaskWARNING
|
||||
MaskERROR = glog.MaskERROR
|
||||
MaskFATAL = glog.MaskFATAL
|
||||
|
||||
MaskStd = glog.MaskStd
|
||||
MaskAll = glog.MaskAll
|
||||
|
||||
MaskDev = MaskFATAL | MaskERROR | MaskWARNING | MaskINFO | MaskTRACE | MaskDEBUG | MaskUNKNOWN
|
||||
MaskProd = MaskFATAL | MaskERROR | MaskWARNING
|
||||
LogMaskDEBUG uint32 = glog.MaskUNKNOWN | glog.MaskDEBUG | glog.MaskTRACE | glog.MaskINFO | glog.MaskWARNING | glog.MaskERROR | glog.MaskFATAL
|
||||
LogFlagDEBUG uint32 = glog.FlagDate | glog.FlagTime | glog.FlagShortFile | glog.FlagPrefix | glog.FlagSuffix
|
||||
)
|
||||
|
||||
const (
|
||||
FlagDate = glog.FlagDate
|
||||
FlagTime = glog.FlagTime
|
||||
FlagLongFile = glog.FlagLongFile
|
||||
FlagShortFile = glog.FlagShortFile
|
||||
FlagFunc = glog.FlagFunc
|
||||
FlagPrefix = glog.FlagPrefix
|
||||
FlagSuffix = glog.FlagSuffix
|
||||
LogMaskINFO uint32 = glog.MaskINFO | glog.MaskWARNING | glog.MaskERROR | glog.MaskFATAL
|
||||
LogFlagINFO uint32 = glog.FlagDate | glog.FlagTime | glog.FlagPrefix
|
||||
)
|
||||
|
||||
FlagStd = glog.FlagStd
|
||||
FlagAll = glog.FlagAll
|
||||
const (
|
||||
LogMaskWARNING uint32 = glog.MaskWARNING | glog.MaskERROR | glog.MaskFATAL
|
||||
LogFlagWARNING uint32 = glog.FlagDate | glog.FlagTime | glog.FlagPrefix
|
||||
)
|
||||
|
||||
FlagDev = FlagDate | FlagTime | FlagShortFile | FlagFunc | FlagPrefix | FlagSuffix
|
||||
FlagProd = FlagDate | FlagTime | FlagPrefix
|
||||
const (
|
||||
LogMaskNONE uint32 = 0
|
||||
LogFlagNONE uint32 = glog.FlagDate | glog.FlagTime | glog.FlagPrefix
|
||||
)
|
||||
|
||||
func init() {
|
||||
glog.SetMask(MaskStd)
|
||||
glog.SetFlag(FlagStd)
|
||||
}
|
||||
|
||||
func SetLogProd(isProd bool) {
|
||||
if isProd {
|
||||
glog.SetMask(MaskProd)
|
||||
glog.SetFlag(FlagProd)
|
||||
} else {
|
||||
glog.SetMask(MaskStd)
|
||||
glog.SetFlag(FlagStd)
|
||||
}
|
||||
glog.SetMask(LogMaskNONE)
|
||||
glog.SetFlag(LogFlagNONE)
|
||||
}
|
||||
|
||||
func SetLogMask(mask uint32) {
|
||||
|
@ -58,8 +36,3 @@ func SetLogMask(mask uint32) {
|
|||
func SetLogFlag(f uint32) {
|
||||
glog.SetFlag(f)
|
||||
}
|
||||
|
||||
func init() {
|
||||
glog.SetMask(MaskProd)
|
||||
glog.SetFlag(FlagProd)
|
||||
}
|
||||
|
|
376
package.go
376
package.go
|
@ -4,268 +4,278 @@ import (
|
|||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"git.viry.cc/gomod/glog"
|
||||
"git.viry.cc/gomod/util"
|
||||
"hash/crc32"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// package的起始标志
|
||||
var prefix = [headLengthPrefix]uint8{0xff, 0x07, 0x55, 0x00}
|
||||
// PREFIX package的起始标志
|
||||
const PREFIX uint8 = 0x95 // 1001 0101
|
||||
|
||||
var ErrorPackageIncomplete = errors.New("package incomplete")
|
||||
|
||||
// head中各部分的长度
|
||||
const (
|
||||
headLengthPrefix = 4
|
||||
headLengthVersion = 1
|
||||
headLengthCRC32Checksum = 4
|
||||
headLengthFlag = 1
|
||||
headLengthEncryptMethod = 1
|
||||
headLengthCustomValue = 1
|
||||
headLengthDataSize = 4
|
||||
headLengthDataCrc32 = 4
|
||||
)
|
||||
|
||||
// head中各部分的偏移
|
||||
const (
|
||||
headOffsetPrefix = 0
|
||||
headOffsetVersion = headOffsetPrefix + headLengthPrefix
|
||||
headOffsetCRC32Checksum = headOffsetVersion + headLengthVersion
|
||||
headOffsetFlag = headOffsetCRC32Checksum + headLengthCRC32Checksum
|
||||
headOffsetEncryptMethod = headOffsetFlag + headLengthFlag
|
||||
headOffsetCustomValue = headOffsetEncryptMethod + headLengthEncryptMethod
|
||||
headOffsetDataSize = headOffsetCustomValue + headLengthCustomValue
|
||||
headOffsetDataCrc32 = headOffsetDataSize + headLengthDataSize
|
||||
headOffsetData = headOffsetDataCrc32 + headLengthDataCrc32
|
||||
)
|
||||
|
||||
// 计算head的crc32时起始偏移
|
||||
const headOffsetNeedCheck = headOffsetCRC32Checksum + headLengthCRC32Checksum
|
||||
|
||||
// data的加密方式
|
||||
const (
|
||||
encryptNone uint8 = iota
|
||||
)
|
||||
// VERSION package的版本
|
||||
const VERSION uint8 = 3
|
||||
|
||||
// flag标志位
|
||||
const (
|
||||
// 普通心跳信号,心跳响应信号
|
||||
flagHeartbeat uint8 = 1 << iota
|
||||
// 心跳请求信号,接收方必须回复flagHeartbeat
|
||||
flagHeartbeatRequest
|
||||
/*
|
||||
Heartbeat机制, 用来检查链路是否正常
|
||||
*/
|
||||
FlagHeartbeatRequest uint8 = 1 << iota // 心跳请求信号, 发送心跳信号并请求心跳响应
|
||||
FlagHeartbeatResponse // 心跳响应信号, 发送心跳信号
|
||||
|
||||
/*
|
||||
Package校验机制, 如果携带此标志, 则
|
||||
1. Header中携带了Header的哈希值, 用来确认Header完整性
|
||||
2. Header中携带了Body的哈希值, 用来确认Body完整性
|
||||
*/
|
||||
FlagHeaderHashCheck // Header拥有校验值
|
||||
FlagBodyHashCheck // Body拥有校验值
|
||||
|
||||
/*
|
||||
Package确认机制, 接收方收到携带此标志的包裹后,需要在一定时间内向发送方发送确认信号
|
||||
*/
|
||||
FlagPackageConfirmRequest // Package请求确认信号, 接收后需要响应确认信号
|
||||
FlagPackageConfirmResponse // Package确认信号,发送确认信号
|
||||
)
|
||||
|
||||
// package的head的大小 (byte)
|
||||
const packageHeadSize = headOffsetData
|
||||
// data的加密方式
|
||||
const (
|
||||
EncryptorNone uint8 = iota // 不加密
|
||||
EncryptorXor // 异或加密
|
||||
EncryptorAes // AES加密
|
||||
)
|
||||
|
||||
// package的最大size (byte)
|
||||
const packageMaxSize = 4096
|
||||
// 错误
|
||||
var (
|
||||
errorPackageIncomplete = errors.New("package incomplete")
|
||||
)
|
||||
|
||||
// package的data的最大size (byte)
|
||||
const dataMaxSize = packageMaxSize - packageHeadSize
|
||||
// 数据长度
|
||||
const (
|
||||
HeaderSize = 20 // header的大小
|
||||
BodyMaxSize = 16384 // body的最大大小
|
||||
)
|
||||
|
||||
// 协议包结构
|
||||
type protocolPackage struct {
|
||||
prefix [headLengthPrefix]byte // 4 byte 0xff 0x55
|
||||
version uint8 // 1 byte protocol version
|
||||
crc32 uint32 // 4 byte head crc32 checksum (BigEndian)
|
||||
flag uint8 // 1 byte flag
|
||||
encryptMethod uint8 // 1 byte encrypted method
|
||||
value uint8 // 1 byte custom value (for heartbeat)
|
||||
dataSize uint32 // 4 byte curSize of data (BigEndian)
|
||||
dataCrc32 uint32 // 4 byte crc32 of data (BigEndian)
|
||||
data []byte // ? byte data
|
||||
prefix uint8 // 1 byte 0x95
|
||||
version uint8 // 1 byte protocol version
|
||||
headerHash uint32 // 4 byte head crc32 checksum (BigEndian)
|
||||
flag uint8 // 1 byte flag
|
||||
encryptor uint8 // 1 byte encrypt method
|
||||
value uint32 // 4 byte 在普通Package中保存Package的序号, 用以Confirm. 在Heartbeat中存储心跳序号和心跳超时时间
|
||||
bodySize uint32 // 4 byte size of body (BigEndian)
|
||||
bodyHash uint32 // 4 byte headerHash of body (BigEndian)
|
||||
body []byte // ? byte body
|
||||
}
|
||||
|
||||
// 缓存池
|
||||
var (
|
||||
// *protocolPackage
|
||||
packagePool = sync.Pool{New: func() any { return &protocolPackage{} }}
|
||||
// *bytes.Buffer
|
||||
bytesBufferPool = sync.Pool{New: func() any { return &bytes.Buffer{} }}
|
||||
// make([]byte, HeaderSize)
|
||||
parsePackageBufferPool = sync.Pool{New: func() any { return make([]byte, HeaderSize) }}
|
||||
)
|
||||
|
||||
// 创建新package, 所有新package的创建都必须通过此方法
|
||||
func newPackage(flag uint8, encrypt uint8, value uint8, data []byte) *protocolPackage {
|
||||
pkg := &protocolPackage{
|
||||
prefix: prefix,
|
||||
version: VERSION,
|
||||
crc32: 0,
|
||||
flag: flag,
|
||||
encryptMethod: encryptNone,
|
||||
value: value,
|
||||
dataSize: uint32(len(data)),
|
||||
dataCrc32: 0,
|
||||
data: data,
|
||||
//
|
||||
// 返回的 *protocolPackage 使用结束后需要手动回收到 packagePool
|
||||
func newPackage(flag uint8, encryptor uint8, value uint32, body []byte) *protocolPackage {
|
||||
pkg := packagePool.Get().(*protocolPackage)
|
||||
pkg.prefix = PREFIX
|
||||
pkg.version = VERSION
|
||||
pkg.headerHash = 0
|
||||
pkg.flag = flag
|
||||
pkg.encryptor = EncryptorNone
|
||||
pkg.value = value
|
||||
pkg.bodySize = uint32(len(body))
|
||||
pkg.bodyHash = 0
|
||||
pkg.body = body
|
||||
if encryptor != EncryptorNone {
|
||||
pkg.encrypt(encryptor)
|
||||
}
|
||||
if (flag & FlagBodyHashCheck) != 0 {
|
||||
pkg.generateBodyHash()
|
||||
}
|
||||
if (flag & FlagHeaderHashCheck) != 0 {
|
||||
pkg.generateHeaderHash()
|
||||
}
|
||||
pkg.generateDataCheck()
|
||||
pkg.generateHeadCheck()
|
||||
pkg.encrypt(encrypt)
|
||||
return pkg
|
||||
}
|
||||
|
||||
func (p *protocolPackage) Bytes() *bytes.Buffer {
|
||||
buf := &bytes.Buffer{}
|
||||
// 返回的 *bytes.Buffer 使用结束后需要手动回收到 bytesBufferPool
|
||||
func (p *protocolPackage) bytesBuffer() *bytes.Buffer {
|
||||
buf := bytesBufferPool.Get().(*bytes.Buffer)
|
||||
// prefix
|
||||
buf.Write(p.prefix[:])
|
||||
buf.WriteByte(p.prefix)
|
||||
// version
|
||||
buf.WriteByte(p.version)
|
||||
// crc32
|
||||
buf.Write(util.UInt32ToBytesSlice(p.crc32))
|
||||
// head hash
|
||||
writeUint32Bytes(buf, p.headerHash)
|
||||
// flag
|
||||
buf.WriteByte(p.flag)
|
||||
// encrypt method
|
||||
buf.WriteByte(p.encryptMethod)
|
||||
buf.WriteByte(p.encryptor)
|
||||
// value
|
||||
buf.WriteByte(p.value)
|
||||
// data curSize
|
||||
buf.Write(util.UInt32ToBytesSlice(p.dataSize))
|
||||
// data crc32
|
||||
buf.Write(util.UInt32ToBytesSlice(p.dataCrc32))
|
||||
// data
|
||||
buf.Write(p.data)
|
||||
writeUint32Bytes(buf, p.value)
|
||||
// body curSize
|
||||
writeUint32Bytes(buf, p.bodySize)
|
||||
// body hash
|
||||
writeUint32Bytes(buf, p.bodyHash)
|
||||
// body
|
||||
buf.Write(p.body)
|
||||
return buf
|
||||
}
|
||||
|
||||
// 将head中需要校验的数据拼接起来
|
||||
func (p *protocolPackage) headNeedCheckBytes() *bytes.Buffer {
|
||||
buf := &bytes.Buffer{}
|
||||
//
|
||||
// 返回的 *bytes.Buffer 使用结束后需要手动回收到 bytesBufferPool
|
||||
func (p *protocolPackage) headerNeedCheck() *bytes.Buffer {
|
||||
buf := bytesBufferPool.Get().(*bytes.Buffer)
|
||||
buf.WriteByte(p.flag)
|
||||
buf.WriteByte(p.encryptMethod)
|
||||
buf.WriteByte(p.value)
|
||||
buf.Write(util.UInt32ToBytesSlice(p.dataSize))
|
||||
buf.Write(util.UInt32ToBytesSlice(p.dataCrc32))
|
||||
buf.WriteByte(p.encryptor)
|
||||
writeUint32Bytes(buf, p.value)
|
||||
writeUint32Bytes(buf, p.bodySize)
|
||||
writeUint32Bytes(buf, p.bodyHash)
|
||||
return buf
|
||||
}
|
||||
|
||||
// 生成head的crc32
|
||||
func (p *protocolPackage) generateHeadCheck() {
|
||||
p.crc32 = util.NewCRC32().FromBytes(p.headNeedCheckBytes().Bytes()).Value()
|
||||
glog.Trace("[protocol_package] head crc32 is %d", p.crc32)
|
||||
// 生成header hash
|
||||
func (p *protocolPackage) generateHeaderHash() {
|
||||
buf := p.headerNeedCheck()
|
||||
defer func() {
|
||||
buf.Reset()
|
||||
bytesBufferPool.Put(buf)
|
||||
}()
|
||||
p.headerHash = crc32.ChecksumIEEE(buf.Bytes())
|
||||
}
|
||||
|
||||
// 校验head的crc32
|
||||
func (p *protocolPackage) checkHead() bool {
|
||||
return p.crc32 == util.NewCRC32().FromBytes(p.headNeedCheckBytes().Bytes()).Value()
|
||||
// 校验header
|
||||
func (p *protocolPackage) checkHeader() bool {
|
||||
buf := p.headerNeedCheck()
|
||||
defer func() {
|
||||
buf.Reset()
|
||||
bytesBufferPool.Put(buf)
|
||||
}()
|
||||
return p.headerHash == crc32.ChecksumIEEE(buf.Bytes())
|
||||
}
|
||||
|
||||
// 生成data的crc32
|
||||
func (p *protocolPackage) generateDataCheck() {
|
||||
p.dataCrc32 = util.NewCRC32().FromBytes(p.data).Value()
|
||||
glog.Trace("[protocol_package] data crc32 is %d", p.dataCrc32)
|
||||
func (p *protocolPackage) generateBodyHash() {
|
||||
p.bodyHash = crc32.ChecksumIEEE(p.body)
|
||||
}
|
||||
|
||||
// 校验data的crc32
|
||||
func (p *protocolPackage) checkData() bool {
|
||||
if int(p.dataSize) != len(p.data) {
|
||||
glog.Trace("[protocol_package] pkg.dataSize != len(pkg.data)")
|
||||
// 校验body
|
||||
func (p *protocolPackage) checkBody() bool {
|
||||
if (p.flag & FlagHeaderHashCheck) == 0 {
|
||||
return true
|
||||
}
|
||||
if int(p.bodySize) != len(p.body) {
|
||||
return false
|
||||
}
|
||||
return p.dataCrc32 == util.NewCRC32().FromBytes(p.data).Value()
|
||||
return p.bodyHash == crc32.ChecksumIEEE(p.body)
|
||||
}
|
||||
|
||||
// encrypt 加密data
|
||||
func (p *protocolPackage) encrypt(method uint8) {
|
||||
if p.encryptMethod == method {
|
||||
glog.Trace("[protocol_package] is already encrypted [%d]", method)
|
||||
return // 已经加密
|
||||
func (p *protocolPackage) encrypt(encryptor uint8) {
|
||||
switch encryptor {
|
||||
case EncryptorXor:
|
||||
p.xorEncrypt()
|
||||
case EncryptorAes:
|
||||
p.aesEncrypt()
|
||||
default:
|
||||
glog.Warning("[protocol_package] unknown encrypt method")
|
||||
}
|
||||
if p.encryptMethod != encryptNone {
|
||||
glog.Trace("[protocol_package] encrypt with other method got [%d] need encryptNone[%d]", p.encryptMethod, encryptNone)
|
||||
return // 已经通过其他方式加密
|
||||
}
|
||||
glog.Warning("[protocol_package] unknown encrypt method")
|
||||
}
|
||||
|
||||
// decrypt 解密data
|
||||
func (p *protocolPackage) decrypt() {
|
||||
if !p.isEncrypted() {
|
||||
glog.Trace("[protocol_package] is not encrypted")
|
||||
switch p.encryptor {
|
||||
case EncryptorNone:
|
||||
return
|
||||
case EncryptorXor:
|
||||
p.xorDecrypt()
|
||||
case EncryptorAes:
|
||||
p.aesDecrypt()
|
||||
default:
|
||||
glog.Warning("[protocol_package] unknown encrypt method")
|
||||
}
|
||||
}
|
||||
|
||||
// isEncrypted 是否已经加密
|
||||
func (p *protocolPackage) isEncrypted() bool {
|
||||
return p.encryptMethod != encryptNone
|
||||
}
|
||||
|
||||
// parsePackage 从buf中读取一个package
|
||||
//
|
||||
// 如果协议头标志(prefix)不匹配,删除buf中除第一个字符外,下一个prefix1到buf开头的所有数据
|
||||
// 如果协议头标志(prefix)匹配,不断从buf中取出数据,填充到package结构体
|
||||
// 如果buf中的数据出错,无法正确提取package, 则返回(nil,true), 且已从buf中提取的数据不会退回buf
|
||||
func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
|
||||
bufLen := uint32(buf.Len())
|
||||
bufData := buf.Bytes()
|
||||
// 协议头标志不匹配,删除未知数据,寻找下一个package起始位置
|
||||
if !bytes.Equal(prefix[:], buf.Bytes()[headOffsetPrefix:headOffsetPrefix+headLengthPrefix]) {
|
||||
glog.Trace("[protocol_package] prefix does not match, need %v got %v", prefix, buf.Bytes()[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
|
||||
if bufData[0] != PREFIX {
|
||||
nextPackageHead(buf)
|
||||
return nil, fmt.Errorf("prefix does not match, need %v got %v", prefix, buf.Bytes()[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
|
||||
return nil, fmt.Errorf("prefix does not match, expected %v got %v", PREFIX, bufData[0])
|
||||
}
|
||||
// 判断package的版本
|
||||
// 暂时只处理VERSION版本的package
|
||||
if buf.Len() < headOffsetVersion+headLengthVersion {
|
||||
glog.Trace("[protocol_package] incomplete version information, need %d got %d", headOffsetVersion+headLengthVersion, buf.Len())
|
||||
return nil, ErrorPackageIncomplete
|
||||
}
|
||||
if buf.Bytes()[headOffsetVersion] != VERSION {
|
||||
glog.Trace("[protocol_package] unsupported version need %d got %d", VERSION, buf.Bytes()[headOffsetVersion])
|
||||
if bufData[1] != VERSION {
|
||||
nextPackageHead(buf)
|
||||
return nil, fmt.Errorf("unsupported version need %d got %d", VERSION, buf.Bytes()[headOffsetVersion])
|
||||
return nil, fmt.Errorf("unsupported version expected %d got %d", VERSION, bufData[1])
|
||||
}
|
||||
// 开始判断是否为package并提取package
|
||||
if buf.Len() < packageHeadSize {
|
||||
glog.Trace("[protocol_package] incomplete head, need %d got %d", packageHeadSize, buf.Len())
|
||||
return nil, ErrorPackageIncomplete
|
||||
// 获取body长度
|
||||
bodySize := uint32(bufData[HeaderSize-5]) | uint32(bufData[HeaderSize-6])<<8 | uint32(bufData[HeaderSize-7])<<16 | uint32(bufData[HeaderSize-8])<<24
|
||||
// 判断Package是否接收完整
|
||||
if bufLen < HeaderSize+bodySize {
|
||||
return nil, errorPackageIncomplete
|
||||
}
|
||||
head := make([]byte, packageHeadSize)
|
||||
copy(head, buf.Bytes()[:packageHeadSize])
|
||||
// 检查head是否完整,删除未知数据,寻找下一个package起始位置
|
||||
headChecksum := util.BytesSliceToUInt32(head[headOffsetCRC32Checksum : headOffsetCRC32Checksum+headLengthCRC32Checksum])
|
||||
headCrc32 := util.NewCRC32().FromBytes(head[headOffsetNeedCheck:]).Value()
|
||||
if headChecksum != headCrc32 {
|
||||
glog.Trace("[protocol_package] head crc32 checksum does not match, need %d got %d", headChecksum, headCrc32)
|
||||
nextPackageHead(buf)
|
||||
return nil, fmt.Errorf("head crc32 checksum does not match, need %d got %d", headChecksum, headCrc32)
|
||||
// 校验Header
|
||||
var headerHashExpected uint32
|
||||
if (bufData[6] & FlagHeaderHashCheck) != 0 {
|
||||
var headerHashGot uint32
|
||||
headerHashGot = crc32.ChecksumIEEE(bufData[6:HeaderSize])
|
||||
headerHashExpected = uint32(bufData[5]) | uint32(bufData[4])<<8 | uint32(bufData[3])<<16 | uint32(bufData[2])<<24
|
||||
if headerHashExpected != headerHashGot {
|
||||
nextPackageHead(buf)
|
||||
return nil, fmt.Errorf("wrong header hash, expected %d got %d", headerHashExpected, headerHashGot)
|
||||
}
|
||||
}
|
||||
// 检查package是否完整,不完整则等待
|
||||
packageDataSize := util.BytesSliceToUInt32(head[headOffsetDataSize : headOffsetDataSize+headLengthDataSize])
|
||||
if packageHeadSize+int(packageDataSize) > buf.Len() {
|
||||
glog.Trace("[protocol_package] incomplete data, need %d got %d", packageHeadSize+packageDataSize, buf.Len())
|
||||
return nil, ErrorPackageIncomplete
|
||||
}
|
||||
// package完整
|
||||
pkg := &protocolPackage{}
|
||||
_, _ = buf.Read(make([]byte, packageHeadSize))
|
||||
// prefix
|
||||
copy(pkg.prefix[:], head[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
|
||||
// crc32
|
||||
pkg.version = head[headOffsetVersion]
|
||||
// crc32
|
||||
pkg.crc32 = headCrc32
|
||||
// flag
|
||||
pkg.flag = head[headOffsetFlag]
|
||||
// encrypt method
|
||||
pkg.encryptMethod = head[headOffsetEncryptMethod]
|
||||
// value
|
||||
pkg.value = head[headOffsetCustomValue]
|
||||
// data curSize
|
||||
pkg.dataSize = packageDataSize
|
||||
// data crc32
|
||||
pkg.dataCrc32 = util.BytesSliceToUInt32(head[headOffsetDataCrc32 : headOffsetDataCrc32+headLengthDataCrc32])
|
||||
// data
|
||||
pkg.data = make([]byte, pkg.dataSize)
|
||||
_, _ = buf.Read(pkg.data)
|
||||
dataCrc32 := util.NewCRC32().FromBytes(pkg.data).Value()
|
||||
if pkg.dataCrc32 != dataCrc32 {
|
||||
glog.Trace("[protocol_package] data crc32 checksum does not match, need %d got %d", pkg.dataCrc32, dataCrc32)
|
||||
nextPackageHead(buf)
|
||||
return nil, fmt.Errorf("data crc32 checksum does not match, need %d got %d", pkg.dataCrc32, dataCrc32)
|
||||
// 校验Body
|
||||
var bodyHashExpected uint32
|
||||
if (bufData[6] & FlagBodyHashCheck) != 0 {
|
||||
var bodyHashGot uint32
|
||||
bodyHashGot = crc32.ChecksumIEEE(bufData[HeaderSize:])
|
||||
bodyHashExpected = uint32(bufData[HeaderSize-1]) | uint32(bufData[HeaderSize-2])<<8 | uint32(bufData[HeaderSize-3])<<16 | uint32(bufData[HeaderSize-4])<<24
|
||||
if bodyHashExpected != bodyHashGot {
|
||||
nextPackageHead(buf)
|
||||
return nil, fmt.Errorf("wrong header hash, expected %d got %d", bodyHashExpected, bodyHashGot)
|
||||
}
|
||||
}
|
||||
|
||||
pkg := packagePool.Get().(*protocolPackage)
|
||||
pkg.prefix = bufData[0]
|
||||
pkg.version = bufData[1]
|
||||
pkg.headerHash = headerHashExpected
|
||||
pkg.flag = bufData[6]
|
||||
pkg.encryptor = bufData[7]
|
||||
pkg.value = uint32(bufData[11]) | uint32(bufData[10])<<8 | uint32(bufData[9])<<16 | uint32(bufData[8])<<24
|
||||
pkg.bodySize = bodySize
|
||||
pkg.bodyHash = bodyHashExpected
|
||||
pkg.body = make([]byte, bodySize)
|
||||
hd := parsePackageBufferPool.Get().([]byte)
|
||||
buf.Read(hd)
|
||||
buf.Read(pkg.body)
|
||||
|
||||
return pkg, nil
|
||||
}
|
||||
|
||||
// 删除掉buf中第一个byte, 并将buf中的起始位置调整到与prefix[0]相同的下一个元素的位置
|
||||
// 删除掉buf中第一个byte, 并将buf中的起始位置调整到与PREFIX相同的下一个元素的位置
|
||||
func nextPackageHead(buf *bytes.Buffer) {
|
||||
var err error
|
||||
_, _ = buf.ReadByte()
|
||||
_, err = buf.ReadBytes(prefix[0]) // 只搜索与prefix[0]相同元素,防止prefix[0]出现在buf末尾
|
||||
if err == nil { // 找到下一个协议头标志,把删掉的prefix回退到buffer
|
||||
_, err = buf.ReadBytes(PREFIX) // 只搜索与PREFIX相同元素,防止PREFIX出现在buf末尾
|
||||
if err == nil { // 找到下一个协议头标志,把删掉的prefix回退到buffer
|
||||
_ = buf.UnreadByte()
|
||||
glog.Trace("[protocol_package] prefix does not match, prefix[0] found, trim buf, buf length [%d]", buf.Len())
|
||||
} else { // 找不到下一个协议头标志,清空buffer
|
||||
buf.Reset()
|
||||
glog.Trace("[protocol_package] prefix does not match, prefix[0] not found, reset buf, buf length [%d]", buf.Len())
|
||||
}
|
||||
}
|
||||
|
|
106
protocol.go
106
protocol.go
|
@ -3,15 +3,12 @@ package protocol
|
|||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"git.viry.cc/gomod/glog"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.viry.cc/gomod/glog"
|
||||
)
|
||||
|
||||
const VERSION uint8 = 1
|
||||
|
||||
var ErrorReadCallbackIsNil = errors.New("read callback is nil")
|
||||
var ErrorReaderIsNil = errors.New("reader is nil")
|
||||
var ErrorWriterIsNil = errors.New("writer is nil")
|
||||
|
@ -20,7 +17,7 @@ var ErrorWriterIsKilled = errors.New("writer is killed")
|
|||
var ErrorWriterQueueIsNil = errors.New("writer queue is nil")
|
||||
var ErrorHeartbeatIsKilled = errors.New("heartbeat is killed")
|
||||
var ErrorHeartbeatCallbackIsNil = errors.New("heartbeat callback is nil")
|
||||
var ErrorDataSizeExceedsLimit = errors.New("data size exceeds limit")
|
||||
var ErrorDataSizeExceedsLimit = errors.New("body size exceeds limit")
|
||||
var ErrorTimeout = errors.New("timeout")
|
||||
|
||||
const (
|
||||
|
@ -44,9 +41,9 @@ type Protocol struct {
|
|||
runningRoutines int32
|
||||
|
||||
// 心跳信号,同时也是心跳响应信号
|
||||
heartbeatSig chan uint8
|
||||
heartbeatSig chan uint32
|
||||
// 心跳请求信号,收到此信号必须回复对方
|
||||
heartbeatSigReq chan uint8
|
||||
heartbeatSigReq chan uint32
|
||||
// 发送心跳请求的间隔
|
||||
heartbeatInterval uint32
|
||||
// 接收心跳请求的超时时间
|
||||
|
@ -114,8 +111,8 @@ func New(tag string, r io.Reader, w io.Writer, writeQueueSize int, readCallback
|
|||
readCallback: readCallback,
|
||||
writeQueue: newQueue(writeQueueSize),
|
||||
runningRoutines: 0,
|
||||
heartbeatSig: make(chan uint8, 1),
|
||||
heartbeatSigReq: make(chan uint8, 1),
|
||||
heartbeatSig: make(chan uint32, 1),
|
||||
heartbeatSigReq: make(chan uint32, 1),
|
||||
heartbeatInterval: 15,
|
||||
heartbeatTimeout: 40,
|
||||
heartbeatTimeoutCallback: heartbeatTimeoutCallback,
|
||||
|
@ -144,32 +141,32 @@ func (p *Protocol) handlePackage(pkg *protocolPackage) {
|
|||
glog.Trace("[protocol.%s] package is nil", p.tag)
|
||||
return
|
||||
}
|
||||
if pkg.isEncrypted() {
|
||||
glog.Trace("[protocol.%s] package is encrypted, decrypt package", p.tag)
|
||||
pkg.decrypt()
|
||||
}
|
||||
if !pkg.checkHead() {
|
||||
defer packagePool.Put(pkg)
|
||||
if (pkg.flag&FlagHeaderHashCheck) != 0 && !pkg.checkHeader() {
|
||||
glog.Trace("[protocol.%s] package head broken", p.tag)
|
||||
return
|
||||
}
|
||||
if (pkg.flag & flagHeartbeat) != 0 {
|
||||
if pkg.encryptor != EncryptorNone {
|
||||
pkg.decrypt()
|
||||
}
|
||||
if (pkg.flag & FlagHeartbeatResponse) != 0 {
|
||||
glog.Info("[protocol.%s] heartbeat signal in package", p.tag)
|
||||
p.heartbeatSig <- pkg.value
|
||||
}
|
||||
if (pkg.flag & flagHeartbeatRequest) != 0 {
|
||||
if (pkg.flag & FlagHeartbeatRequest) != 0 {
|
||||
glog.Info("[protocol.%s] heartbeat request signal in package", p.tag)
|
||||
p.heartbeatSigReq <- pkg.value
|
||||
}
|
||||
if !pkg.checkData() {
|
||||
glog.Trace("[protocol.%s] package data broken", p.tag)
|
||||
if !pkg.checkBody() {
|
||||
glog.Trace("[protocol.%s] package body broken", p.tag)
|
||||
return
|
||||
}
|
||||
if pkg.dataSize == 0 {
|
||||
glog.Trace("[protocol.%s] package data empty", p.tag)
|
||||
if pkg.bodySize == 0 {
|
||||
glog.Trace("[protocol.%s] package body empty", p.tag)
|
||||
return
|
||||
}
|
||||
glog.Info("[protocol.%s] handle package successful, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize)
|
||||
p.readCallback(pkg.data)
|
||||
glog.Info("[protocol.%s] handle package successful, crc32:[%d] flag:[%x] bodySize:[%d]", p.tag, pkg.headerHash, pkg.flag, pkg.bodySize)
|
||||
p.readCallback(pkg.body)
|
||||
}
|
||||
|
||||
// Reader 阻塞接收数据并提交给readCallback
|
||||
|
@ -183,7 +180,7 @@ func (p *Protocol) reader() {
|
|||
defer p.decRunningRoutine()
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
buf := make([]byte, packageMaxSize)
|
||||
buf := make([]byte, HeaderSize+BodyMaxSize)
|
||||
var err error
|
||||
var n int
|
||||
// 监听并接收数据
|
||||
|
@ -214,11 +211,11 @@ func (p *Protocol) reader() {
|
|||
}
|
||||
n, err = buffer.Write(buf[:n])
|
||||
glog.Trace("[protocol.%s] write %d bytes, buffer already %d bytes, error is %v", p.tag, n, buffer.Len(), err)
|
||||
for buffer.Len() >= packageHeadSize {
|
||||
for buffer.Len() >= HeaderSize {
|
||||
glog.Trace("[protocol.%s] complete package, buffer length %d", p.tag, buffer.Len())
|
||||
pkg, err := parsePackage(buffer)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrorPackageIncomplete) {
|
||||
if errors.Is(err, errorPackageIncomplete) {
|
||||
glog.Trace("[protocol.%s] incomplete package, buffer length %d", p.tag, buffer.Len())
|
||||
break
|
||||
}
|
||||
|
@ -226,7 +223,7 @@ func (p *Protocol) reader() {
|
|||
}
|
||||
|
||||
if pkg != nil {
|
||||
glog.Info("[protocol.%s] receive new package, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize)
|
||||
glog.Info("[protocol.%s] receive new package, crc32:[%d] flag:[%x] bodySize:[%d]", p.tag, pkg.headerHash, pkg.flag, pkg.bodySize)
|
||||
go p.handlePackage(pkg)
|
||||
}
|
||||
}
|
||||
|
@ -261,7 +258,7 @@ func (p *Protocol) writer() {
|
|||
p.setFuncBeforeWrite()
|
||||
}
|
||||
glog.Trace("[protocol.%s] writer wait write", p.tag)
|
||||
n, err = p.w.Write(pkg.Bytes().Bytes())
|
||||
n, err = p.w.Write(pkg.bytesBuffer().Bytes())
|
||||
if p.setFuncAfterWrite != nil {
|
||||
glog.Trace("[protocol.%s] writer func after write", p.tag)
|
||||
p.setFuncAfterWrite(err)
|
||||
|
@ -277,25 +274,26 @@ func (p *Protocol) writer() {
|
|||
}
|
||||
}
|
||||
}
|
||||
glog.Info("[protocol.%s] send package successful, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize)
|
||||
packagePool.Put(pkg)
|
||||
glog.Info("[protocol.%s] send package successful, crc32:[%d] flag:[%x] bodySize:[%d]", p.tag, pkg.headerHash, pkg.flag, pkg.bodySize)
|
||||
}
|
||||
}
|
||||
|
||||
// Write 发送数据
|
||||
func (p *Protocol) Write(data []byte) error {
|
||||
func (p *Protocol) Write(flag, encryptor uint8, data []byte) error {
|
||||
glog.Trace("[protocol.%s] write", p.tag)
|
||||
if len(data) > dataMaxSize {
|
||||
glog.Info("[protocol.%s] maximum supported data size exceeded", p.tag)
|
||||
if len(data) > BodyMaxSize {
|
||||
glog.Info("[protocol.%s] maximum supported body size exceeded", p.tag)
|
||||
return ErrorDataSizeExceedsLimit
|
||||
}
|
||||
pkg := newPackage(0, encryptNone, 0, data)
|
||||
pkg := newPackage(flag, encryptor, 0, data)
|
||||
for {
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Info("[protocol.%s] protocol is killed", p.tag)
|
||||
return ErrorWriterIsKilled
|
||||
}
|
||||
if p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) {
|
||||
glog.Info("[protocol.%s] write successful, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize)
|
||||
glog.Info("[protocol.%s] write successful, crc32:[%d] flag:[%x] bodySize:[%d]", p.tag, pkg.headerHash, pkg.flag, pkg.bodySize)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -364,9 +362,9 @@ func (p *Protocol) sendHeartbeatSignal(isReq bool) {
|
|||
glog.Trace("[protocol.%s] send heartbeat signal", p.tag)
|
||||
var pkg *protocolPackage
|
||||
if isReq {
|
||||
pkg = newPackage(flagHeartbeatRequest, encryptNone, 0, nil)
|
||||
pkg = newPackage(FlagHeartbeatRequest, EncryptorNone, 0, nil)
|
||||
} else {
|
||||
pkg = newPackage(flagHeartbeat, encryptNone, 0, nil)
|
||||
pkg = newPackage(FlagHeartbeatResponse, EncryptorNone, 0, nil)
|
||||
}
|
||||
for !p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) {
|
||||
if p.getStatus() == statusKilled {
|
||||
|
@ -390,28 +388,28 @@ func (p *Protocol) getStatus() int32 {
|
|||
return atomic.LoadInt32(&p.status)
|
||||
}
|
||||
|
||||
func (p *Protocol) SetHeartbeatInterval(interval uint8) {
|
||||
func (p *Protocol) SetHeartbeatInterval(interval uint32) {
|
||||
if interval < 3 {
|
||||
glog.Trace("[protocol.%s] heartbeatInterval is < 3, use 3", p.tag)
|
||||
interval = 3
|
||||
}
|
||||
atomic.StoreUint32(&p.heartbeatInterval, uint32(interval))
|
||||
atomic.StoreUint32(&p.heartbeatInterval, interval)
|
||||
}
|
||||
|
||||
func (p *Protocol) GetHeartbeatInterval() uint8 {
|
||||
return uint8(atomic.LoadUint32(&p.heartbeatInterval))
|
||||
func (p *Protocol) GetHeartbeatInterval() uint32 {
|
||||
return atomic.LoadUint32(&p.heartbeatInterval)
|
||||
}
|
||||
|
||||
func (p *Protocol) SetHeartbeatTimeout(timeout uint8) {
|
||||
func (p *Protocol) SetHeartbeatTimeout(timeout uint32) {
|
||||
if timeout < 6 {
|
||||
glog.Trace("[protocol.%s] heartbeatTimeout is < 6, use 6", p.tag)
|
||||
timeout = 6
|
||||
}
|
||||
atomic.StoreUint32(&p.heartbeatTimeout, uint32(timeout))
|
||||
atomic.StoreUint32(&p.heartbeatTimeout, timeout)
|
||||
}
|
||||
|
||||
func (p *Protocol) GetHeartbeatTimeout() uint8 {
|
||||
return uint8(atomic.LoadUint32(&p.heartbeatTimeout))
|
||||
func (p *Protocol) GetHeartbeatTimeout() uint32 {
|
||||
return atomic.LoadUint32(&p.heartbeatTimeout)
|
||||
}
|
||||
|
||||
func (p *Protocol) setHeartbeatLastReceived() {
|
||||
|
@ -481,14 +479,14 @@ func defaultKillCallback() {
|
|||
glog.Trace("[protocol] default kill callback")
|
||||
}
|
||||
|
||||
func GetDataMaxSize() int {
|
||||
return dataMaxSize
|
||||
}
|
||||
|
||||
func CalculateTheNumberOfPackages(size int64) int64 {
|
||||
res := size / dataMaxSize
|
||||
if size%dataMaxSize != 0 {
|
||||
res += 1
|
||||
}
|
||||
return res
|
||||
}
|
||||
// func GetDataMaxSize() int {
|
||||
// return BodyMaxSize
|
||||
// }
|
||||
//
|
||||
// func CalculateTheNumberOfPackages(size int64) int64 {
|
||||
// res := size / BodyMaxSize
|
||||
// if size%BodyMaxSize != 0 {
|
||||
// res += 1
|
||||
// }
|
||||
// return res
|
||||
// }
|
||||
|
|
|
@ -14,7 +14,10 @@ var serverClosed = make(chan bool, 1)
|
|||
var clientClosed = make(chan bool, 1)
|
||||
|
||||
func TestProtocol(t *testing.T) {
|
||||
SetLogMask(MaskFATAL | MaskERROR | MaskWARNING | MaskINFO)
|
||||
// SetLogMask(LogMaskDEBUG)
|
||||
// SetLogFlag(LogFlagDEBUG)
|
||||
SetLogMask(LogMaskINFO)
|
||||
SetLogFlag(LogFlagINFO)
|
||||
go testServer(t)
|
||||
time.Sleep(time.Second)
|
||||
go testClient(t)
|
||||
|
@ -73,7 +76,7 @@ func testServer(t *testing.T) {
|
|||
i++
|
||||
msg := fmt.Sprintf("serv msg %d", i)
|
||||
fmt.Printf("[server] send [%s]\n", msg)
|
||||
err = protServer.Write([]byte(msg))
|
||||
err = protServer.Write(FlagHeaderHashCheck|FlagBodyHashCheck, EncryptorAes, []byte(msg))
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrorWriterIsKilled) {
|
||||
glog.Warning("[server] failed to write %v", err)
|
||||
|
@ -136,7 +139,7 @@ func testClient(t *testing.T) {
|
|||
i++
|
||||
msg := fmt.Sprintf("client msg %d", i)
|
||||
fmt.Printf("[client] send [%s]\n", msg)
|
||||
err = protClient.Write([]byte(msg))
|
||||
err = protClient.Write(FlagHeaderHashCheck|FlagBodyHashCheck, EncryptorXor, []byte(msg))
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrorWriterIsKilled) {
|
||||
glog.Warning("[client] failed to write %v", err)
|
||||
|
|
|
@ -24,8 +24,8 @@ func TestQueue(t *testing.T) {
|
|||
if que.size() != 0 || !que.isEmpty() {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
if pkg11.value != 1 || len(pkg11.data) != 1 || pkg11.data[0] != 1 {
|
||||
t.Errorf("value:%d data:%v\n", pkg11.value, pkg11.data)
|
||||
if pkg11.value != 1 || len(pkg11.body) != 1 || pkg11.body[0] != 1 {
|
||||
t.Errorf("value:%d body:%v\n", pkg11.value, pkg11.body)
|
||||
}
|
||||
que.push(pkg2, 0)
|
||||
if que.size() != 1 || que.isEmpty() {
|
||||
|
@ -33,8 +33,8 @@ func TestQueue(t *testing.T) {
|
|||
}
|
||||
|
||||
pkg21 := que.pop(0)
|
||||
if pkg21.value != 2 || len(pkg21.data) != 2 || pkg21.data[0] != 1 || pkg21.data[1] != 2 {
|
||||
t.Errorf("value:%d data:%v\n", pkg21.value, pkg21.data)
|
||||
if pkg21.value != 2 || len(pkg21.body) != 2 || pkg21.body[0] != 1 || pkg21.body[1] != 2 {
|
||||
t.Errorf("value:%d body:%v\n", pkg21.value, pkg21.body)
|
||||
}
|
||||
|
||||
pkg1 = newPackage(0, 0, 1, nil)
|
||||
|
@ -67,8 +67,8 @@ func TestQueue(t *testing.T) {
|
|||
time.Sleep(3 * time.Second)
|
||||
fmt.Println("pop")
|
||||
pkg11 := que2.pop(0)
|
||||
if pkg11.value != 1 || pkg11.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg11.value, pkg11.data)
|
||||
if pkg11.value != 1 || pkg11.body != nil {
|
||||
t.Errorf("value:%d body:%v\n", pkg11.value, pkg11.body)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -79,18 +79,18 @@ func TestQueue(t *testing.T) {
|
|||
}
|
||||
|
||||
pkg21 = que2.pop(0)
|
||||
if pkg21.value != 2 || pkg21.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg21.value, pkg21.data)
|
||||
if pkg21.value != 2 || pkg21.body != nil {
|
||||
t.Errorf("value:%d body:%v\n", pkg21.value, pkg21.body)
|
||||
}
|
||||
|
||||
pkg31 := que2.pop(0)
|
||||
if pkg31.value != 3 || pkg31.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg31.value, pkg31.data)
|
||||
if pkg31.value != 3 || pkg31.body != nil {
|
||||
t.Errorf("value:%d body:%v\n", pkg31.value, pkg31.body)
|
||||
}
|
||||
|
||||
pkg41 := que2.pop(0)
|
||||
if pkg41.value != 4 || pkg31.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg41.value, pkg41.data)
|
||||
if pkg41.value != 4 || pkg31.body != nil {
|
||||
t.Errorf("value:%d body:%v\n", pkg41.value, pkg41.body)
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
@ -101,8 +101,8 @@ func TestQueue(t *testing.T) {
|
|||
|
||||
fmt.Println("wait push")
|
||||
pkg51 := que2.pop(0)
|
||||
if pkg51.value != 5 || len(pkg51.data) != 1 || pkg51.data[0] != 55 {
|
||||
t.Errorf("value:%d data:%v\n", pkg51.value, pkg51.data)
|
||||
if pkg51.value != 5 || len(pkg51.body) != 1 || pkg51.body[0] != 55 {
|
||||
t.Errorf("value:%d body:%v\n", pkg51.value, pkg51.body)
|
||||
}
|
||||
|
||||
fmt.Println("size")
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var uint32BytesPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 4)
|
||||
},
|
||||
}
|
||||
|
||||
func writeUint32Bytes(b *bytes.Buffer, i uint32) {
|
||||
buf := uint32BytesPool.Get().([]byte)
|
||||
binary.BigEndian.PutUint32(buf, i)
|
||||
b.Write(buf)
|
||||
uint32BytesPool.Put(buf)
|
||||
}
|
Loading…
Reference in New Issue