heartbeat
Test / testing (1.19.13, ubuntu-latest) (push) Failing after 1m39s Details
Test / testing (>=1.20, ubuntu-latest) (push) Failing after 1m59s Details

This commit is contained in:
Akvicor 2024-02-28 02:49:54 +08:00
parent 51d2857220
commit 86c89364cc
7 changed files with 355 additions and 264 deletions

114
README.md
View File

@ -11,107 +11,11 @@
import "git.viry.cc/gomod/protocol"
```
```go
func testServer(t *testing.T) {
listen, err := net.Listen("tcp", "0.0.0.0:9999")
if err != nil {
glog.Error("[S] Listen() failed, err: %s", err)
return
}
glog.Info("[S] Listen 0.0.0.0:9999")
for {
conn, err := listen.Accept() // 监听客户端的连接请求
if err != nil {
glog.Error("[S] Accept() failed, err: %s", err)
continue
}
glog.Info("[S] Accept %s %s", conn.LocalAddr().String(), conn.RemoteAddr().String())
var Index = 0
prot := New(conn, conn, 8, func(data []byte) {
fmt.Printf("[S] received [%s]\n", string(data))
Index++
if fmt.Sprintf("client msg %d", Index) != string(data) {
t.Errorf("test client error need %s got %s", fmt.Sprintf("client msg %d", Index), string(data))
}
}, func() bool {
fmt.Println("[S] heartbeat timeout")
t.Error("heartbeat timeout")
return false
})
prot.Connect(true)
go func() {
time.Sleep(30 * time.Second)
if prot.GetHeartbeatLastSend() == 0 {
t.Error("GetHeartbeatLastSend is zero")
}
if prot.GetHeartbeatLastReceived() == 0 {
t.Error("GetHeartbeatLastReceived is zero")
}
prot.Kill()
}()
i := 0
for {
time.Sleep(5 * time.Second)
i++
msg := fmt.Sprintf("server msg %d", i)
fmt.Printf("[S] send [%s]\n", msg)
err := prot.Write([]byte(msg))
if err != nil {
glog.Warning("[S] failed to write %v", err)
break
}
}
}
}
func testClient(t *testing.T) {
conn, err := net.Dial("tcp", "127.0.0.1:9999")
if err != nil {
glog.Error("[C] Dial() failed, err: %s", err)
return
}
glog.Info("[C] Connected")
var Index = 0
prot := New(conn, conn, 8, func(data []byte) {
fmt.Printf("[C] received [%s]\n", string(data))
Index++
if fmt.Sprintf("server msg %d", Index) != string(data) {
t.Errorf("test client error need %s got %s", fmt.Sprintf("server msg %d", Index), string(data))
}
}, func() bool {
fmt.Println("[C] heartbeat timeout")
t.Error("heartbeat timeout")
return false
})
prot.Connect(false)
go func() {
time.Sleep(30 * time.Second)
if prot.GetHeartbeatLastSend() == 0 {
t.Error("GetHeartbeatLastSend is zero")
}
if prot.GetHeartbeatLastReceived() == 0 {
t.Error("GetHeartbeatLastReceived is zero")
}
prot.Kill()
}()
time.Sleep(1 * time.Second)
i := 0
for {
time.Sleep(5 * time.Second)
i++
msg := fmt.Sprintf("client msg %d", i)
fmt.Printf("[C] send [%s]\n", msg)
err = prot.Write([]byte(msg))
if err != nil {
glog.Warning("[C] failed to write %v", err)
}
break
}
}
```
样例请参照`protocol_test.go`中的示例示例中包含了对tcp的封装
```go
// 创建protocol封装
New(tag string, r io.Reader, w io.Writer, writeQueueSize int, readCallback func(data []byte), heartbeatTimeoutCallback func() bool, setReadDeadline, setWriteDeadline, killCallback func()) *Protocol
// 启动传输,通过参数确定是否是心跳服务端(主动发出心跳信号一方)
// 如果传输双方均为发出方,不会影响正常服务,但会产生不必要的心跳
Connect(bool)
@ -127,9 +31,21 @@ GetHeartbeatTimeout() uint8
GetHeartbeatLastReceived()
// 获取上一次发送心跳的时间
GetHeartbeatLastSend()
// 获取正在运行的协程的数量
GetRunningRoutine() int32
// 等待Protocol关闭
WaitKilled(timeout int) error
// 获取tag
GetTag() string
SetTag(tag string)
// 关闭Protocol
Kill()
// 设置log模式
SetLogProd(isProd bool)
SetLogMask(mask uint32)
SetLogFlag(f uint32)
// Protocol版本号不同版本存在不兼容的可能性
protocol.VERSION
// 获取每次Write的最大数据长度

6
log.go
View File

@ -1,8 +1,8 @@
package protocol
import "git.viry.cc/gomod/glog"
const logPrefix = "[protocol]"
import (
"git.viry.cc/gomod/glog"
)
const (
MaskUNKNOWN = glog.MaskUNKNOWN

View File

@ -52,7 +52,9 @@ const (
// flag标志位
const (
// 普通心跳信号,心跳响应信号
flagHeartbeat uint8 = 1 << iota
// 心跳请求信号接收方必须回复flagHeartbeat
flagHeartbeatRequest
)
@ -133,7 +135,7 @@ func (p *protocolPackage) headNeedCheckBytes() *bytes.Buffer {
// 生成head的crc32
func (p *protocolPackage) generateHeadCheck() {
p.crc32 = util.NewCRC32().FromBytes(p.headNeedCheckBytes().Bytes()).Value()
glog.Trace("%shead crc32 is %d", logPrefix, p.crc32)
glog.Trace("[protocol_package] head crc32 is %d", p.crc32)
}
// 校验head的crc32
@ -144,13 +146,13 @@ func (p *protocolPackage) checkHead() bool {
// 生成data的crc32
func (p *protocolPackage) generateDataCheck() {
p.dataCrc32 = util.NewCRC32().FromBytes(p.data).Value()
glog.Trace("data crc32 is %d", p.dataCrc32)
glog.Trace("[protocol_package] data crc32 is %d", p.dataCrc32)
}
// 校验data的crc32
func (p *protocolPackage) checkData() bool {
if int(p.dataSize) != len(p.data) {
glog.Trace("pkg.dataSize != len(pkg.data)")
glog.Trace("[protocol_package] pkg.dataSize != len(pkg.data)")
return false
}
return p.dataCrc32 == util.NewCRC32().FromBytes(p.data).Value()
@ -159,20 +161,20 @@ func (p *protocolPackage) checkData() bool {
// encrypt 加密data
func (p *protocolPackage) encrypt(method uint8) {
if p.encryptMethod == method {
glog.Trace("is already encrypted [%d]", method)
glog.Trace("[protocol_package] is already encrypted [%d]", method)
return // 已经加密
}
if p.encryptMethod != encryptNone {
glog.Trace("encrypt with other method got [%d] need encryptNone[%d]", p.encryptMethod, encryptNone)
glog.Trace("[protocol_package] encrypt with other method got [%d] need encryptNone[%d]", p.encryptMethod, encryptNone)
return // 已经通过其他方式加密
}
glog.Warning("unknown encrypt method")
glog.Warning("[protocol_package] unknown encrypt method")
}
// decrypt 解密data
func (p *protocolPackage) decrypt() {
if !p.isEncrypted() {
glog.Trace("is not encrypted")
glog.Trace("[protocol_package] is not encrypted")
return
}
}
@ -191,24 +193,24 @@ func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
// 判断package的版本
// 暂时只处理VERSION版本的package
if buf.Len() < headOffsetVersion+headLengthVersion {
glog.Trace("incomplete version information, need %d got %d", headOffsetVersion+headLengthVersion, buf.Len())
glog.Trace("[protocol_package] incomplete version information, need %d got %d", headOffsetVersion+headLengthVersion, buf.Len())
return nil, ErrorPackageIncomplete
}
if buf.Bytes()[headOffsetVersion] != VERSION {
glog.Trace("unsupported version need %d got %d", VERSION, buf.Bytes()[headOffsetVersion])
glog.Trace("[protocol_package] unsupported version need %d got %d", VERSION, buf.Bytes()[headOffsetVersion])
nextPackageHead(buf)
return nil, ErrorUnsupportedVersion
}
// 开始判断是否为package并提取package
if buf.Len() < packageHeadSize {
glog.Trace("incomplete head, need %d got %d", packageHeadSize, buf.Len())
glog.Trace("[protocol_package] incomplete head, need %d got %d", packageHeadSize, buf.Len())
return nil, ErrorPackageIncomplete
}
head := make([]byte, packageHeadSize)
copy(head, buf.Bytes()[:packageHeadSize])
// 协议头标志不匹配删除未知数据寻找下一个package起始位置
if !bytes.Equal(prefix[:], head[headOffsetPrefix:headOffsetPrefix+headLengthPrefix]) {
glog.Trace("prefix does not match, need %v got %v", prefix, head[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
glog.Trace("[protocol_package] prefix does not match, need %v got %v", prefix, head[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
nextPackageHead(buf)
return nil, ErrorWrongPrefix
}
@ -216,14 +218,14 @@ func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
headChecksum := util.BytesSliceToUInt32(head[headOffsetCRC32Checksum : headOffsetCRC32Checksum+headLengthCRC32Checksum])
headCrc32 := util.NewCRC32().FromBytes(head[headOffsetNeedCheck:]).Value()
if headChecksum != headCrc32 {
glog.Trace("head crc32 checksum does not match, need %d got %d", headChecksum, headCrc32)
glog.Trace("[protocol_package] head crc32 checksum does not match, need %d got %d", headChecksum, headCrc32)
nextPackageHead(buf)
return nil, ErrorBrokenHead
}
// 检查package是否完整不完整则等待
packageDataSize := util.BytesSliceToUInt32(head[headOffsetDataSize : headOffsetDataSize+headLengthDataSize])
if packageHeadSize+int(packageDataSize) > buf.Len() {
glog.Trace("incomplete data, need %d got %d", packageHeadSize+packageDataSize, buf.Len())
glog.Trace("[protocol_package] incomplete data, need %d got %d", packageHeadSize+packageDataSize, buf.Len())
return nil, ErrorPackageIncomplete
}
// package完整
@ -250,7 +252,7 @@ func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
_, _ = buf.Read(pkg.data)
dataCrc32 := util.NewCRC32().FromBytes(pkg.data).Value()
if pkg.dataCrc32 != dataCrc32 {
glog.Trace("data crc32 checksum does not match, need %d got %d", pkg.dataCrc32, dataCrc32)
glog.Trace("[protocol_package] data crc32 checksum does not match, need %d got %d", pkg.dataCrc32, dataCrc32)
nextPackageHead(buf)
return nil, ErrorBrokenData
}
@ -264,9 +266,9 @@ func nextPackageHead(buf *bytes.Buffer) {
_, err = buf.ReadBytes(prefix[0]) // 只搜索与prefix[0]相同元素防止prefix[0]出现在buf末尾
if err == nil { // 找到下一个协议头标志把删掉的prefix回退到buffer
_ = buf.UnreadByte()
glog.Trace("prefix does not match, prefix[0] found, trim buf, buf length [%d]", buf.Len())
glog.Trace("[protocol_package] prefix does not match, prefix[0] found, trim buf, buf length [%d]", buf.Len())
} else { // 找不到下一个协议头标志清空buffer
buf.Reset()
glog.Trace("prefix does not match, prefix[0] not found, reset buf, buf length [%d]", buf.Len())
glog.Trace("[protocol_package] prefix does not match, prefix[0] not found, reset buf, buf length [%d]", buf.Len())
}
}

View File

@ -21,6 +21,7 @@ var ErrorWriterQueueIsNil = errors.New("writer queue is nil")
var ErrorHeartbeatIsKilled = errors.New("heartbeat is killed")
var ErrorHeartbeatCallbackIsNil = errors.New("heartbeat callback is nil")
var ErrorDataSizeExceedsLimit = errors.New("data size exceeds limit")
var ErrorTimeout = errors.New("timeout")
const (
statusRunning int32 = iota
@ -28,49 +29,87 @@ const (
)
type Protocol struct {
r io.Reader
w io.Writer
// 标记protocol
tag string
r io.Reader
w io.Writer
status int32
// protocol的状态
status int32
// 用于处理获取到的数据每个package中的数据都会完整的保存在data中
readCallback func(data []byte)
writeQueue *queue
// 写入等待队列
writeQueue *queue
// 当前protocol正在运行的协程数量
runningRoutines int32
heartbeatSig chan uint8
heartbeatSigReq chan uint8
heartbeatInterval uint32
heartbeatTimeout uint32
// 心跳信号,同时也是心跳响应信号
heartbeatSig chan uint8
// 心跳请求信号,收到此信号必须回复对方
heartbeatSigReq chan uint8
// 发送心跳请求的间隔
heartbeatInterval uint32
// 接收心跳请求的超时时间
heartbeatTimeout uint32
// 心跳请求超时后的处理函数
heartbeatTimeoutCallback func() bool
heartbeatLastSend int64
heartbeatLastReceived int64
// 上次发送心跳的时间
heartbeatLastSend int64
// 上次收到心跳的时间
heartbeatLastReceived int64
// status被标记为statusKilled时执行可以用于关闭reader和writer
killCallback func()
// 在reader读取数据前设置reader的读取截止时间
setReadDeadline func()
// 在writer读取数据前设置writer的读取截止时间
setWriteDeadline func()
}
func New(r io.Reader, w io.Writer, writeQueueSize int, readCallback func(data []byte), heartbeatTimeoutCallback func() bool) *Protocol {
// New 返回一个protocol实例
//
// tag: 标签用于区分protocol实例
// r: 数据流的reader
// w: 数据流的writer
// writeQueueSize: 发送等待队列长度
// readCallback: 用于处理获取到的数据每个package中的数据都会完整的保存在data中
// heartbeatTimeoutCallback: 心跳请求超时后的处理函数
// setReadDeadline: 在reader读取数据前设置reader的读取截止时间
// setWriteDeadline: 在writer读取数据前设置writer的读取截止时间
// killCallback: status被标记为statusKilled时执行可以用于关闭reader和writer
func New(tag string, r io.Reader, w io.Writer, writeQueueSize int, readCallback func(data []byte), heartbeatTimeoutCallback func() bool, setReadDeadline, setWriteDeadline, killCallback func()) *Protocol {
if r == nil {
glog.Warning("%s reader is nil", logPrefix)
glog.Warning("[protocol.%s] reader is nil", tag)
return nil
}
if w == nil {
glog.Warning("%s writer is nil", logPrefix)
glog.Warning("[protocol.%s] writer is nil", tag)
return nil
}
if writeQueueSize < 1 {
glog.Trace("%s writeQueueSize is < 1, use 1", logPrefix)
glog.Trace("[protocol.%s] writeQueueSize is < 1, use 1", tag)
writeQueueSize = 1
}
if readCallback == nil {
glog.Trace("%s readCallback is nil, use defaultReadCallback", logPrefix)
glog.Trace("[protocol.%s] readCallback is nil, use defaultReadCallback", tag)
readCallback = defaultReadCallback
}
if heartbeatTimeoutCallback == nil {
glog.Trace("%s heartbeatTimeoutCallback is nil, use defaultHeartbeatTimeoutCallback", logPrefix)
glog.Trace("[protocol.%s] heartbeatTimeoutCallback is nil, use defaultHeartbeatTimeoutCallback", tag)
heartbeatTimeoutCallback = defaultHeartbeatTimeoutCallback
}
if killCallback == nil {
glog.Trace("[protocol.%s] killCallback is nil, use defaultKillCallback", tag)
killCallback = defaultKillCallback
}
return &Protocol{
tag: tag,
r: r,
w: w,
status: statusRunning,
readCallback: readCallback,
writeQueue: newQueue(writeQueueSize),
runningRoutines: 0,
heartbeatSig: make(chan uint8, 1),
heartbeatSigReq: make(chan uint8, 1),
heartbeatInterval: 15,
@ -78,6 +117,9 @@ func New(r io.Reader, w io.Writer, writeQueueSize int, readCallback func(data []
heartbeatTimeoutCallback: heartbeatTimeoutCallback,
heartbeatLastSend: 0,
heartbeatLastReceived: 0,
killCallback: killCallback,
setReadDeadline: setReadDeadline,
setWriteDeadline: setWriteDeadline,
}
}
@ -91,46 +133,48 @@ func (p *Protocol) Connect(activeHeartbeatSignalSender bool) {
}
func (p *Protocol) handlePackage(pkg *protocolPackage) {
glog.Trace("%s handle package", logPrefix)
glog.Trace("[protocol.%s] handle package", p.tag)
if pkg == nil {
glog.Trace("%s package is nil", logPrefix)
glog.Trace("[protocol.%s] package is nil", p.tag)
return
}
if pkg.isEncrypted() {
glog.Trace("%s package is encrypted, decrypt package", logPrefix)
glog.Trace("[protocol.%s] package is encrypted, decrypt package", p.tag)
pkg.decrypt()
}
if !pkg.checkHead() {
glog.Trace("%s package head broken", logPrefix)
glog.Trace("[protocol.%s] package head broken", p.tag)
return
}
if (pkg.flag & flagHeartbeat) != 0 {
glog.Trace("%s heartbeat signal in package", logPrefix)
glog.Trace("[protocol.%s] heartbeat signal in package", p.tag)
p.heartbeatSig <- pkg.value
}
if (pkg.flag & flagHeartbeatRequest) != 0 {
glog.Trace("%s heartbeat request signal in package", logPrefix)
glog.Trace("[protocol.%s] heartbeat request signal in package", p.tag)
p.heartbeatSigReq <- pkg.value
}
if !pkg.checkData() {
glog.Trace("%s package data broken", logPrefix)
glog.Trace("[protocol.%s] package data broken", p.tag)
return
}
if pkg.dataSize == 0 {
glog.Trace("%s package data empty", logPrefix)
glog.Trace("[protocol.%s] package data empty", p.tag)
return
}
glog.Trace("%s handle package successful", logPrefix)
glog.Trace("[protocol.%s] handle package successful", p.tag)
p.readCallback(pkg.data)
}
// Reader 阻塞接收数据并提交给readCallback
func (p *Protocol) reader() {
glog.Trace("%s reader enable", logPrefix)
glog.Trace("[protocol.%s] reader enable", p.tag)
if p.r == nil {
glog.Warning("%s reader is not ready", logPrefix)
glog.Warning("[protocol.%s] reader is not ready", p.tag)
return
}
p.incRunningRoutine()
defer p.decRunningRoutine()
buffer := &bytes.Buffer{}
buf := make([]byte, packageMaxSize)
@ -139,31 +183,36 @@ func (p *Protocol) reader() {
// 监听并接收数据
for {
if p.getStatus() == statusKilled {
glog.Trace("%s reader is killed", logPrefix)
glog.Trace("[protocol.%s] reader is killed", p.tag)
return
}
if p.setReadDeadline != nil {
glog.Trace("[protocol.%s] reader set deadline", p.tag)
p.setReadDeadline()
}
glog.Trace("[protocol.%s] reader wait read", p.tag)
n, err = p.r.Read(buf)
if err != nil {
glog.Trace("%s read error %v", logPrefix, err)
glog.Trace("[protocol.%s] read error %v", p.tag, err)
time.Sleep(500 * time.Millisecond)
continue
}
if n == 0 {
glog.Trace("%s read empty", logPrefix)
glog.Trace("[protocol.%s] read empty", p.tag)
time.Sleep(500 * time.Millisecond)
continue
}
n, err = buffer.Write(buf[:n])
glog.Trace("%s write %d bytes, buffer already %d bytes, error is %v", logPrefix, n, buffer.Len(), err)
glog.Trace("[protocol.%s] write %d bytes, buffer already %d bytes, error is %v", p.tag, n, buffer.Len(), err)
for buffer.Len() >= packageHeadSize {
glog.Trace("%s complete package, buffer length %d", logPrefix, buffer.Len())
glog.Trace("[protocol.%s] complete package, buffer length %d", p.tag, buffer.Len())
pkg, err := parsePackage(buffer)
if errors.Is(err, ErrorPackageIncomplete) {
glog.Trace("%s incomplete package, buffer length %d", logPrefix, buffer.Len())
glog.Trace("[protocol.%s] incomplete package, buffer length %d", p.tag, buffer.Len())
break
}
if pkg != nil {
glog.Trace("%s receive new package", logPrefix)
glog.Trace("[protocol.%s] receive new package", p.tag)
go p.handlePackage(pkg)
}
}
@ -172,50 +221,64 @@ func (p *Protocol) reader() {
// Writer 创建发送队列并监听待发送数据
func (p *Protocol) writer() {
glog.Trace("%s writer enable", logPrefix)
glog.Trace("[protocol.%s] writer enable", p.tag)
if p.w == nil {
glog.Warning("%s writer is not ready", logPrefix)
glog.Warning("[protocol.%s] writer is not ready", p.tag)
return
}
p.incRunningRoutine()
defer p.decRunningRoutine()
var err error
var n int
for {
if p.getStatus() == statusKilled {
glog.Trace("%s writer is killed", logPrefix)
glog.Trace("[protocol.%s] writer is killed", p.tag)
return
}
pkg := p.writeQueue.pop()
glog.Trace("[protocol.%s] writer wait pop", p.tag)
pkg := p.writeQueue.pop(int(p.GetHeartbeatInterval()))
if pkg == nil {
glog.Trace("[protocol.%s] writer pop timeout", p.tag)
continue
}
if p.setWriteDeadline != nil {
glog.Trace("[protocol.%s] writer set deadline", p.tag)
p.setWriteDeadline()
}
glog.Trace("[protocol.%s] writer wait write", p.tag)
n, err = p.w.Write(pkg.Bytes().Bytes())
glog.Trace("%s write %d bytes, error is %v", logPrefix, n, err)
glog.Trace("[protocol.%s] write %d bytes, error is %v", p.tag, n, err)
if err != nil {
glog.Trace("%s send package failed, re-push package", logPrefix)
go p.writeQueue.push(pkg)
glog.Trace("[protocol.%s] send package failed, re-push package", p.tag)
time.Sleep(time.Second)
for !p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) {
if p.getStatus() == statusKilled {
glog.Trace("[protocol.%s] writer is killed", p.tag)
return
}
}
}
}
}
// Write 发送数据
func (p *Protocol) Write(data []byte) error {
glog.Trace("%s write", logPrefix)
glog.Trace("[protocol.%s] write", p.tag)
if len(data) > dataMaxSize {
glog.Warning("%s maximum supported data size exceeded", logPrefix)
glog.Info("[protocol.%s] maximum supported data size exceeded", p.tag)
return ErrorDataSizeExceedsLimit
}
if p.w == nil {
glog.Warning("%s writer is not ready", logPrefix)
return ErrorWriterIsNil
}
if p.writeQueue == nil {
glog.Warning("%s protocol is not ready", logPrefix)
return ErrorWriterQueueIsNil
}
if p.getStatus() == statusKilled {
glog.Warning("%s protocol is killed", logPrefix)
return ErrorWriterIsKilled
}
pkg := newPackage(0, encryptNone, 0, data)
p.writeQueue.push(pkg)
return nil
for {
if p.getStatus() == statusKilled {
glog.Info("[protocol.%s] protocol is killed", p.tag)
return ErrorWriterIsKilled
}
if p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) {
return nil
}
}
}
// heartbeat 心跳服务
@ -223,32 +286,39 @@ func (p *Protocol) Write(data []byte) error {
// heartbeatTimeout: 被动接收心跳信号的超时时间(s)最小为3s传入参数小于3时使用默认值30
// heartbeatTimeoutCallback: 没有按时收到心跳信号时调用返回true继续等待返回false退出
func (p *Protocol) heartbeat() {
glog.Trace("%s heartbeat enable", logPrefix)
glog.Trace("[protocol.%s] heartbeat enable", p.tag)
p.incRunningRoutine()
defer p.decRunningRoutine()
for {
select {
case <-time.After(time.Duration(p.GetHeartbeatTimeout()) * time.Second):
glog.Trace("%s heartbeat timeout", logPrefix)
glog.Trace("[protocol.%s] heartbeat timeout", p.tag)
if p.getStatus() == statusKilled {
glog.Trace("[protocol.%s] heartbeat is killed", p.tag)
return
}
if !p.heartbeatTimeoutCallback() {
glog.Trace("%s heartbeat is killed, set status killed", logPrefix)
glog.Trace("[protocol.%s] heartbeat is killed, set status killed", p.tag)
p.setStatus(statusKilled)
return
}
case val := <-p.heartbeatSigReq:
glog.Trace("%s heartbeat request signal received", logPrefix)
glog.Trace("[protocol.%s] heartbeat request signal received", p.tag)
p.setHeartbeatLastReceived()
p.sendHeartbeatSignal(false)
if val != 0 {
p.SetHeartbeatTimeout(val)
}
case val := <-p.heartbeatSig:
glog.Trace("%s heartbeat signal received", logPrefix)
glog.Trace("[protocol.%s] heartbeat signal received", p.tag)
p.setHeartbeatLastReceived()
if val != 0 {
p.SetHeartbeatTimeout(val)
}
}
if p.getStatus() == statusKilled {
glog.Trace("%s heartbeat is killed", logPrefix)
glog.Trace("[protocol.%s] heartbeat is killed", p.tag)
return
}
}
@ -258,9 +328,11 @@ func (p *Protocol) heartbeat() {
//
// heartbeatInterval: 主动发送心跳信号的间隔时间(s)最小为3s传入参数小于3时使用默认值3
func (p *Protocol) heartbeatSignalSender() {
p.incRunningRoutine()
defer p.decRunningRoutine()
for {
if p.getStatus() == statusKilled {
glog.Trace("%s heartbeat signal sender is killed", logPrefix)
glog.Trace("[protocol.%s] heartbeat signal sender is killed", p.tag)
return
}
p.sendHeartbeatSignal(true)
@ -269,28 +341,38 @@ func (p *Protocol) heartbeatSignalSender() {
}
func (p *Protocol) sendHeartbeatSignal(isReq bool) {
glog.Trace("%s send heartbeat signal", logPrefix)
glog.Trace("[protocol.%s] send heartbeat signal", p.tag)
var pkg *protocolPackage
if isReq {
p.writeQueue.push(newPackage(flagHeartbeatRequest, encryptNone, 0, nil))
pkg = newPackage(flagHeartbeatRequest, encryptNone, 0, nil)
} else {
p.writeQueue.push(newPackage(flagHeartbeat, encryptNone, 0, nil))
pkg = newPackage(flagHeartbeat, encryptNone, 0, nil)
}
for !p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) {
if p.getStatus() == statusKilled {
glog.Info("[protocol.%s] protocol is killed", p.tag)
return
}
}
p.setHeartbeatLastSend()
}
func (p *Protocol) setStatus(status int32) {
glog.Trace("%s set status %d", logPrefix, status)
glog.Trace("[protocol.%s] set status %d", p.tag, status)
if status == statusKilled {
p.killCallback()
}
atomic.StoreInt32(&p.status, status)
}
func (p *Protocol) getStatus() int32 {
glog.Trace("%s get status", logPrefix)
glog.Trace("[protocol.%s] get status", p.tag)
return atomic.LoadInt32(&p.status)
}
func (p *Protocol) SetHeartbeatInterval(interval uint8) {
if interval < 3 {
glog.Trace("%s heartbeatInterval is < 3, use 3", logPrefix)
glog.Trace("[protocol.%s] heartbeatInterval is < 3, use 3", p.tag)
interval = 3
}
atomic.StoreUint32(&p.heartbeatInterval, uint32(interval))
@ -302,7 +384,7 @@ func (p *Protocol) GetHeartbeatInterval() uint8 {
func (p *Protocol) SetHeartbeatTimeout(timeout uint8) {
if timeout < 6 {
glog.Trace("%s heartbeatTimeout is < 6, use 6", logPrefix)
glog.Trace("[protocol.%s] heartbeatTimeout is < 6, use 6", p.tag)
timeout = 6
}
atomic.StoreUint32(&p.heartbeatTimeout, uint32(timeout))
@ -328,19 +410,57 @@ func (p *Protocol) GetHeartbeatLastSend() int64 {
return atomic.LoadInt64(&p.heartbeatLastSend)
}
func (p *Protocol) incRunningRoutine() {
atomic.AddInt32(&p.runningRoutines, 1)
}
func (p *Protocol) decRunningRoutine() {
atomic.AddInt32(&p.runningRoutines, -1)
}
func (p *Protocol) GetRunningRoutine() int32 {
return atomic.LoadInt32(&p.runningRoutines)
}
func (p *Protocol) WaitKilled(timeout int) error {
out := time.After(time.Duration(timeout) * time.Second)
for {
select {
case <-out:
return ErrorTimeout
case <-time.After(time.Second):
if p.GetRunningRoutine() <= 0 {
return nil
}
}
}
}
func (p *Protocol) GetTag() string {
return p.tag
}
func (p *Protocol) SetTag(tag string) {
p.tag = tag
}
func (p *Protocol) Kill() {
p.setStatus(statusKilled)
}
func defaultReadCallback(data []byte) {
glog.Trace("%s default read callback %x", logPrefix, data)
glog.Trace("[protocol] default read callback %x", data)
}
func defaultHeartbeatTimeoutCallback() bool {
glog.Trace("%s default heartbeat timeout callback", logPrefix)
glog.Trace("[protocol] default heartbeat timeout callback")
return true
}
func defaultKillCallback() {
glog.Trace("[protocol] default kill callback")
}
func GetDataMaxSize() int {
return dataMaxSize
}

View File

@ -1,6 +1,7 @@
package protocol
import (
"errors"
"fmt"
"git.viry.cc/gomod/glog"
"net"
@ -9,63 +10,95 @@ import (
"time"
)
var protServer *Protocol
var protClient *Protocol
func TestProtocol(t *testing.T) {
// SetLogProd(false)
go testServer(t)
time.Sleep(time.Second)
testClient(t)
time.Sleep(15 * time.Second)
go testClient(t)
time.Sleep(30 * time.Second)
if protServer.GetHeartbeatLastSend() == 0 {
t.Error("server.GetHeartbeatLastSend is zero")
}
if protServer.GetHeartbeatLastReceived() == 0 {
t.Error("server.GetHeartbeatLastReceived is zero")
}
if protClient.GetHeartbeatLastSend() == 0 {
t.Error("client.GetHeartbeatLastSend is zero")
}
if protClient.GetHeartbeatLastReceived() == 0 {
t.Error("client.GetHeartbeatLastReceived is zero")
}
glog.Info("kill client")
protClient.Kill()
glog.Info("wait client killed")
err := protClient.WaitKilled(60)
if err != nil {
t.Errorf("kill client failed [%d]", protClient.GetRunningRoutine())
}
glog.Info("wait client killed [%d]", protClient.GetRunningRoutine())
glog.Info("wait server killed")
err = protServer.WaitKilled(60)
if err != nil {
t.Errorf("server killed failed [%d]", protServer.GetRunningRoutine())
}
glog.Info("wait server killed [%d]", protServer.GetRunningRoutine())
}
func testServer(t *testing.T) {
listen, err := net.Listen("tcp", "0.0.0.0:9999")
if err != nil {
glog.Error("[S] Listen() failed, err: %s", err)
glog.Error("[server] Listen() failed, err: %s", err)
return
}
glog.Info("[S] Listen 0.0.0.0:9999")
glog.Info("[server] Listen 0.0.0.0:9999")
for {
conn, err := listen.Accept() // 监听客户端的连接请求
if err != nil {
glog.Error("[S] Accept() failed, err: %s", err)
glog.Error("[server] Accept() failed, err: %s", err)
continue
}
glog.Info("[S] Accept %s %s", conn.LocalAddr().String(), conn.RemoteAddr().String())
glog.Info("[server] Accept %s %s", conn.LocalAddr().String(), conn.RemoteAddr().String())
var Index uint32 = 0
prot := New(conn, conn, 8, func(data []byte) {
fmt.Printf("[S] received [%s]\n", string(data))
protServer = New("server", conn, conn, 8, func(data []byte) {
// 处理获取到的数据
fmt.Printf("[server] received [%s]\n", string(data))
atomic.AddUint32(&Index, 1)
ans := fmt.Sprintf("client msg %d", atomic.LoadUint32(&Index))
if ans != string(data) {
t.Errorf("test client error need %s got %s", ans, string(data))
}
}, func() bool {
fmt.Println("[S] heartbeat timeout")
t.Error("heartbeat timeout")
// protocol还在运行但心跳超时
fmt.Println("[server] heartbeat timeout")
return false
}, func() {
// 每次conn.Read前运行
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
}, func() {
// 每次conn.Write前运行
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
}, func() {
// protocol状态更改为killed时运行
conn.Close()
})
prot.SetHeartbeatInterval(3)
prot.SetHeartbeatTimeout(10)
prot.Connect(true)
go func() {
time.Sleep(30 * time.Second)
if prot.GetHeartbeatLastSend() == 0 {
t.Error("GetHeartbeatLastSend is zero")
}
if prot.GetHeartbeatLastReceived() == 0 {
t.Error("GetHeartbeatLastReceived is zero")
}
prot.Kill()
}()
protServer.SetHeartbeatInterval(3)
protServer.SetHeartbeatTimeout(10)
protServer.Connect(true)
i := 0
for {
time.Sleep(5 * time.Second)
i++
msg := fmt.Sprintf("server msg %d", i)
fmt.Printf("[S] send [%s]\n", msg)
err := prot.Write([]byte(msg))
if err != nil {
glog.Warning("[S] failed to write %v", err)
fmt.Printf("[server] send [%s]\n", msg)
err = protServer.Write([]byte(msg))
if err != nil && !errors.Is(err, ErrorWriterIsKilled) {
glog.Warning("[server] failed to write %v", err)
return
}
}
@ -75,47 +108,42 @@ func testServer(t *testing.T) {
func testClient(t *testing.T) {
conn, err := net.Dial("tcp", "127.0.0.1:9999")
if err != nil {
glog.Error("[C] Dial() failed, err: %s", err)
glog.Error("[client] Dial() failed, err: %s", err)
return
}
glog.Info("[C] Connected")
glog.Info("[client] Connected")
var Index uint32 = 0
prot := New(conn, conn, 8, func(data []byte) {
fmt.Printf("[C] received [%s]\n", string(data))
protClient = New("client", conn, conn, 8, func(data []byte) {
fmt.Printf("[client] received [%s]\n", string(data))
atomic.AddUint32(&Index, 1)
ans := fmt.Sprintf("server msg %d", atomic.LoadUint32(&Index))
if ans != string(data) {
t.Errorf("test client error need %s got %s", ans, string(data))
}
}, func() bool {
fmt.Println("[C] heartbeat timeout")
t.Error("heartbeat timeout")
return false
fmt.Println("[client] heartbeat timeout")
return true
}, func() {
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
}, func() {
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
}, func() {
conn.Close()
})
prot.SetHeartbeatInterval(3)
prot.SetHeartbeatTimeout(10)
prot.Connect(false)
go func() {
time.Sleep(30 * time.Second)
if prot.GetHeartbeatLastSend() == 0 {
t.Error("GetHeartbeatLastSend is zero")
}
if prot.GetHeartbeatLastReceived() == 0 {
t.Error("GetHeartbeatLastReceived is zero")
}
prot.Kill()
}()
protClient.SetHeartbeatInterval(3)
protClient.SetHeartbeatTimeout(10)
protClient.Connect(false)
time.Sleep(1 * time.Second)
i := 0
for {
time.Sleep(5 * time.Second)
i++
msg := fmt.Sprintf("client msg %d", i)
fmt.Printf("[C] send [%s]\n", msg)
err = prot.Write([]byte(msg))
if err != nil {
glog.Warning("[C] failed to write %v", err)
fmt.Printf("[client] send [%s]\n", msg)
err = protClient.Write([]byte(msg))
if err != nil && !errors.Is(err, ErrorWriterIsKilled) {
glog.Warning("[client] failed to write %v", err)
return
}
}

View File

@ -3,6 +3,7 @@ package protocol
import (
"sync"
"sync/atomic"
"time"
)
type queue struct {
@ -43,17 +44,28 @@ func newQueue(size int) *queue {
}
}
func (q *queue) push(item *protocolPackage) {
func (q *queue) push(item *protocolPackage, timeout int) (res bool) {
q.pushLock.Lock()
defer func() {
// push成功后队列大小+1
atomic.AddInt32(&q.curSize, 1)
q.pushLock.Unlock()
// 操作必定成功向End信号池发送一个信号表示完成此次push
q.poolEnd <- true
if res {
// 向End信号池发送一个信号表示完成此次push
q.poolEnd <- true
}
}()
// 操作成功代表队列不满向Start信号池发送一个信号表示开始push
q.poolStart <- true
if timeout > 0 {
select {
case q.poolStart <- true:
case <-time.After(time.Duration(timeout) * time.Second):
res = false
return
}
} else {
q.poolStart <- true
}
q.queue[q.wIndex] = item
@ -61,19 +73,32 @@ func (q *queue) push(item *protocolPackage) {
if q.wIndex >= q.maxSize {
q.wIndex = 0
}
res = true
return
}
func (q *queue) pop() (item *protocolPackage) {
func (q *queue) pop(timeout int) (item *protocolPackage) {
q.popLock.Lock()
defer func() {
// pop成功后队列大小-1
atomic.AddInt32(&q.curSize, -1)
q.popLock.Unlock()
// 操作必定成功,当前元素已经成功取出,释放当前位置
<-q.poolStart
if item != nil {
// 当前元素已经成功取出,释放当前位置
<-q.poolStart
}
}()
// 操作成功代表队列非空只有End信号池中有信号才能保证有完整的元素在队列中
<-q.poolEnd
if timeout > 0 {
select {
case <-q.poolEnd:
case <-time.After(time.Duration(timeout) * time.Second):
item = nil
return
}
} else {
<-q.poolEnd
}
item = q.queue[q.rIndex]

View File

@ -15,24 +15,24 @@ func TestQueue(t *testing.T) {
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
}
que.push(pkg1)
que.push(pkg1, 0)
if que.size() != 1 || que.isEmpty() {
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
}
pkg11 := que.pop()
pkg11 := que.pop(0)
if que.size() != 0 || !que.isEmpty() {
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
}
if pkg11.value != 1 || len(pkg11.data) != 1 || pkg11.data[0] != 1 {
t.Errorf("value:%d data:%v\n", pkg11.value, pkg11.data)
}
que.push(pkg2)
que.push(pkg2, 0)
if que.size() != 1 || que.isEmpty() {
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
}
pkg21 := que.pop()
pkg21 := que.pop(0)
if pkg21.value != 2 || len(pkg21.data) != 2 || pkg21.data[0] != 1 || pkg21.data[1] != 2 {
t.Errorf("value:%d data:%v\n", pkg21.value, pkg21.data)
}
@ -48,17 +48,17 @@ func TestQueue(t *testing.T) {
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
}
que2.push(pkg1)
que2.push(pkg1, 0)
if que2.size() != 1 {
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
}
que2.push(pkg2)
que2.push(pkg2, 0)
if que2.size() != 2 {
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
}
que2.push(pkg3)
que2.push(pkg3, 0)
if que2.size() != 3 {
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
}
@ -66,29 +66,29 @@ func TestQueue(t *testing.T) {
go func() {
time.Sleep(3 * time.Second)
fmt.Println("pop")
pkg11 := que2.pop()
pkg11 := que2.pop(0)
if pkg11.value != 1 || pkg11.data != nil {
t.Errorf("value:%d data:%v\n", pkg11.value, pkg11.data)
}
}()
fmt.Println("wait pop")
que2.push(pkg4)
que2.push(pkg4, 0)
if que2.size() != 3 {
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
}
pkg21 = que2.pop()
pkg21 = que2.pop(0)
if pkg21.value != 2 || pkg21.data != nil {
t.Errorf("value:%d data:%v\n", pkg21.value, pkg21.data)
}
pkg31 := que2.pop()
pkg31 := que2.pop(0)
if pkg31.value != 3 || pkg31.data != nil {
t.Errorf("value:%d data:%v\n", pkg31.value, pkg31.data)
}
pkg41 := que2.pop()
pkg41 := que2.pop(0)
if pkg41.value != 4 || pkg31.data != nil {
t.Errorf("value:%d data:%v\n", pkg41.value, pkg41.data)
}
@ -96,11 +96,11 @@ func TestQueue(t *testing.T) {
go func() {
time.Sleep(3 * time.Second)
fmt.Println("push")
que2.push(pkg5)
que2.push(pkg5, 0)
}()
fmt.Println("wait push")
pkg51 := que2.pop()
pkg51 := que2.pop(0)
if pkg51.value != 5 || len(pkg51.data) != 1 || pkg51.data[0] != 55 {
t.Errorf("value:%d data:%v\n", pkg51.value, pkg51.data)
}