heartbeat
This commit is contained in:
parent
51d2857220
commit
86c89364cc
114
README.md
114
README.md
|
@ -11,107 +11,11 @@
|
|||
import "git.viry.cc/gomod/protocol"
|
||||
```
|
||||
|
||||
```go
|
||||
func testServer(t *testing.T) {
|
||||
listen, err := net.Listen("tcp", "0.0.0.0:9999")
|
||||
if err != nil {
|
||||
glog.Error("[S] Listen() failed, err: %s", err)
|
||||
return
|
||||
}
|
||||
glog.Info("[S] Listen 0.0.0.0:9999")
|
||||
for {
|
||||
conn, err := listen.Accept() // 监听客户端的连接请求
|
||||
if err != nil {
|
||||
glog.Error("[S] Accept() failed, err: %s", err)
|
||||
continue
|
||||
}
|
||||
glog.Info("[S] Accept %s %s", conn.LocalAddr().String(), conn.RemoteAddr().String())
|
||||
var Index = 0
|
||||
prot := New(conn, conn, 8, func(data []byte) {
|
||||
fmt.Printf("[S] received [%s]\n", string(data))
|
||||
Index++
|
||||
if fmt.Sprintf("client msg %d", Index) != string(data) {
|
||||
t.Errorf("test client error need %s got %s", fmt.Sprintf("client msg %d", Index), string(data))
|
||||
}
|
||||
}, func() bool {
|
||||
fmt.Println("[S] heartbeat timeout")
|
||||
t.Error("heartbeat timeout")
|
||||
return false
|
||||
})
|
||||
prot.Connect(true)
|
||||
go func() {
|
||||
time.Sleep(30 * time.Second)
|
||||
if prot.GetHeartbeatLastSend() == 0 {
|
||||
t.Error("GetHeartbeatLastSend is zero")
|
||||
}
|
||||
if prot.GetHeartbeatLastReceived() == 0 {
|
||||
t.Error("GetHeartbeatLastReceived is zero")
|
||||
}
|
||||
prot.Kill()
|
||||
}()
|
||||
i := 0
|
||||
for {
|
||||
time.Sleep(5 * time.Second)
|
||||
i++
|
||||
msg := fmt.Sprintf("server msg %d", i)
|
||||
fmt.Printf("[S] send [%s]\n", msg)
|
||||
err := prot.Write([]byte(msg))
|
||||
if err != nil {
|
||||
glog.Warning("[S] failed to write %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testClient(t *testing.T) {
|
||||
conn, err := net.Dial("tcp", "127.0.0.1:9999")
|
||||
if err != nil {
|
||||
glog.Error("[C] Dial() failed, err: %s", err)
|
||||
return
|
||||
}
|
||||
glog.Info("[C] Connected")
|
||||
|
||||
var Index = 0
|
||||
prot := New(conn, conn, 8, func(data []byte) {
|
||||
fmt.Printf("[C] received [%s]\n", string(data))
|
||||
Index++
|
||||
if fmt.Sprintf("server msg %d", Index) != string(data) {
|
||||
t.Errorf("test client error need %s got %s", fmt.Sprintf("server msg %d", Index), string(data))
|
||||
}
|
||||
}, func() bool {
|
||||
fmt.Println("[C] heartbeat timeout")
|
||||
t.Error("heartbeat timeout")
|
||||
return false
|
||||
})
|
||||
prot.Connect(false)
|
||||
go func() {
|
||||
time.Sleep(30 * time.Second)
|
||||
if prot.GetHeartbeatLastSend() == 0 {
|
||||
t.Error("GetHeartbeatLastSend is zero")
|
||||
}
|
||||
if prot.GetHeartbeatLastReceived() == 0 {
|
||||
t.Error("GetHeartbeatLastReceived is zero")
|
||||
}
|
||||
prot.Kill()
|
||||
}()
|
||||
time.Sleep(1 * time.Second)
|
||||
i := 0
|
||||
for {
|
||||
time.Sleep(5 * time.Second)
|
||||
i++
|
||||
msg := fmt.Sprintf("client msg %d", i)
|
||||
fmt.Printf("[C] send [%s]\n", msg)
|
||||
err = prot.Write([]byte(msg))
|
||||
if err != nil {
|
||||
glog.Warning("[C] failed to write %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
```
|
||||
样例请参照`protocol_test.go`中的示例,示例中包含了对tcp的封装
|
||||
|
||||
```go
|
||||
// 创建protocol封装
|
||||
New(tag string, r io.Reader, w io.Writer, writeQueueSize int, readCallback func(data []byte), heartbeatTimeoutCallback func() bool, setReadDeadline, setWriteDeadline, killCallback func()) *Protocol
|
||||
// 启动传输,通过参数确定是否是心跳服务端(主动发出心跳信号一方)
|
||||
// 如果传输双方均为发出方,不会影响正常服务,但会产生不必要的心跳
|
||||
Connect(bool)
|
||||
|
@ -127,9 +31,21 @@ GetHeartbeatTimeout() uint8
|
|||
GetHeartbeatLastReceived()
|
||||
// 获取上一次发送心跳的时间
|
||||
GetHeartbeatLastSend()
|
||||
// 获取正在运行的协程的数量
|
||||
GetRunningRoutine() int32
|
||||
// 等待Protocol关闭
|
||||
WaitKilled(timeout int) error
|
||||
// 获取tag
|
||||
GetTag() string
|
||||
SetTag(tag string)
|
||||
// 关闭Protocol
|
||||
Kill()
|
||||
|
||||
// 设置log模式
|
||||
SetLogProd(isProd bool)
|
||||
SetLogMask(mask uint32)
|
||||
SetLogFlag(f uint32)
|
||||
|
||||
// Protocol版本号,不同版本存在不兼容的可能性
|
||||
protocol.VERSION
|
||||
// 获取每次Write的最大数据长度
|
||||
|
|
6
log.go
6
log.go
|
@ -1,8 +1,8 @@
|
|||
package protocol
|
||||
|
||||
import "git.viry.cc/gomod/glog"
|
||||
|
||||
const logPrefix = "[protocol]"
|
||||
import (
|
||||
"git.viry.cc/gomod/glog"
|
||||
)
|
||||
|
||||
const (
|
||||
MaskUNKNOWN = glog.MaskUNKNOWN
|
||||
|
|
34
package.go
34
package.go
|
@ -52,7 +52,9 @@ const (
|
|||
|
||||
// flag标志位
|
||||
const (
|
||||
// 普通心跳信号,心跳响应信号
|
||||
flagHeartbeat uint8 = 1 << iota
|
||||
// 心跳请求信号,接收方必须回复flagHeartbeat
|
||||
flagHeartbeatRequest
|
||||
)
|
||||
|
||||
|
@ -133,7 +135,7 @@ func (p *protocolPackage) headNeedCheckBytes() *bytes.Buffer {
|
|||
// 生成head的crc32
|
||||
func (p *protocolPackage) generateHeadCheck() {
|
||||
p.crc32 = util.NewCRC32().FromBytes(p.headNeedCheckBytes().Bytes()).Value()
|
||||
glog.Trace("%shead crc32 is %d", logPrefix, p.crc32)
|
||||
glog.Trace("[protocol_package] head crc32 is %d", p.crc32)
|
||||
}
|
||||
|
||||
// 校验head的crc32
|
||||
|
@ -144,13 +146,13 @@ func (p *protocolPackage) checkHead() bool {
|
|||
// 生成data的crc32
|
||||
func (p *protocolPackage) generateDataCheck() {
|
||||
p.dataCrc32 = util.NewCRC32().FromBytes(p.data).Value()
|
||||
glog.Trace("data crc32 is %d", p.dataCrc32)
|
||||
glog.Trace("[protocol_package] data crc32 is %d", p.dataCrc32)
|
||||
}
|
||||
|
||||
// 校验data的crc32
|
||||
func (p *protocolPackage) checkData() bool {
|
||||
if int(p.dataSize) != len(p.data) {
|
||||
glog.Trace("pkg.dataSize != len(pkg.data)")
|
||||
glog.Trace("[protocol_package] pkg.dataSize != len(pkg.data)")
|
||||
return false
|
||||
}
|
||||
return p.dataCrc32 == util.NewCRC32().FromBytes(p.data).Value()
|
||||
|
@ -159,20 +161,20 @@ func (p *protocolPackage) checkData() bool {
|
|||
// encrypt 加密data
|
||||
func (p *protocolPackage) encrypt(method uint8) {
|
||||
if p.encryptMethod == method {
|
||||
glog.Trace("is already encrypted [%d]", method)
|
||||
glog.Trace("[protocol_package] is already encrypted [%d]", method)
|
||||
return // 已经加密
|
||||
}
|
||||
if p.encryptMethod != encryptNone {
|
||||
glog.Trace("encrypt with other method got [%d] need encryptNone[%d]", p.encryptMethod, encryptNone)
|
||||
glog.Trace("[protocol_package] encrypt with other method got [%d] need encryptNone[%d]", p.encryptMethod, encryptNone)
|
||||
return // 已经通过其他方式加密
|
||||
}
|
||||
glog.Warning("unknown encrypt method")
|
||||
glog.Warning("[protocol_package] unknown encrypt method")
|
||||
}
|
||||
|
||||
// decrypt 解密data
|
||||
func (p *protocolPackage) decrypt() {
|
||||
if !p.isEncrypted() {
|
||||
glog.Trace("is not encrypted")
|
||||
glog.Trace("[protocol_package] is not encrypted")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -191,24 +193,24 @@ func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
|
|||
// 判断package的版本
|
||||
// 暂时只处理VERSION版本的package
|
||||
if buf.Len() < headOffsetVersion+headLengthVersion {
|
||||
glog.Trace("incomplete version information, need %d got %d", headOffsetVersion+headLengthVersion, buf.Len())
|
||||
glog.Trace("[protocol_package] incomplete version information, need %d got %d", headOffsetVersion+headLengthVersion, buf.Len())
|
||||
return nil, ErrorPackageIncomplete
|
||||
}
|
||||
if buf.Bytes()[headOffsetVersion] != VERSION {
|
||||
glog.Trace("unsupported version need %d got %d", VERSION, buf.Bytes()[headOffsetVersion])
|
||||
glog.Trace("[protocol_package] unsupported version need %d got %d", VERSION, buf.Bytes()[headOffsetVersion])
|
||||
nextPackageHead(buf)
|
||||
return nil, ErrorUnsupportedVersion
|
||||
}
|
||||
// 开始判断是否为package并提取package
|
||||
if buf.Len() < packageHeadSize {
|
||||
glog.Trace("incomplete head, need %d got %d", packageHeadSize, buf.Len())
|
||||
glog.Trace("[protocol_package] incomplete head, need %d got %d", packageHeadSize, buf.Len())
|
||||
return nil, ErrorPackageIncomplete
|
||||
}
|
||||
head := make([]byte, packageHeadSize)
|
||||
copy(head, buf.Bytes()[:packageHeadSize])
|
||||
// 协议头标志不匹配,删除未知数据,寻找下一个package起始位置
|
||||
if !bytes.Equal(prefix[:], head[headOffsetPrefix:headOffsetPrefix+headLengthPrefix]) {
|
||||
glog.Trace("prefix does not match, need %v got %v", prefix, head[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
|
||||
glog.Trace("[protocol_package] prefix does not match, need %v got %v", prefix, head[headOffsetPrefix:headOffsetPrefix+headLengthPrefix])
|
||||
nextPackageHead(buf)
|
||||
return nil, ErrorWrongPrefix
|
||||
}
|
||||
|
@ -216,14 +218,14 @@ func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
|
|||
headChecksum := util.BytesSliceToUInt32(head[headOffsetCRC32Checksum : headOffsetCRC32Checksum+headLengthCRC32Checksum])
|
||||
headCrc32 := util.NewCRC32().FromBytes(head[headOffsetNeedCheck:]).Value()
|
||||
if headChecksum != headCrc32 {
|
||||
glog.Trace("head crc32 checksum does not match, need %d got %d", headChecksum, headCrc32)
|
||||
glog.Trace("[protocol_package] head crc32 checksum does not match, need %d got %d", headChecksum, headCrc32)
|
||||
nextPackageHead(buf)
|
||||
return nil, ErrorBrokenHead
|
||||
}
|
||||
// 检查package是否完整,不完整则等待
|
||||
packageDataSize := util.BytesSliceToUInt32(head[headOffsetDataSize : headOffsetDataSize+headLengthDataSize])
|
||||
if packageHeadSize+int(packageDataSize) > buf.Len() {
|
||||
glog.Trace("incomplete data, need %d got %d", packageHeadSize+packageDataSize, buf.Len())
|
||||
glog.Trace("[protocol_package] incomplete data, need %d got %d", packageHeadSize+packageDataSize, buf.Len())
|
||||
return nil, ErrorPackageIncomplete
|
||||
}
|
||||
// package完整
|
||||
|
@ -250,7 +252,7 @@ func parsePackage(buf *bytes.Buffer) (*protocolPackage, error) {
|
|||
_, _ = buf.Read(pkg.data)
|
||||
dataCrc32 := util.NewCRC32().FromBytes(pkg.data).Value()
|
||||
if pkg.dataCrc32 != dataCrc32 {
|
||||
glog.Trace("data crc32 checksum does not match, need %d got %d", pkg.dataCrc32, dataCrc32)
|
||||
glog.Trace("[protocol_package] data crc32 checksum does not match, need %d got %d", pkg.dataCrc32, dataCrc32)
|
||||
nextPackageHead(buf)
|
||||
return nil, ErrorBrokenData
|
||||
}
|
||||
|
@ -264,9 +266,9 @@ func nextPackageHead(buf *bytes.Buffer) {
|
|||
_, err = buf.ReadBytes(prefix[0]) // 只搜索与prefix[0]相同元素,防止prefix[0]出现在buf末尾
|
||||
if err == nil { // 找到下一个协议头标志,把删掉的prefix回退到buffer
|
||||
_ = buf.UnreadByte()
|
||||
glog.Trace("prefix does not match, prefix[0] found, trim buf, buf length [%d]", buf.Len())
|
||||
glog.Trace("[protocol_package] prefix does not match, prefix[0] found, trim buf, buf length [%d]", buf.Len())
|
||||
} else { // 找不到下一个协议头标志,清空buffer
|
||||
buf.Reset()
|
||||
glog.Trace("prefix does not match, prefix[0] not found, reset buf, buf length [%d]", buf.Len())
|
||||
glog.Trace("[protocol_package] prefix does not match, prefix[0] not found, reset buf, buf length [%d]", buf.Len())
|
||||
}
|
||||
}
|
||||
|
|
266
protocol.go
266
protocol.go
|
@ -21,6 +21,7 @@ var ErrorWriterQueueIsNil = errors.New("writer queue is nil")
|
|||
var ErrorHeartbeatIsKilled = errors.New("heartbeat is killed")
|
||||
var ErrorHeartbeatCallbackIsNil = errors.New("heartbeat callback is nil")
|
||||
var ErrorDataSizeExceedsLimit = errors.New("data size exceeds limit")
|
||||
var ErrorTimeout = errors.New("timeout")
|
||||
|
||||
const (
|
||||
statusRunning int32 = iota
|
||||
|
@ -28,49 +29,87 @@ const (
|
|||
)
|
||||
|
||||
type Protocol struct {
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
// 标记protocol
|
||||
tag string
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
|
||||
status int32
|
||||
// protocol的状态
|
||||
status int32
|
||||
// 用于处理获取到的数据,每个package中的数据都会完整的保存在data中
|
||||
readCallback func(data []byte)
|
||||
writeQueue *queue
|
||||
// 写入等待队列
|
||||
writeQueue *queue
|
||||
// 当前protocol正在运行的协程数量
|
||||
runningRoutines int32
|
||||
|
||||
heartbeatSig chan uint8
|
||||
heartbeatSigReq chan uint8
|
||||
heartbeatInterval uint32
|
||||
heartbeatTimeout uint32
|
||||
// 心跳信号,同时也是心跳响应信号
|
||||
heartbeatSig chan uint8
|
||||
// 心跳请求信号,收到此信号必须回复对方
|
||||
heartbeatSigReq chan uint8
|
||||
// 发送心跳请求的间隔
|
||||
heartbeatInterval uint32
|
||||
// 接收心跳请求的超时时间
|
||||
heartbeatTimeout uint32
|
||||
// 心跳请求超时后的处理函数
|
||||
heartbeatTimeoutCallback func() bool
|
||||
heartbeatLastSend int64
|
||||
heartbeatLastReceived int64
|
||||
// 上次发送心跳的时间
|
||||
heartbeatLastSend int64
|
||||
// 上次收到心跳的时间
|
||||
heartbeatLastReceived int64
|
||||
|
||||
// status被标记为statusKilled时执行,可以用于关闭reader和writer
|
||||
killCallback func()
|
||||
// 在reader读取数据前,设置reader的读取截止时间
|
||||
setReadDeadline func()
|
||||
// 在writer读取数据前,设置writer的读取截止时间
|
||||
setWriteDeadline func()
|
||||
}
|
||||
|
||||
func New(r io.Reader, w io.Writer, writeQueueSize int, readCallback func(data []byte), heartbeatTimeoutCallback func() bool) *Protocol {
|
||||
// New 返回一个protocol实例
|
||||
//
|
||||
// tag: 标签,用于区分protocol实例
|
||||
// r: 数据流的reader
|
||||
// w: 数据流的writer
|
||||
// writeQueueSize: 发送等待队列长度
|
||||
// readCallback: 用于处理获取到的数据,每个package中的数据都会完整的保存在data中
|
||||
// heartbeatTimeoutCallback: 心跳请求超时后的处理函数
|
||||
// setReadDeadline: 在reader读取数据前,设置reader的读取截止时间
|
||||
// setWriteDeadline: 在writer读取数据前,设置writer的读取截止时间
|
||||
// killCallback: status被标记为statusKilled时执行,可以用于关闭reader和writer
|
||||
func New(tag string, r io.Reader, w io.Writer, writeQueueSize int, readCallback func(data []byte), heartbeatTimeoutCallback func() bool, setReadDeadline, setWriteDeadline, killCallback func()) *Protocol {
|
||||
if r == nil {
|
||||
glog.Warning("%s reader is nil", logPrefix)
|
||||
glog.Warning("[protocol.%s] reader is nil", tag)
|
||||
return nil
|
||||
}
|
||||
if w == nil {
|
||||
glog.Warning("%s writer is nil", logPrefix)
|
||||
glog.Warning("[protocol.%s] writer is nil", tag)
|
||||
return nil
|
||||
}
|
||||
if writeQueueSize < 1 {
|
||||
glog.Trace("%s writeQueueSize is < 1, use 1", logPrefix)
|
||||
glog.Trace("[protocol.%s] writeQueueSize is < 1, use 1", tag)
|
||||
writeQueueSize = 1
|
||||
}
|
||||
if readCallback == nil {
|
||||
glog.Trace("%s readCallback is nil, use defaultReadCallback", logPrefix)
|
||||
glog.Trace("[protocol.%s] readCallback is nil, use defaultReadCallback", tag)
|
||||
readCallback = defaultReadCallback
|
||||
}
|
||||
if heartbeatTimeoutCallback == nil {
|
||||
glog.Trace("%s heartbeatTimeoutCallback is nil, use defaultHeartbeatTimeoutCallback", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeatTimeoutCallback is nil, use defaultHeartbeatTimeoutCallback", tag)
|
||||
heartbeatTimeoutCallback = defaultHeartbeatTimeoutCallback
|
||||
}
|
||||
if killCallback == nil {
|
||||
glog.Trace("[protocol.%s] killCallback is nil, use defaultKillCallback", tag)
|
||||
killCallback = defaultKillCallback
|
||||
}
|
||||
return &Protocol{
|
||||
tag: tag,
|
||||
r: r,
|
||||
w: w,
|
||||
status: statusRunning,
|
||||
readCallback: readCallback,
|
||||
writeQueue: newQueue(writeQueueSize),
|
||||
runningRoutines: 0,
|
||||
heartbeatSig: make(chan uint8, 1),
|
||||
heartbeatSigReq: make(chan uint8, 1),
|
||||
heartbeatInterval: 15,
|
||||
|
@ -78,6 +117,9 @@ func New(r io.Reader, w io.Writer, writeQueueSize int, readCallback func(data []
|
|||
heartbeatTimeoutCallback: heartbeatTimeoutCallback,
|
||||
heartbeatLastSend: 0,
|
||||
heartbeatLastReceived: 0,
|
||||
killCallback: killCallback,
|
||||
setReadDeadline: setReadDeadline,
|
||||
setWriteDeadline: setWriteDeadline,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,46 +133,48 @@ func (p *Protocol) Connect(activeHeartbeatSignalSender bool) {
|
|||
}
|
||||
|
||||
func (p *Protocol) handlePackage(pkg *protocolPackage) {
|
||||
glog.Trace("%s handle package", logPrefix)
|
||||
glog.Trace("[protocol.%s] handle package", p.tag)
|
||||
if pkg == nil {
|
||||
glog.Trace("%s package is nil", logPrefix)
|
||||
glog.Trace("[protocol.%s] package is nil", p.tag)
|
||||
return
|
||||
}
|
||||
if pkg.isEncrypted() {
|
||||
glog.Trace("%s package is encrypted, decrypt package", logPrefix)
|
||||
glog.Trace("[protocol.%s] package is encrypted, decrypt package", p.tag)
|
||||
pkg.decrypt()
|
||||
}
|
||||
if !pkg.checkHead() {
|
||||
glog.Trace("%s package head broken", logPrefix)
|
||||
glog.Trace("[protocol.%s] package head broken", p.tag)
|
||||
return
|
||||
}
|
||||
if (pkg.flag & flagHeartbeat) != 0 {
|
||||
glog.Trace("%s heartbeat signal in package", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeat signal in package", p.tag)
|
||||
p.heartbeatSig <- pkg.value
|
||||
}
|
||||
if (pkg.flag & flagHeartbeatRequest) != 0 {
|
||||
glog.Trace("%s heartbeat request signal in package", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeat request signal in package", p.tag)
|
||||
p.heartbeatSigReq <- pkg.value
|
||||
}
|
||||
if !pkg.checkData() {
|
||||
glog.Trace("%s package data broken", logPrefix)
|
||||
glog.Trace("[protocol.%s] package data broken", p.tag)
|
||||
return
|
||||
}
|
||||
if pkg.dataSize == 0 {
|
||||
glog.Trace("%s package data empty", logPrefix)
|
||||
glog.Trace("[protocol.%s] package data empty", p.tag)
|
||||
return
|
||||
}
|
||||
glog.Trace("%s handle package successful", logPrefix)
|
||||
glog.Trace("[protocol.%s] handle package successful", p.tag)
|
||||
p.readCallback(pkg.data)
|
||||
}
|
||||
|
||||
// Reader 阻塞接收数据并提交给readCallback
|
||||
func (p *Protocol) reader() {
|
||||
glog.Trace("%s reader enable", logPrefix)
|
||||
glog.Trace("[protocol.%s] reader enable", p.tag)
|
||||
if p.r == nil {
|
||||
glog.Warning("%s reader is not ready", logPrefix)
|
||||
glog.Warning("[protocol.%s] reader is not ready", p.tag)
|
||||
return
|
||||
}
|
||||
p.incRunningRoutine()
|
||||
defer p.decRunningRoutine()
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
buf := make([]byte, packageMaxSize)
|
||||
|
@ -139,31 +183,36 @@ func (p *Protocol) reader() {
|
|||
// 监听并接收数据
|
||||
for {
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Trace("%s reader is killed", logPrefix)
|
||||
glog.Trace("[protocol.%s] reader is killed", p.tag)
|
||||
return
|
||||
}
|
||||
if p.setReadDeadline != nil {
|
||||
glog.Trace("[protocol.%s] reader set deadline", p.tag)
|
||||
p.setReadDeadline()
|
||||
}
|
||||
glog.Trace("[protocol.%s] reader wait read", p.tag)
|
||||
n, err = p.r.Read(buf)
|
||||
if err != nil {
|
||||
glog.Trace("%s read error %v", logPrefix, err)
|
||||
glog.Trace("[protocol.%s] read error %v", p.tag, err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if n == 0 {
|
||||
glog.Trace("%s read empty", logPrefix)
|
||||
glog.Trace("[protocol.%s] read empty", p.tag)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
n, err = buffer.Write(buf[:n])
|
||||
glog.Trace("%s write %d bytes, buffer already %d bytes, error is %v", logPrefix, n, buffer.Len(), err)
|
||||
glog.Trace("[protocol.%s] write %d bytes, buffer already %d bytes, error is %v", p.tag, n, buffer.Len(), err)
|
||||
for buffer.Len() >= packageHeadSize {
|
||||
glog.Trace("%s complete package, buffer length %d", logPrefix, buffer.Len())
|
||||
glog.Trace("[protocol.%s] complete package, buffer length %d", p.tag, buffer.Len())
|
||||
pkg, err := parsePackage(buffer)
|
||||
if errors.Is(err, ErrorPackageIncomplete) {
|
||||
glog.Trace("%s incomplete package, buffer length %d", logPrefix, buffer.Len())
|
||||
glog.Trace("[protocol.%s] incomplete package, buffer length %d", p.tag, buffer.Len())
|
||||
break
|
||||
}
|
||||
if pkg != nil {
|
||||
glog.Trace("%s receive new package", logPrefix)
|
||||
glog.Trace("[protocol.%s] receive new package", p.tag)
|
||||
go p.handlePackage(pkg)
|
||||
}
|
||||
}
|
||||
|
@ -172,50 +221,64 @@ func (p *Protocol) reader() {
|
|||
|
||||
// Writer 创建发送队列并监听待发送数据
|
||||
func (p *Protocol) writer() {
|
||||
glog.Trace("%s writer enable", logPrefix)
|
||||
glog.Trace("[protocol.%s] writer enable", p.tag)
|
||||
if p.w == nil {
|
||||
glog.Warning("%s writer is not ready", logPrefix)
|
||||
glog.Warning("[protocol.%s] writer is not ready", p.tag)
|
||||
return
|
||||
}
|
||||
p.incRunningRoutine()
|
||||
defer p.decRunningRoutine()
|
||||
|
||||
var err error
|
||||
var n int
|
||||
for {
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Trace("%s writer is killed", logPrefix)
|
||||
glog.Trace("[protocol.%s] writer is killed", p.tag)
|
||||
return
|
||||
}
|
||||
pkg := p.writeQueue.pop()
|
||||
glog.Trace("[protocol.%s] writer wait pop", p.tag)
|
||||
pkg := p.writeQueue.pop(int(p.GetHeartbeatInterval()))
|
||||
if pkg == nil {
|
||||
glog.Trace("[protocol.%s] writer pop timeout", p.tag)
|
||||
continue
|
||||
}
|
||||
if p.setWriteDeadline != nil {
|
||||
glog.Trace("[protocol.%s] writer set deadline", p.tag)
|
||||
p.setWriteDeadline()
|
||||
}
|
||||
glog.Trace("[protocol.%s] writer wait write", p.tag)
|
||||
n, err = p.w.Write(pkg.Bytes().Bytes())
|
||||
glog.Trace("%s write %d bytes, error is %v", logPrefix, n, err)
|
||||
glog.Trace("[protocol.%s] write %d bytes, error is %v", p.tag, n, err)
|
||||
if err != nil {
|
||||
glog.Trace("%s send package failed, re-push package", logPrefix)
|
||||
go p.writeQueue.push(pkg)
|
||||
glog.Trace("[protocol.%s] send package failed, re-push package", p.tag)
|
||||
time.Sleep(time.Second)
|
||||
for !p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) {
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Trace("[protocol.%s] writer is killed", p.tag)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write 发送数据
|
||||
func (p *Protocol) Write(data []byte) error {
|
||||
glog.Trace("%s write", logPrefix)
|
||||
glog.Trace("[protocol.%s] write", p.tag)
|
||||
if len(data) > dataMaxSize {
|
||||
glog.Warning("%s maximum supported data size exceeded", logPrefix)
|
||||
glog.Info("[protocol.%s] maximum supported data size exceeded", p.tag)
|
||||
return ErrorDataSizeExceedsLimit
|
||||
}
|
||||
if p.w == nil {
|
||||
glog.Warning("%s writer is not ready", logPrefix)
|
||||
return ErrorWriterIsNil
|
||||
}
|
||||
if p.writeQueue == nil {
|
||||
glog.Warning("%s protocol is not ready", logPrefix)
|
||||
return ErrorWriterQueueIsNil
|
||||
}
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Warning("%s protocol is killed", logPrefix)
|
||||
return ErrorWriterIsKilled
|
||||
}
|
||||
pkg := newPackage(0, encryptNone, 0, data)
|
||||
p.writeQueue.push(pkg)
|
||||
return nil
|
||||
for {
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Info("[protocol.%s] protocol is killed", p.tag)
|
||||
return ErrorWriterIsKilled
|
||||
}
|
||||
if p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat 心跳服务
|
||||
|
@ -223,32 +286,39 @@ func (p *Protocol) Write(data []byte) error {
|
|||
// heartbeatTimeout: 被动接收心跳信号的超时时间(s),最小为3s,传入参数小于3时使用默认值30
|
||||
// heartbeatTimeoutCallback: 没有按时收到心跳信号时调用,返回true继续等待,返回false退出
|
||||
func (p *Protocol) heartbeat() {
|
||||
glog.Trace("%s heartbeat enable", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeat enable", p.tag)
|
||||
p.incRunningRoutine()
|
||||
defer p.decRunningRoutine()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Duration(p.GetHeartbeatTimeout()) * time.Second):
|
||||
glog.Trace("%s heartbeat timeout", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeat timeout", p.tag)
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Trace("[protocol.%s] heartbeat is killed", p.tag)
|
||||
return
|
||||
}
|
||||
if !p.heartbeatTimeoutCallback() {
|
||||
glog.Trace("%s heartbeat is killed, set status killed", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeat is killed, set status killed", p.tag)
|
||||
p.setStatus(statusKilled)
|
||||
return
|
||||
}
|
||||
case val := <-p.heartbeatSigReq:
|
||||
glog.Trace("%s heartbeat request signal received", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeat request signal received", p.tag)
|
||||
p.setHeartbeatLastReceived()
|
||||
p.sendHeartbeatSignal(false)
|
||||
if val != 0 {
|
||||
p.SetHeartbeatTimeout(val)
|
||||
}
|
||||
case val := <-p.heartbeatSig:
|
||||
glog.Trace("%s heartbeat signal received", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeat signal received", p.tag)
|
||||
p.setHeartbeatLastReceived()
|
||||
if val != 0 {
|
||||
p.SetHeartbeatTimeout(val)
|
||||
}
|
||||
}
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Trace("%s heartbeat is killed", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeat is killed", p.tag)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -258,9 +328,11 @@ func (p *Protocol) heartbeat() {
|
|||
//
|
||||
// heartbeatInterval: 主动发送心跳信号的间隔时间(s),最小为3s,传入参数小于3时使用默认值3
|
||||
func (p *Protocol) heartbeatSignalSender() {
|
||||
p.incRunningRoutine()
|
||||
defer p.decRunningRoutine()
|
||||
for {
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Trace("%s heartbeat signal sender is killed", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeat signal sender is killed", p.tag)
|
||||
return
|
||||
}
|
||||
p.sendHeartbeatSignal(true)
|
||||
|
@ -269,28 +341,38 @@ func (p *Protocol) heartbeatSignalSender() {
|
|||
}
|
||||
|
||||
func (p *Protocol) sendHeartbeatSignal(isReq bool) {
|
||||
glog.Trace("%s send heartbeat signal", logPrefix)
|
||||
glog.Trace("[protocol.%s] send heartbeat signal", p.tag)
|
||||
var pkg *protocolPackage
|
||||
if isReq {
|
||||
p.writeQueue.push(newPackage(flagHeartbeatRequest, encryptNone, 0, nil))
|
||||
pkg = newPackage(flagHeartbeatRequest, encryptNone, 0, nil)
|
||||
} else {
|
||||
p.writeQueue.push(newPackage(flagHeartbeat, encryptNone, 0, nil))
|
||||
pkg = newPackage(flagHeartbeat, encryptNone, 0, nil)
|
||||
}
|
||||
for !p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) {
|
||||
if p.getStatus() == statusKilled {
|
||||
glog.Info("[protocol.%s] protocol is killed", p.tag)
|
||||
return
|
||||
}
|
||||
}
|
||||
p.setHeartbeatLastSend()
|
||||
}
|
||||
|
||||
func (p *Protocol) setStatus(status int32) {
|
||||
glog.Trace("%s set status %d", logPrefix, status)
|
||||
glog.Trace("[protocol.%s] set status %d", p.tag, status)
|
||||
if status == statusKilled {
|
||||
p.killCallback()
|
||||
}
|
||||
atomic.StoreInt32(&p.status, status)
|
||||
}
|
||||
|
||||
func (p *Protocol) getStatus() int32 {
|
||||
glog.Trace("%s get status", logPrefix)
|
||||
glog.Trace("[protocol.%s] get status", p.tag)
|
||||
return atomic.LoadInt32(&p.status)
|
||||
}
|
||||
|
||||
func (p *Protocol) SetHeartbeatInterval(interval uint8) {
|
||||
if interval < 3 {
|
||||
glog.Trace("%s heartbeatInterval is < 3, use 3", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeatInterval is < 3, use 3", p.tag)
|
||||
interval = 3
|
||||
}
|
||||
atomic.StoreUint32(&p.heartbeatInterval, uint32(interval))
|
||||
|
@ -302,7 +384,7 @@ func (p *Protocol) GetHeartbeatInterval() uint8 {
|
|||
|
||||
func (p *Protocol) SetHeartbeatTimeout(timeout uint8) {
|
||||
if timeout < 6 {
|
||||
glog.Trace("%s heartbeatTimeout is < 6, use 6", logPrefix)
|
||||
glog.Trace("[protocol.%s] heartbeatTimeout is < 6, use 6", p.tag)
|
||||
timeout = 6
|
||||
}
|
||||
atomic.StoreUint32(&p.heartbeatTimeout, uint32(timeout))
|
||||
|
@ -328,19 +410,57 @@ func (p *Protocol) GetHeartbeatLastSend() int64 {
|
|||
return atomic.LoadInt64(&p.heartbeatLastSend)
|
||||
}
|
||||
|
||||
func (p *Protocol) incRunningRoutine() {
|
||||
atomic.AddInt32(&p.runningRoutines, 1)
|
||||
}
|
||||
|
||||
func (p *Protocol) decRunningRoutine() {
|
||||
atomic.AddInt32(&p.runningRoutines, -1)
|
||||
}
|
||||
|
||||
func (p *Protocol) GetRunningRoutine() int32 {
|
||||
return atomic.LoadInt32(&p.runningRoutines)
|
||||
}
|
||||
|
||||
func (p *Protocol) WaitKilled(timeout int) error {
|
||||
out := time.After(time.Duration(timeout) * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-out:
|
||||
return ErrorTimeout
|
||||
case <-time.After(time.Second):
|
||||
if p.GetRunningRoutine() <= 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Protocol) GetTag() string {
|
||||
return p.tag
|
||||
}
|
||||
|
||||
func (p *Protocol) SetTag(tag string) {
|
||||
p.tag = tag
|
||||
}
|
||||
|
||||
func (p *Protocol) Kill() {
|
||||
p.setStatus(statusKilled)
|
||||
}
|
||||
|
||||
func defaultReadCallback(data []byte) {
|
||||
glog.Trace("%s default read callback %x", logPrefix, data)
|
||||
glog.Trace("[protocol] default read callback %x", data)
|
||||
}
|
||||
|
||||
func defaultHeartbeatTimeoutCallback() bool {
|
||||
glog.Trace("%s default heartbeat timeout callback", logPrefix)
|
||||
glog.Trace("[protocol] default heartbeat timeout callback")
|
||||
return true
|
||||
}
|
||||
|
||||
func defaultKillCallback() {
|
||||
glog.Trace("[protocol] default kill callback")
|
||||
}
|
||||
|
||||
func GetDataMaxSize() int {
|
||||
return dataMaxSize
|
||||
}
|
||||
|
|
130
protocol_test.go
130
protocol_test.go
|
@ -1,6 +1,7 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.viry.cc/gomod/glog"
|
||||
"net"
|
||||
|
@ -9,63 +10,95 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
var protServer *Protocol
|
||||
var protClient *Protocol
|
||||
|
||||
func TestProtocol(t *testing.T) {
|
||||
// SetLogProd(false)
|
||||
go testServer(t)
|
||||
time.Sleep(time.Second)
|
||||
testClient(t)
|
||||
time.Sleep(15 * time.Second)
|
||||
go testClient(t)
|
||||
time.Sleep(30 * time.Second)
|
||||
|
||||
if protServer.GetHeartbeatLastSend() == 0 {
|
||||
t.Error("server.GetHeartbeatLastSend is zero")
|
||||
}
|
||||
if protServer.GetHeartbeatLastReceived() == 0 {
|
||||
t.Error("server.GetHeartbeatLastReceived is zero")
|
||||
}
|
||||
|
||||
if protClient.GetHeartbeatLastSend() == 0 {
|
||||
t.Error("client.GetHeartbeatLastSend is zero")
|
||||
}
|
||||
if protClient.GetHeartbeatLastReceived() == 0 {
|
||||
t.Error("client.GetHeartbeatLastReceived is zero")
|
||||
}
|
||||
|
||||
glog.Info("kill client")
|
||||
protClient.Kill()
|
||||
glog.Info("wait client killed")
|
||||
err := protClient.WaitKilled(60)
|
||||
if err != nil {
|
||||
t.Errorf("kill client failed [%d]", protClient.GetRunningRoutine())
|
||||
}
|
||||
glog.Info("wait client killed [%d]", protClient.GetRunningRoutine())
|
||||
glog.Info("wait server killed")
|
||||
err = protServer.WaitKilled(60)
|
||||
if err != nil {
|
||||
t.Errorf("server killed failed [%d]", protServer.GetRunningRoutine())
|
||||
}
|
||||
glog.Info("wait server killed [%d]", protServer.GetRunningRoutine())
|
||||
}
|
||||
|
||||
func testServer(t *testing.T) {
|
||||
listen, err := net.Listen("tcp", "0.0.0.0:9999")
|
||||
if err != nil {
|
||||
glog.Error("[S] Listen() failed, err: %s", err)
|
||||
glog.Error("[server] Listen() failed, err: %s", err)
|
||||
return
|
||||
}
|
||||
glog.Info("[S] Listen 0.0.0.0:9999")
|
||||
glog.Info("[server] Listen 0.0.0.0:9999")
|
||||
for {
|
||||
conn, err := listen.Accept() // 监听客户端的连接请求
|
||||
if err != nil {
|
||||
glog.Error("[S] Accept() failed, err: %s", err)
|
||||
glog.Error("[server] Accept() failed, err: %s", err)
|
||||
continue
|
||||
}
|
||||
glog.Info("[S] Accept %s %s", conn.LocalAddr().String(), conn.RemoteAddr().String())
|
||||
glog.Info("[server] Accept %s %s", conn.LocalAddr().String(), conn.RemoteAddr().String())
|
||||
var Index uint32 = 0
|
||||
prot := New(conn, conn, 8, func(data []byte) {
|
||||
fmt.Printf("[S] received [%s]\n", string(data))
|
||||
protServer = New("server", conn, conn, 8, func(data []byte) {
|
||||
// 处理获取到的数据
|
||||
fmt.Printf("[server] received [%s]\n", string(data))
|
||||
atomic.AddUint32(&Index, 1)
|
||||
ans := fmt.Sprintf("client msg %d", atomic.LoadUint32(&Index))
|
||||
if ans != string(data) {
|
||||
t.Errorf("test client error need %s got %s", ans, string(data))
|
||||
}
|
||||
}, func() bool {
|
||||
fmt.Println("[S] heartbeat timeout")
|
||||
t.Error("heartbeat timeout")
|
||||
// protocol还在运行,但心跳超时
|
||||
fmt.Println("[server] heartbeat timeout")
|
||||
return false
|
||||
}, func() {
|
||||
// 每次conn.Read前运行
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
}, func() {
|
||||
// 每次conn.Write前运行
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
}, func() {
|
||||
// protocol状态更改为killed时运行
|
||||
conn.Close()
|
||||
})
|
||||
prot.SetHeartbeatInterval(3)
|
||||
prot.SetHeartbeatTimeout(10)
|
||||
prot.Connect(true)
|
||||
go func() {
|
||||
time.Sleep(30 * time.Second)
|
||||
if prot.GetHeartbeatLastSend() == 0 {
|
||||
t.Error("GetHeartbeatLastSend is zero")
|
||||
}
|
||||
if prot.GetHeartbeatLastReceived() == 0 {
|
||||
t.Error("GetHeartbeatLastReceived is zero")
|
||||
}
|
||||
prot.Kill()
|
||||
}()
|
||||
protServer.SetHeartbeatInterval(3)
|
||||
protServer.SetHeartbeatTimeout(10)
|
||||
protServer.Connect(true)
|
||||
i := 0
|
||||
for {
|
||||
time.Sleep(5 * time.Second)
|
||||
i++
|
||||
msg := fmt.Sprintf("server msg %d", i)
|
||||
fmt.Printf("[S] send [%s]\n", msg)
|
||||
err := prot.Write([]byte(msg))
|
||||
if err != nil {
|
||||
glog.Warning("[S] failed to write %v", err)
|
||||
fmt.Printf("[server] send [%s]\n", msg)
|
||||
err = protServer.Write([]byte(msg))
|
||||
if err != nil && !errors.Is(err, ErrorWriterIsKilled) {
|
||||
glog.Warning("[server] failed to write %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -75,47 +108,42 @@ func testServer(t *testing.T) {
|
|||
func testClient(t *testing.T) {
|
||||
conn, err := net.Dial("tcp", "127.0.0.1:9999")
|
||||
if err != nil {
|
||||
glog.Error("[C] Dial() failed, err: %s", err)
|
||||
glog.Error("[client] Dial() failed, err: %s", err)
|
||||
return
|
||||
}
|
||||
glog.Info("[C] Connected")
|
||||
glog.Info("[client] Connected")
|
||||
|
||||
var Index uint32 = 0
|
||||
prot := New(conn, conn, 8, func(data []byte) {
|
||||
fmt.Printf("[C] received [%s]\n", string(data))
|
||||
protClient = New("client", conn, conn, 8, func(data []byte) {
|
||||
fmt.Printf("[client] received [%s]\n", string(data))
|
||||
atomic.AddUint32(&Index, 1)
|
||||
ans := fmt.Sprintf("server msg %d", atomic.LoadUint32(&Index))
|
||||
if ans != string(data) {
|
||||
t.Errorf("test client error need %s got %s", ans, string(data))
|
||||
}
|
||||
}, func() bool {
|
||||
fmt.Println("[C] heartbeat timeout")
|
||||
t.Error("heartbeat timeout")
|
||||
return false
|
||||
fmt.Println("[client] heartbeat timeout")
|
||||
return true
|
||||
}, func() {
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
}, func() {
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
}, func() {
|
||||
conn.Close()
|
||||
})
|
||||
prot.SetHeartbeatInterval(3)
|
||||
prot.SetHeartbeatTimeout(10)
|
||||
prot.Connect(false)
|
||||
go func() {
|
||||
time.Sleep(30 * time.Second)
|
||||
if prot.GetHeartbeatLastSend() == 0 {
|
||||
t.Error("GetHeartbeatLastSend is zero")
|
||||
}
|
||||
if prot.GetHeartbeatLastReceived() == 0 {
|
||||
t.Error("GetHeartbeatLastReceived is zero")
|
||||
}
|
||||
prot.Kill()
|
||||
}()
|
||||
protClient.SetHeartbeatInterval(3)
|
||||
protClient.SetHeartbeatTimeout(10)
|
||||
protClient.Connect(false)
|
||||
time.Sleep(1 * time.Second)
|
||||
i := 0
|
||||
for {
|
||||
time.Sleep(5 * time.Second)
|
||||
i++
|
||||
msg := fmt.Sprintf("client msg %d", i)
|
||||
fmt.Printf("[C] send [%s]\n", msg)
|
||||
err = prot.Write([]byte(msg))
|
||||
if err != nil {
|
||||
glog.Warning("[C] failed to write %v", err)
|
||||
fmt.Printf("[client] send [%s]\n", msg)
|
||||
err = protClient.Write([]byte(msg))
|
||||
if err != nil && !errors.Is(err, ErrorWriterIsKilled) {
|
||||
glog.Warning("[client] failed to write %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
41
queue.go
41
queue.go
|
@ -3,6 +3,7 @@ package protocol
|
|||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type queue struct {
|
||||
|
@ -43,17 +44,28 @@ func newQueue(size int) *queue {
|
|||
}
|
||||
}
|
||||
|
||||
func (q *queue) push(item *protocolPackage) {
|
||||
func (q *queue) push(item *protocolPackage, timeout int) (res bool) {
|
||||
q.pushLock.Lock()
|
||||
defer func() {
|
||||
// push成功后队列大小+1
|
||||
atomic.AddInt32(&q.curSize, 1)
|
||||
q.pushLock.Unlock()
|
||||
// 操作必定成功,向End信号池发送一个信号,表示完成此次push
|
||||
q.poolEnd <- true
|
||||
if res {
|
||||
// 向End信号池发送一个信号,表示完成此次push
|
||||
q.poolEnd <- true
|
||||
}
|
||||
}()
|
||||
// 操作成功代表队列不满,向Start信号池发送一个信号,表示开始push
|
||||
q.poolStart <- true
|
||||
if timeout > 0 {
|
||||
select {
|
||||
case q.poolStart <- true:
|
||||
case <-time.After(time.Duration(timeout) * time.Second):
|
||||
res = false
|
||||
return
|
||||
}
|
||||
} else {
|
||||
q.poolStart <- true
|
||||
}
|
||||
|
||||
q.queue[q.wIndex] = item
|
||||
|
||||
|
@ -61,19 +73,32 @@ func (q *queue) push(item *protocolPackage) {
|
|||
if q.wIndex >= q.maxSize {
|
||||
q.wIndex = 0
|
||||
}
|
||||
res = true
|
||||
return
|
||||
}
|
||||
|
||||
func (q *queue) pop() (item *protocolPackage) {
|
||||
func (q *queue) pop(timeout int) (item *protocolPackage) {
|
||||
q.popLock.Lock()
|
||||
defer func() {
|
||||
// pop成功后队列大小-1
|
||||
atomic.AddInt32(&q.curSize, -1)
|
||||
q.popLock.Unlock()
|
||||
// 操作必定成功,当前元素已经成功取出,释放当前位置
|
||||
<-q.poolStart
|
||||
if item != nil {
|
||||
// 当前元素已经成功取出,释放当前位置
|
||||
<-q.poolStart
|
||||
}
|
||||
}()
|
||||
// 操作成功代表队列非空,只有End信号池中有信号,才能保证有完整的元素在队列中
|
||||
<-q.poolEnd
|
||||
if timeout > 0 {
|
||||
select {
|
||||
case <-q.poolEnd:
|
||||
case <-time.After(time.Duration(timeout) * time.Second):
|
||||
item = nil
|
||||
return
|
||||
}
|
||||
} else {
|
||||
<-q.poolEnd
|
||||
}
|
||||
|
||||
item = q.queue[q.rIndex]
|
||||
|
||||
|
|
|
@ -15,24 +15,24 @@ func TestQueue(t *testing.T) {
|
|||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
que.push(pkg1)
|
||||
que.push(pkg1, 0)
|
||||
if que.size() != 1 || que.isEmpty() {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
pkg11 := que.pop()
|
||||
pkg11 := que.pop(0)
|
||||
if que.size() != 0 || !que.isEmpty() {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
if pkg11.value != 1 || len(pkg11.data) != 1 || pkg11.data[0] != 1 {
|
||||
t.Errorf("value:%d data:%v\n", pkg11.value, pkg11.data)
|
||||
}
|
||||
que.push(pkg2)
|
||||
que.push(pkg2, 0)
|
||||
if que.size() != 1 || que.isEmpty() {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
pkg21 := que.pop()
|
||||
pkg21 := que.pop(0)
|
||||
if pkg21.value != 2 || len(pkg21.data) != 2 || pkg21.data[0] != 1 || pkg21.data[1] != 2 {
|
||||
t.Errorf("value:%d data:%v\n", pkg21.value, pkg21.data)
|
||||
}
|
||||
|
@ -48,17 +48,17 @@ func TestQueue(t *testing.T) {
|
|||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
que2.push(pkg1)
|
||||
que2.push(pkg1, 0)
|
||||
if que2.size() != 1 {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
que2.push(pkg2)
|
||||
que2.push(pkg2, 0)
|
||||
if que2.size() != 2 {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
que2.push(pkg3)
|
||||
que2.push(pkg3, 0)
|
||||
if que2.size() != 3 {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
@ -66,29 +66,29 @@ func TestQueue(t *testing.T) {
|
|||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
fmt.Println("pop")
|
||||
pkg11 := que2.pop()
|
||||
pkg11 := que2.pop(0)
|
||||
if pkg11.value != 1 || pkg11.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg11.value, pkg11.data)
|
||||
}
|
||||
}()
|
||||
|
||||
fmt.Println("wait pop")
|
||||
que2.push(pkg4)
|
||||
que2.push(pkg4, 0)
|
||||
if que2.size() != 3 {
|
||||
t.Errorf("size:%d isEmpty:%v\n", que.size(), que.isEmpty())
|
||||
}
|
||||
|
||||
pkg21 = que2.pop()
|
||||
pkg21 = que2.pop(0)
|
||||
if pkg21.value != 2 || pkg21.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg21.value, pkg21.data)
|
||||
}
|
||||
|
||||
pkg31 := que2.pop()
|
||||
pkg31 := que2.pop(0)
|
||||
if pkg31.value != 3 || pkg31.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg31.value, pkg31.data)
|
||||
}
|
||||
|
||||
pkg41 := que2.pop()
|
||||
pkg41 := que2.pop(0)
|
||||
if pkg41.value != 4 || pkg31.data != nil {
|
||||
t.Errorf("value:%d data:%v\n", pkg41.value, pkg41.data)
|
||||
}
|
||||
|
@ -96,11 +96,11 @@ func TestQueue(t *testing.T) {
|
|||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
fmt.Println("push")
|
||||
que2.push(pkg5)
|
||||
que2.push(pkg5, 0)
|
||||
}()
|
||||
|
||||
fmt.Println("wait push")
|
||||
pkg51 := que2.pop()
|
||||
pkg51 := que2.pop(0)
|
||||
if pkg51.value != 5 || len(pkg51.data) != 1 || pkg51.data[0] != 55 {
|
||||
t.Errorf("value:%d data:%v\n", pkg51.value, pkg51.data)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue