0.0.1
This commit is contained in:
commit
df166ae23b
|
@ -0,0 +1 @@
|
|||
.idea
|
|
@ -0,0 +1,44 @@
|
|||
# Protocol
|
||||
|
||||
一个简易的数据传输协议,保证在流传输协议下,每次Write的数据对方能完整(不多&不少)的接收和处理
|
||||
|
||||
```go
|
||||
// 创建Protocol(只能通过New方法创建才能保证Protocol得到正确初始化)
|
||||
prot = protocol.New(r io.Reader, w io.Writer)
|
||||
// Reader方法为阻塞方式,监听接收数据,每次接收到完整Package后会调用callback来处理
|
||||
go func() {
|
||||
err := prot.Reader(callback func(data []byte))
|
||||
if err != nil {
|
||||
glog.Fatal("failed to enable reader %v", err)
|
||||
}
|
||||
}()
|
||||
// Writer方法为阻塞方法,从队列中取出数据并发送,需要传入队列大小来初始化写入队列
|
||||
go func() {
|
||||
err := prot.Writer(writeQueueSize int)
|
||||
if err != nil {
|
||||
glog.Fatal("failed to enable writer %v", err)
|
||||
}
|
||||
}()
|
||||
// Heartbeat方法为阻塞方法,
|
||||
go func() {
|
||||
err := prot.Heartbeat(sendInterval int, receiveTimeout int, failedCallback func() bool)
|
||||
if err != nil {
|
||||
glog.Fatal("failed to enable heartbeat %v", err)
|
||||
}
|
||||
}()
|
||||
// 发送数据(注意数据长度有限制
|
||||
prot.Write([]byte{0x11, 0x22, 0x33})
|
||||
// 获取上一次收到心跳的时间
|
||||
prot.GetHeartbeatLastReceived()
|
||||
// 获取上一次发送心跳的时间
|
||||
prot.GetHeartbeatLastSend()
|
||||
// 关闭Protocol
|
||||
prot.Kill()
|
||||
|
||||
// Protocol版本号,不同版本存在不兼容的可能性
|
||||
protocol.VERSION
|
||||
// 获取每次Write的最大数据长度
|
||||
protocol.GetDataMaxSize() int
|
||||
// 计算传入的size需要多少次Write才能发送
|
||||
protocol.CalculateTheNumberOfPackages(size int64) int64
|
||||
```
|
|
@ -0,0 +1,8 @@
|
|||
module git.viry.cc/gomod/protocol
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
git.viry.cc/gomod/glog v0.1.3
|
||||
git.viry.cc/gomod/util v1.6.1
|
||||
)
|
|
@ -0,0 +1,4 @@
|
|||
git.viry.cc/gomod/glog v0.1.3 h1:x1ldfyyjp9L2iutobj1c5i/eq3IFOn29OVhu+ELgKYg=
|
||||
git.viry.cc/gomod/glog v0.1.3/go.mod h1:e4ndIpsVbkUwjvf/t5Gs3LJIjuJCw70r91cDGLiodqo=
|
||||
git.viry.cc/gomod/util v1.6.1 h1:rlATwShd4w0ovbJd7hX8jyyF3Cy1/fPgQ0bJHLvEv00=
|
||||
git.viry.cc/gomod/util v1.6.1/go.mod h1:n1+pvIjf5b6F3dCQo552CxWGIfiM//gTVHJ6KKfB1YE=
|
|
@ -0,0 +1,262 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"git.viry.cc/gomod/glog"
|
||||
"git.viry.cc/gomod/util"
|
||||
)
|
||||
|
||||
// package的起始标志
|
||||
var prefix = [headLengthPrefix]uint8{0xff, 0x07, 0x55, 0x00}
|
||||
|
||||
var ErrorPackageEncrypted = errors.New("package is encrypted")
|
||||
var ErrorPackageIncomplete = errors.New("package incomplete")
|
||||
var ErrorUnsupportedVersion = errors.New("unsupported version")
|
||||
var ErrorWrongPrefix = errors.New("prefix does not match")
|
||||
var ErrorBrokenHead = errors.New("head crc32 checksum does not match")
|
||||
var ErrorBrokenData = errors.New("data crc32 checksum does not match")
|
||||
|
||||
// head中各部分的长度
|
||||
const (
|
||||
headLengthPrefix = 4
|
||||
headLengthVersion = 1
|
||||
headLengthCRC32Checksum = 4
|
||||
headLengthFlag = 1
|
||||
headLengthEncryptMethod = 1
|
||||
headLengthCustomValue = 1
|
||||
headLengthDataSize = 4
|
||||
headLengthDataCrc32 = 4
|
||||
)
|
||||
|
||||
// head中各部分的偏移
|
||||
const (
|
||||
headOffsetPrefix = 0
|
||||
headOffsetVersion = headOffsetPrefix + headLengthPrefix
|
||||
headOffsetCRC32Checksum = headOffsetVersion + headLengthVersion
|
||||
headOffsetFlag = headOffsetCRC32Checksum + headLengthCRC32Checksum
|
||||
headOffsetEncryptMethod = headOffsetFlag + headLengthFlag
|
||||
headOffsetCustomValue = headOffsetEncryptMethod + headLengthEncryptMethod
|
||||
headOffsetDataSize = headOffsetCustomValue + headLengthCustomValue
|
||||
headOffsetDataCrc32 = headOffsetDataSize + headLengthDataSize
|
||||
headOffsetData = headOffsetDataCrc32 + headLengthDataCrc32
|
||||
)
|
||||
|
||||
// 计算head的crc32时起始偏移
|
||||
const headOffsetNeedCheck = headOffsetCRC32Checksum + headLengthCRC32Checksum
|
||||
|
||||
// data的加密方式
|
||||
const (
|
||||
encryptNone uint8 = iota
|
||||
)
|
||||
|
||||
// flag标志位
|
||||
const (
|
||||
flagHeartbeat uint8 = 1 << iota
|
||||
)
|
||||
|
||||
// package的head的大小 (byte)
|
||||
const packageHeadSize = headOffsetData
|
||||
|
||||
// package的最大size (byte)
|
||||
const packageMaxSize = 4096
|
||||
|
||||
// package的data的最大size (byte)
|
||||
const dataMaxSize = packageMaxSize - packageHeadSize
|
||||
|
||||
type protocolPackage struct {
|
||||
prefix [headLengthPrefix]byte // 4 byte 0xff 0x55
|
||||
version uint8 // 1 byte protocol version
|
||||
crc32 uint32 // 4 byte head crc32 checksum (BigEndian)
|
||||
flag uint8 // 1 byte flag
|
||||
encryptMethod uint8 // 1 byte encrypted method
|
||||
value uint8 // 1 byte custom value (for heartbeat)
|
||||
dataSize uint32 // 4 byte curSize of data (BigEndian)
|
||||
dataCrc32 uint32 // 4 byte crc32 of data (BigEndian)
|
||||
data []byte // ? byte data
|
||||
}
|
||||
|
||||
// 创建新package, 所有新package的创建都必须通过此方法
|
||||
func newPackage(flag uint8, encrypt uint8, value uint8, data []byte) *protocolPackage {
|
||||
pkg := &protocolPackage{
|
||||
prefix: prefix,
|
||||
version: VERSION,
|
||||
crc32: 0,
|
||||
flag: flag,
|
||||
encryptMethod: encrypt,
|
||||
value: value,
|
||||
dataSize: uint32(len(data)),
|
||||
dataCrc32: 0,
|
||||
data: data,
|
||||
}
|
||||
pkg.generateDataCheck()
|
||||
pkg.generateHeadCheck()
|
||||
return pkg
|
||||
}
|
||||
|
||||
func (p *protocolPackage) Bytes() *bytes.Buffer {
|
||||
buf := &bytes.Buffer{}
|
||||
// prefix
|
||||
buf.Write(p.prefix[:])
|
||||
// version
|
||||
buf.WriteByte(p.version)
|
||||
// crc32
|
||||
crc32 := util.UInt32ToBytes(p.crc32)
|
||||
buf.Write(crc32[:])
|
||||
// flag
|
||||
buf.WriteByte(p.flag)
|
||||
// encrypt method
|
||||
buf.WriteByte(p.encryptMethod)
|
||||
// value
|
||||
buf.WriteByte(p.value)
|
||||
// data curSize
|
||||
dataSize := util.UInt32ToBytes(p.dataSize)
|
||||
buf.Write(dataSize[:])
|
||||
// data crc32
|
||||
dataCrc32 := util.UInt32ToBytes(p.dataCrc32)
|
||||
buf.Write(dataCrc32[:])
|
||||
// data
|
||||
buf.Write(p.data)
|
||||
return buf
|
||||
}
|
||||
|
||||
// 将head中需要校验的数据拼接起来
|
||||
func (p *protocolPackage) headNeedCheckBytes() *bytes.Buffer {
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteByte(p.flag)
|
||||
buf.WriteByte(p.encryptMethod)
|
||||
buf.WriteByte(p.value)
|
||||
bDataSize := util.UInt32ToBytes(p.dataSize)
|
||||
buf.Write(bDataSize[:])
|
||||
bDataCrc32 := util.UInt32ToBytes(p.dataCrc32)
|
||||
buf.Write(bDataCrc32[:])
|
||||
return buf
|
||||
}
|
||||
|
||||
// 生成head的crc32
|
||||
func (p *protocolPackage) generateHeadCheck() {
|
||||
p.crc32 = util.CRC32Bytes(p.headNeedCheckBytes().Bytes())
|
||||
}
|
||||
|
||||
// 校验head的crc32
|
||||
func (p *protocolPackage) checkHead() bool {
|
||||
return p.crc32 == util.CRC32Bytes(p.headNeedCheckBytes().Bytes())
|
||||
}
|
||||
|
||||
// 生成data的crc32
|
||||
func (p *protocolPackage) generateDataCheck() {
|
||||
p.dataCrc32 = util.CRC32Bytes(p.data)
|
||||
}
|
||||
|
||||
// 校验data的crc32
|
||||
func (p *protocolPackage) checkData() bool {
|
||||
if int(p.dataSize) != len(p.data) {
|
||||
return false
|
||||
}
|
||||
dataCrc32 := util.CRC32Bytes(p.data)
|
||||
return p.dataCrc32 == dataCrc32
|
||||
}
|
||||
|
||||
// encrypt 加密data
|
||||
func (p *protocolPackage) encrypt(method uint8) error {
|
||||
if p.encryptMethod != encryptNone {
|
||||
return ErrorPackageEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decrypt 解密data
|
||||
func (p *protocolPackage) decrypt() {
|
||||
if p.encryptMethod == encryptNone {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// parsePackage 从buf中读取一个package
|
||||
//
|
||||
// 如果协议头标志(prefix)不匹配,删除buf中除第一个字符外,下一个prefix1到buf开头的所有数据
|
||||
// 如果协议头标志(prefix)匹配,不断从buf中取出数据,填充到package结构体
|
||||
// 如果buf中的数据出错,无法正确提取package, 则返回(nil,true), 且已从buf中提取的数据不会退回buf
|
||||
func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
|
||||
// 判断package的版本
|
||||
// 暂时只处理VERSION版本的package
|
||||
if buf.Len() < headOffsetVersion+headLengthVersion {
|
||||
glog.Trace("[protocol.parsePackage] incomplete version information, need %d got %d", headOffsetVersion+headLengthVersion, buf.Len())
|
||||
return nil, ErrorPackageIncomplete
|
||||
}
|
||||
if buf.Bytes()[headOffsetVersion] != VERSION {
|
||||
glog.Trace("[protocol.parsePackage] unsupported version need %d got %d", VERSION, buf.Bytes()[headOffsetVersion])
|
||||
nextPackageHead(buf)
|
||||
return nil, ErrorUnsupportedVersion
|
||||
}
|
||||
// 开始判断是否为package并提取package
|
||||
if buf.Len() < packageHeadSize {
|
||||
glog.Trace("[protocol.parsePackage] incomplete head, need %d got %d", packageHeadSize, buf.Len())
|
||||
return nil, ErrorPackageIncomplete
|
||||
}
|
||||
head := make([]byte, packageHeadSize)
|
||||
copy(head, buf.Bytes()[:packageHeadSize])
|
||||
// 协议头标志不匹配,删除未知数据,寻找下一个package起始位置
|
||||
if !bytes.Equal(prefix[:], head[headOffsetPrefix:headOffsetPrefix+headLengthPrefix]) {
|
||||
glog.Trace("[protocol.parsePackage] prefix does not match")
|
||||
nextPackageHead(buf)
|
||||
return nil, ErrorWrongPrefix
|
||||
}
|
||||
// 检查head是否完整,删除未知数据,寻找下一个package起始位置
|
||||
headCrc32 := util.CRC32Bytes(head[headOffsetNeedCheck:])
|
||||
headChecksum := util.BytesToUInt32(head[headOffsetCRC32Checksum : headOffsetCRC32Checksum+headLengthCRC32Checksum])
|
||||
if headChecksum != headCrc32 {
|
||||
glog.Trace("[protocol.parsePackage] head crc32 checksum does not match")
|
||||
nextPackageHead(buf)
|
||||
return nil, ErrorBrokenHead
|
||||
}
|
||||
// 检查package是否完整,不完整则等待
|
||||
packageDataSize := util.BytesToUInt32(head[headOffsetDataSize : headOffsetDataSize+headLengthDataSize])
|
||||
if packageHeadSize+int(packageDataSize) > buf.Len() {
|
||||
glog.Trace("[protocol.parsePackage] incomplete data, need %d got %d", packageDataSize, buf.Len())
|
||||
return nil, ErrorPackageIncomplete
|
||||
}
|
||||
// package完整
|
||||
pkg := &protocolPackage{}
|
||||
_, _ = buf.Read(make([]byte, packageHeadSize))
|
||||
// prefix
|
||||
copy(pkg.prefix[:], head[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
|
||||
// crc32
|
||||
pkg.version = head[headOffsetVersion]
|
||||
// crc32
|
||||
pkg.crc32 = headCrc32
|
||||
// flag
|
||||
pkg.flag = head[headOffsetFlag]
|
||||
// encrypt method
|
||||
pkg.encryptMethod = head[headOffsetEncryptMethod]
|
||||
// value
|
||||
pkg.value = head[headOffsetCustomValue]
|
||||
// data curSize
|
||||
pkg.dataSize = packageDataSize
|
||||
// data crc32
|
||||
pkg.dataCrc32 = util.BytesToUInt32(head[headOffsetDataCrc32 : headOffsetDataCrc32+headLengthDataCrc32])
|
||||
// data
|
||||
pkg.data = make([]byte, pkg.dataSize)
|
||||
_, _ = buf.Read(pkg.data)
|
||||
dataCrc32 := util.CRC32Bytes(pkg.data)
|
||||
if pkg.dataCrc32 != dataCrc32 {
|
||||
glog.Trace("[protocol.parsePackage] data crc32 checksum does not match, need %d got %d", pkg.dataCrc32, dataCrc32)
|
||||
nextPackageHead(buf)
|
||||
return nil, ErrorBrokenData
|
||||
}
|
||||
return pkg, nil
|
||||
}
|
||||
|
||||
// 删除掉buf中第一个byte, 并将buf中的起始位置调整到与prefix[0]相同的下一个元素的位置
|
||||
func nextPackageHead(buf *bytes.Buffer) {
|
||||
var err error
|
||||
_, _ = buf.ReadByte()
|
||||
_, err = buf.ReadBytes(prefix[0]) // 只搜索与prefix[0]相同元素,防止prefix[0]出现在buf末尾
|
||||
if err == nil { // 找到下一个协议头标志,把删掉的prefix回退到buffer
|
||||
glog.Trace("[protocol.nextPackageHead] prefix does not match, prefix[0] found")
|
||||
_ = buf.UnreadByte()
|
||||
} else { // 找不到下一个协议头标志,清空buffer
|
||||
glog.Trace("[protocol.nextPackageHead] prefix does not match, prefix[0] not found")
|
||||
buf.Reset()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,249 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.viry.cc/gomod/glog"
|
||||
)
|
||||
|
||||
const VERSION uint8 = 1
|
||||
|
||||
var ErrorReadCallbackIsNil = errors.New("read callback is nil")
|
||||
var ErrorReaderIsNil = errors.New("reader is nil")
|
||||
var ErrorWriterIsNil = errors.New("writer is nil")
|
||||
var ErrorReaderIsKilled = errors.New("reader is killed")
|
||||
var ErrorWriterIsKilled = errors.New("writer is killed")
|
||||
var ErrorWriterQueueIsNil = errors.New("writer queue is nil")
|
||||
var ErrorHeartbeatIsKilled = errors.New("heartbeat is killed")
|
||||
var ErrorHeartbeatCallbackIsNil = errors.New("heartbeat callback is nil")
|
||||
var ErrorDataSizeExceedsLimit = errors.New("data size exceeds limit")
|
||||
|
||||
const (
|
||||
statusRunning int32 = iota
|
||||
statusKilled
|
||||
)
|
||||
|
||||
type Protocol struct {
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
|
||||
status int32
|
||||
readCallback func(data []byte)
|
||||
writeQueue *queue
|
||||
|
||||
heartbeatSig chan uint8
|
||||
heartbeatLastSend int64
|
||||
heartbeatLastReceived int64
|
||||
}
|
||||
|
||||
func New(r io.Reader, w io.Writer) *Protocol {
|
||||
return &Protocol{
|
||||
r: r,
|
||||
w: w,
|
||||
status: statusRunning,
|
||||
readCallback: nil,
|
||||
writeQueue: nil,
|
||||
heartbeatSig: make(chan uint8, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Protocol) setStatus(status int32) {
|
||||
atomic.StoreInt32(&p.status, status)
|
||||
}
|
||||
|
||||
func (p *Protocol) getStatus() int32 {
|
||||
return atomic.LoadInt32(&p.status)
|
||||
}
|
||||
|
||||
func (p *Protocol) handlePackage(pkg *protocolPackage) {
|
||||
glog.Trace("[protocol.handlePackage] new package")
|
||||
if pkg == nil {
|
||||
return
|
||||
}
|
||||
pkg.decrypt()
|
||||
if !pkg.checkHead() {
|
||||
glog.Trace("[protocol.handlePackage] broken head")
|
||||
return
|
||||
}
|
||||
if (pkg.flag & flagHeartbeat) != 0 {
|
||||
p.heartbeatSig <- pkg.value
|
||||
}
|
||||
if !pkg.checkData() {
|
||||
glog.Trace("[protocol.handlePackage] broken data")
|
||||
return
|
||||
}
|
||||
if pkg.dataSize == 0 {
|
||||
glog.Trace("[protocol.handlePackage] empty data")
|
||||
return
|
||||
}
|
||||
p.readCallback(pkg.data)
|
||||
}
|
||||
|
||||
// Reader 阻塞接收数据并提交给readCallback
|
||||
func (p *Protocol) Reader(callback func(data []byte)) error {
|
||||
if p.r == nil {
|
||||
glog.Warning("[protocol.Reader] protocol is not ready")
|
||||
return ErrorReaderIsNil
|
||||
}
|
||||
if callback == nil {
|
||||
glog.Warning("[protocol.Reader] protocol is not ready")
|
||||
return ErrorReadCallbackIsNil
|
||||
}
|
||||
p.readCallback = callback
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
buf := make([]byte, packageMaxSize)
|
||||
var err error
|
||||
var n int
|
||||
// 监听并接收数据
|
||||
for {
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Warning("[protocol.Reader] is killed")
|
||||
return ErrorReaderIsKilled
|
||||
}
|
||||
n, err = p.r.Read(buf)
|
||||
if err != nil {
|
||||
glog.Warning("[protocol.Reader] r.read err: %v", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
if n == 0 {
|
||||
glog.Warning("[protocol.Reader] r.read: zero length")
|
||||
continue
|
||||
}
|
||||
n, err = buffer.Write(buf[:n])
|
||||
glog.Trace("[protocol.Reader] buffer already %d bytes, write %d bytes, %v", buffer.Len(), n, err)
|
||||
for buffer.Len() >= packageHeadSize {
|
||||
glog.Trace("[protocol.Reader] complete buffer length %d", buffer.Len())
|
||||
pkg, err := parsePackage(buffer)
|
||||
if errors.Is(err, ErrorPackageIncomplete) {
|
||||
glog.Trace("[protocol.Reader] incomplete buffer length %d", buffer.Len())
|
||||
break
|
||||
}
|
||||
if pkg != nil {
|
||||
go p.handlePackage(pkg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Writer 创建发送队列并监听待发送数据
|
||||
func (p *Protocol) Writer(writeQueueSize int) error {
|
||||
if p.w == nil {
|
||||
glog.Warning("[protocol.Writer] protocol is not ready")
|
||||
return ErrorWriterIsNil
|
||||
}
|
||||
if writeQueueSize < 1 {
|
||||
writeQueueSize = 1
|
||||
}
|
||||
var err error
|
||||
p.writeQueue = newQueue(writeQueueSize)
|
||||
for {
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Warning("[protocol.Writer] is killed")
|
||||
return ErrorWriterIsKilled
|
||||
}
|
||||
pkg := p.writeQueue.pop()
|
||||
_, err = p.w.Write(pkg.Bytes().Bytes())
|
||||
if err != nil {
|
||||
go p.writeQueue.push(pkg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write 发送数据
|
||||
func (p *Protocol) Write(data []byte) error {
|
||||
if len(data) > dataMaxSize {
|
||||
return ErrorDataSizeExceedsLimit
|
||||
}
|
||||
if p.w == nil {
|
||||
glog.Warning("[protocol.Write] protocol is not ready")
|
||||
return ErrorWriterIsNil
|
||||
}
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Warning("[protocol.Write] is killed")
|
||||
return ErrorWriterIsKilled
|
||||
}
|
||||
if p.writeQueue == nil {
|
||||
glog.Warning("[protocol.Write] queue is nil")
|
||||
return ErrorWriterQueueIsNil
|
||||
}
|
||||
pkg := newPackage(0, encryptNone, 0, data)
|
||||
p.writeQueue.push(pkg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Heartbeat 心跳服务
|
||||
//
|
||||
// sendInterval: 主动发送心跳信号的间隔时间(s),最小为3s,,传入参数小于3时使用默认值3
|
||||
// receiveTimeout: 被动接收心跳信号的超时时间(s),最小为3s,传入参数小于3时使用默认值3
|
||||
// failedCallback: 没有按时收到心跳信号时调用,返回true继续等待,返回false退出
|
||||
func (p *Protocol) Heartbeat(sendInterval int, receiveTimeout int, failedCallback func() bool) error {
|
||||
if receiveTimeout < 3 {
|
||||
receiveTimeout = 3
|
||||
}
|
||||
if sendInterval < 3 {
|
||||
sendInterval = 3
|
||||
}
|
||||
if failedCallback == nil {
|
||||
glog.Trace("[protocol.Heartbeat] failedCallback is nil")
|
||||
return ErrorHeartbeatCallbackIsNil
|
||||
}
|
||||
// 发送心跳信号
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Duration(sendInterval) * time.Second):
|
||||
if p.getStatus() == statusKilled {
|
||||
return
|
||||
}
|
||||
glog.Trace("[protocol.Heartbeat] sig send")
|
||||
atomic.StoreInt64(&p.heartbeatLastSend, time.Now().Unix())
|
||||
p.writeQueue.push(newPackage(flagHeartbeat, encryptNone, 0, nil))
|
||||
}
|
||||
}
|
||||
}()
|
||||
// 接收心跳信号
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Duration(receiveTimeout) * time.Second):
|
||||
glog.Trace("[protocol.Heartbeat] heartbeat failed")
|
||||
if !failedCallback() {
|
||||
glog.Trace("[protocol.Heartbeat] heartbeat killed")
|
||||
p.setStatus(statusKilled)
|
||||
return ErrorHeartbeatIsKilled
|
||||
}
|
||||
case <-p.heartbeatSig:
|
||||
glog.Trace("[protocol.Heartbeat] sig rev")
|
||||
atomic.StoreInt64(&p.heartbeatLastReceived, time.Now().Unix())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Protocol) GetHeartbeatLastReceived() int64 {
|
||||
return atomic.LoadInt64(&p.heartbeatLastReceived)
|
||||
}
|
||||
|
||||
func (p *Protocol) GetHeartbeatLastSend() int64 {
|
||||
return atomic.LoadInt64(&p.heartbeatLastSend)
|
||||
}
|
||||
|
||||
func (p *Protocol) Kill() {
|
||||
p.setStatus(statusKilled)
|
||||
}
|
||||
|
||||
func GetDataMaxSize() int {
|
||||
return dataMaxSize
|
||||
}
|
||||
|
||||
func CalculateTheNumberOfPackages(size int64) int64 {
|
||||
res := size / dataMaxSize
|
||||
if size%dataMaxSize != 0 {
|
||||
res += 1
|
||||
}
|
||||
return res
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type queue struct {
|
||||
poolStart chan bool
|
||||
poolEnd chan bool
|
||||
pushLock sync.Mutex
|
||||
popLock sync.Mutex
|
||||
maxSize int
|
||||
curSize int32
|
||||
wIndex int
|
||||
rIndex int
|
||||
queue []*protocolPackage
|
||||
}
|
||||
|
||||
func newQueue(size int) *queue {
|
||||
if size < 1 {
|
||||
size = 1
|
||||
}
|
||||
return &queue{
|
||||
// Start和End信号池用于保证push和pop操作不会互相干扰
|
||||
// 每次Push和Pop操作后,两个信号池中的信号数量都会保持一致
|
||||
poolStart: make(chan bool, size),
|
||||
poolEnd: make(chan bool, size),
|
||||
// 保证push操作完整性
|
||||
pushLock: sync.Mutex{},
|
||||
// 保证pop操作完整性
|
||||
popLock: sync.Mutex{},
|
||||
// 队列中元素最大数量
|
||||
maxSize: size,
|
||||
// 队列当前元素数量
|
||||
curSize: 0,
|
||||
// push指针
|
||||
wIndex: 0,
|
||||
// pop指针
|
||||
rIndex: 0,
|
||||
// 元素数组
|
||||
queue: make([]*protocolPackage, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *queue) push(item *protocolPackage) {
|
||||
q.pushLock.Lock()
|
||||
defer func() {
|
||||
// push成功后队列大小+1
|
||||
atomic.AddInt32(&q.curSize, 1)
|
||||
q.pushLock.Unlock()
|
||||
// 操作必定成功,向End信号池发送一个信号,表示完成此次push
|
||||
q.poolEnd <- true
|
||||
}()
|
||||
// 操作成功代表队列不满,向Start信号池发送一个信号,表示开始push
|
||||
q.poolStart <- true
|
||||
|
||||
q.queue[q.wIndex] = item
|
||||
|
||||
q.wIndex++
|
||||
if q.wIndex >= q.maxSize {
|
||||
q.wIndex = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (q *queue) pop() (item *protocolPackage) {
|
||||
q.popLock.Lock()
|
||||
defer func() {
|
||||
// pop成功后队列大小-1
|
||||
atomic.AddInt32(&q.curSize, -1)
|
||||
q.popLock.Unlock()
|
||||
// 操作必定成功,当前元素已经成功取出,释放当前位置
|
||||
<-q.poolStart
|
||||
}()
|
||||
// 操作成功代表队列非空,只有End信号池中有信号,才能保证有完整的元素在队列中
|
||||
<-q.poolEnd
|
||||
|
||||
item = q.queue[q.rIndex]
|
||||
|
||||
q.queue[q.rIndex] = nil
|
||||
|
||||
q.rIndex++
|
||||
if q.rIndex >= q.maxSize {
|
||||
q.rIndex = 0
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (q *queue) size() int32 {
|
||||
return atomic.LoadInt32(&q.curSize)
|
||||
}
|
||||
|
||||
func (q *queue) isEmpty() bool {
|
||||
return atomic.LoadInt32(&q.curSize) == 0
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestQueue(t *testing.T) {
|
||||
pkg1 := newPackage(0, 0, 1, []byte{1})
|
||||
pkg2 := newPackage(0, 0, 2, []byte{1, 2})
|
||||
|
||||
que := newQueue(0)
|
||||
if que.size() != 0 || !que.isEmpty() {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
que.push(pkg1)
|
||||
if que.size() != 1 || que.isEmpty() {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
pkg11 := que.pop()
|
||||
if que.size() != 0 || !que.isEmpty() {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
if pkg11.value != 1 || len(pkg11.data) != 1 || pkg11.data[0] != 1 {
|
||||
t.Errorf("value:%d data:%v\n", pkg11.value, pkg11.data)
|
||||
}
|
||||
que.push(pkg2)
|
||||
if que.size() != 1 || que.isEmpty() {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
pkg21 := que.pop()
|
||||
if pkg21.value != 2 || len(pkg21.data) != 2 || pkg21.data[0] != 1 || pkg21.data[1] != 2 {
|
||||
t.Errorf("value:%d data:%v\n", pkg21.value, pkg21.data)
|
||||
}
|
||||
|
||||
pkg1 = newPackage(0, 0, 1, nil)
|
||||
pkg2 = newPackage(0, 0, 2, nil)
|
||||
pkg3 := newPackage(0, 0, 3, nil)
|
||||
pkg4 := newPackage(0, 0, 4, nil)
|
||||
pkg5 := newPackage(0, 0, 5, []byte{55})
|
||||
|
||||
que2 := newQueue(3)
|
||||
if que2.size() != 0 {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
que2.push(pkg1)
|
||||
if que2.size() != 1 {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
que2.push(pkg2)
|
||||
if que2.size() != 2 {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
que2.push(pkg3)
|
||||
if que2.size() != 3 {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
fmt.Println("pop")
|
||||
pkg11 := que2.pop()
|
||||
if pkg11.value != 1 || pkg11.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg11.value, pkg11.data)
|
||||
}
|
||||
}()
|
||||
|
||||
fmt.Println("wait pop")
|
||||
que2.push(pkg4)
|
||||
if que2.size() != 3 {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
pkg21 = que2.pop()
|
||||
if pkg21.value != 2 || pkg21.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg21.value, pkg21.data)
|
||||
}
|
||||
|
||||
pkg31 := que2.pop()
|
||||
if pkg31.value != 3 || pkg31.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg31.value, pkg31.data)
|
||||
}
|
||||
|
||||
pkg41 := que2.pop()
|
||||
if pkg41.value != 4 || pkg31.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg41.value, pkg41.data)
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
fmt.Println("push")
|
||||
que2.push(pkg5)
|
||||
}()
|
||||
|
||||
fmt.Println("wait push")
|
||||
pkg51 := que2.pop()
|
||||
if pkg51.value != 5 || len(pkg51.data) != 1 || pkg51.data[0] != 55 {
|
||||
t.Errorf("value:%d data:%v\n", pkg51.value, pkg51.data)
|
||||
}
|
||||
|
||||
fmt.Println("size")
|
||||
if que2.size() != 0 {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue