package protocol import ( "bytes" "errors" "io" "sync/atomic" "time" "git.viry.cc/gomod/glog" ) const VERSION uint8 = 1 var ErrorReadCallbackIsNil = errors.New("read callback is nil") var ErrorReaderIsNil = errors.New("reader is nil") var ErrorWriterIsNil = errors.New("writer is nil") var ErrorReaderIsKilled = errors.New("reader is killed") var ErrorWriterIsKilled = errors.New("writer is killed") var ErrorWriterQueueIsNil = errors.New("writer queue is nil") var ErrorHeartbeatIsKilled = errors.New("heartbeat is killed") var ErrorHeartbeatCallbackIsNil = errors.New("heartbeat callback is nil") var ErrorDataSizeExceedsLimit = errors.New("data size exceeds limit") var ErrorTimeout = errors.New("timeout") const ( statusRunning int32 = iota statusKilled ) type Protocol struct { // 标记protocol tag string r io.Reader w io.Writer // protocol的状态 status int32 // 用于处理获取到的数据,每个package中的数据都会完整的保存在data中 readCallback func(data []byte) // 写入等待队列 writeQueue *queue // 当前protocol正在运行的协程数量 runningRoutines int32 // 心跳信号,同时也是心跳响应信号 heartbeatSig chan uint8 // 心跳请求信号,收到此信号必须回复对方 heartbeatSigReq chan uint8 // 发送心跳请求的间隔 heartbeatInterval uint32 // 接收心跳请求的超时时间 heartbeatTimeout uint32 // 心跳请求超时后的处理函数 heartbeatTimeoutCallback func(p *Protocol) bool // 上次发送心跳的时间 heartbeatLastSend int64 // 上次收到心跳的时间 heartbeatLastReceived int64 // status被标记为statusKilled时执行,可以用于关闭reader和writer killCallback func() // 在reader读取数据前执行的函数,常用于设置reader的读取截止时间,防止协程卡死 setFuncBeforeRead func() // 在reader读取数据后执行的函数 setFuncAfterRead func(error) // 在writer发送数据前执行的函数,常用于设置writer的发送截止时间,防止协程卡死 setFuncBeforeWrite func() // 在writer发送数据后执行的函数 setFuncAfterWrite func(error) } // New 返回一个protocol实例 // // tag: 标签,用于区分protocol实例 // r: 数据流的reader // w: 数据流的writer // writeQueueSize: 发送等待队列长度 // readCallback: 用于处理获取到的数据,每个package中的数据都会完整的保存在data中 // heartbeatTimeoutCallback: 心跳请求超时后的处理函数 // setFuncBeforeRead: 在reader读取数据前,设置reader的读取截止时间 // setFuncBeforeWrite: 在writer发送数据前,设置writer的发送截止时间 // killCallback: status被标记为statusKilled时执行,可以用于关闭reader和writer func New(tag string, r io.Reader, w io.Writer, writeQueueSize int, readCallback func(data []byte), heartbeatTimeoutCallback func(p *Protocol) bool, setFuncBeforeRead func(), setFuncAfterRead func(error), setFuncBeforeWrite func(), setFuncAfterWrite func(error), killCallback func()) *Protocol { if r == nil { glog.Warning("[protocol.%s] reader is nil", tag) return nil } if w == nil { glog.Warning("[protocol.%s] writer is nil", tag) return nil } if writeQueueSize < 1 { glog.Trace("[protocol.%s] writeQueueSize is < 1, use 1", tag) writeQueueSize = 1 } if readCallback == nil { glog.Trace("[protocol.%s] readCallback is nil, use defaultReadCallback", tag) readCallback = defaultReadCallback } if heartbeatTimeoutCallback == nil { glog.Trace("[protocol.%s] heartbeatTimeoutCallback is nil, use defaultHeartbeatTimeoutCallback", tag) heartbeatTimeoutCallback = defaultHeartbeatTimeoutCallback } if killCallback == nil { glog.Trace("[protocol.%s] killCallback is nil, use defaultKillCallback", tag) killCallback = defaultKillCallback } return &Protocol{ tag: tag, r: r, w: w, status: statusRunning, readCallback: readCallback, writeQueue: newQueue(writeQueueSize), runningRoutines: 0, heartbeatSig: make(chan uint8, 1), heartbeatSigReq: make(chan uint8, 1), heartbeatInterval: 15, heartbeatTimeout: 40, heartbeatTimeoutCallback: heartbeatTimeoutCallback, heartbeatLastSend: 0, heartbeatLastReceived: 0, killCallback: killCallback, setFuncBeforeRead: setFuncBeforeRead, setFuncAfterRead: setFuncAfterRead, setFuncBeforeWrite: setFuncBeforeWrite, setFuncAfterWrite: setFuncAfterWrite, } } func (p *Protocol) Connect(activeHeartbeatSignalSender bool) { go p.reader() go p.writer() go p.heartbeat() if activeHeartbeatSignalSender { go p.heartbeatSignalSender() } } func (p *Protocol) handlePackage(pkg *protocolPackage) { glog.Trace("[protocol.%s] handle package", p.tag) if pkg == nil { glog.Trace("[protocol.%s] package is nil", p.tag) return } if pkg.isEncrypted() { glog.Trace("[protocol.%s] package is encrypted, decrypt package", p.tag) pkg.decrypt() } if !pkg.checkHead() { glog.Trace("[protocol.%s] package head broken", p.tag) return } if (pkg.flag & flagHeartbeat) != 0 { glog.Info("[protocol.%s] heartbeat signal in package", p.tag) p.heartbeatSig <- pkg.value } if (pkg.flag & flagHeartbeatRequest) != 0 { glog.Info("[protocol.%s] heartbeat request signal in package", p.tag) p.heartbeatSigReq <- pkg.value } if !pkg.checkData() { glog.Trace("[protocol.%s] package data broken", p.tag) return } if pkg.dataSize == 0 { glog.Trace("[protocol.%s] package data empty", p.tag) return } glog.Info("[protocol.%s] handle package successful, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize) p.readCallback(pkg.data) } // Reader 阻塞接收数据并提交给readCallback func (p *Protocol) reader() { glog.Trace("[protocol.%s] reader enable", p.tag) if p.r == nil { glog.Warning("[protocol.%s] reader is not ready", p.tag) return } p.incRunningRoutine() defer p.decRunningRoutine() buffer := &bytes.Buffer{} buf := make([]byte, packageMaxSize) var err error var n int // 监听并接收数据 for { if p.getStatus() == statusKilled { glog.Trace("[protocol.%s] reader is killed", p.tag) return } if p.setFuncBeforeRead != nil { glog.Trace("[protocol.%s] reader func before read", p.tag) p.setFuncBeforeRead() } glog.Trace("[protocol.%s] reader wait read", p.tag) n, err = p.r.Read(buf) if p.setFuncAfterRead != nil { glog.Trace("[protocol.%s] reader func after read", p.tag) p.setFuncAfterRead(err) } if err != nil { glog.Trace("[protocol.%s] read error %v", p.tag, err) time.Sleep(500 * time.Millisecond) continue } if n == 0 { glog.Trace("[protocol.%s] read empty", p.tag) time.Sleep(500 * time.Millisecond) continue } n, err = buffer.Write(buf[:n]) glog.Trace("[protocol.%s] write %d bytes, buffer already %d bytes, error is %v", p.tag, n, buffer.Len(), err) for buffer.Len() >= packageHeadSize { glog.Trace("[protocol.%s] complete package, buffer length %d", p.tag, buffer.Len()) pkg, err := parsePackage(buffer) if err != nil { if errors.Is(err, ErrorPackageIncomplete) { glog.Trace("[protocol.%s] incomplete package, buffer length %d", p.tag, buffer.Len()) break } glog.Info("[protocol.%s] parse package with error %v", p.tag, err) } if pkg != nil { glog.Info("[protocol.%s] receive new package, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize) go p.handlePackage(pkg) } } } } // Writer 创建发送队列并监听待发送数据 func (p *Protocol) writer() { glog.Trace("[protocol.%s] writer enable", p.tag) if p.w == nil { glog.Warning("[protocol.%s] writer is not ready", p.tag) return } p.incRunningRoutine() defer p.decRunningRoutine() var err error var n int for { if p.getStatus() == statusKilled { glog.Trace("[protocol.%s] writer is killed", p.tag) return } glog.Trace("[protocol.%s] writer wait pop", p.tag) pkg := p.writeQueue.pop(int(p.GetHeartbeatInterval())) if pkg == nil { glog.Trace("[protocol.%s] writer pop timeout", p.tag) continue } if p.setFuncBeforeWrite != nil { glog.Trace("[protocol.%s] writer func before write", p.tag) p.setFuncBeforeWrite() } glog.Trace("[protocol.%s] writer wait write", p.tag) n, err = p.w.Write(pkg.Bytes().Bytes()) if p.setFuncAfterWrite != nil { glog.Trace("[protocol.%s] writer func after write", p.tag) p.setFuncAfterWrite(err) } glog.Trace("[protocol.%s] write %d bytes, error is %v", p.tag, n, err) if err != nil { glog.Info("[protocol.%s] send package failed with error %v, re-push package", p.tag, err) time.Sleep(time.Second) for !p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) { if p.getStatus() == statusKilled { glog.Trace("[protocol.%s] writer is killed", p.tag) return } } } glog.Info("[protocol.%s] send package successful, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize) } } // Write 发送数据 func (p *Protocol) Write(data []byte) error { glog.Trace("[protocol.%s] write", p.tag) if len(data) > dataMaxSize { glog.Info("[protocol.%s] maximum supported data size exceeded", p.tag) return ErrorDataSizeExceedsLimit } pkg := newPackage(0, encryptNone, 0, data) for { if p.getStatus() == statusKilled { glog.Info("[protocol.%s] protocol is killed", p.tag) return ErrorWriterIsKilled } if p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) { glog.Info("[protocol.%s] write successful, crc32:[%d] flag:[%x] dataSize:[%d]", p.tag, pkg.crc32, pkg.flag, pkg.dataSize) return nil } } } // heartbeat 心跳服务 // // heartbeatTimeout: 被动接收心跳信号的超时时间(s),最小为3s,传入参数小于3时使用默认值30 // heartbeatTimeoutCallback: 没有按时收到心跳信号时调用,返回true继续等待,返回false退出 func (p *Protocol) heartbeat() { glog.Trace("[protocol.%s] heartbeat enable", p.tag) p.incRunningRoutine() defer p.decRunningRoutine() for { select { case <-time.After(time.Duration(p.GetHeartbeatTimeout()) * time.Second): glog.Info("[protocol.%s] heartbeat timeout", p.tag) if p.getStatus() == statusKilled { glog.Trace("[protocol.%s] heartbeat is killed", p.tag) return } if !p.heartbeatTimeoutCallback(p) { glog.Trace("[protocol.%s] heartbeat is killed, set status killed", p.tag) p.setStatus(statusKilled) return } case val := <-p.heartbeatSigReq: glog.Info("[protocol.%s] heartbeat request signal received", p.tag) p.setHeartbeatLastReceived() p.sendHeartbeatSignal(false) if val != 0 { p.SetHeartbeatTimeout(val) } case val := <-p.heartbeatSig: glog.Info("[protocol.%s] heartbeat signal received", p.tag) p.setHeartbeatLastReceived() if val != 0 { p.SetHeartbeatTimeout(val) } } if p.getStatus() == statusKilled { glog.Trace("[protocol.%s] heartbeat is killed", p.tag) return } } } // heartbeatSignalSender 主动触发心跳 // // heartbeatInterval: 主动发送心跳信号的间隔时间(s),最小为3s,传入参数小于3时使用默认值3 func (p *Protocol) heartbeatSignalSender() { p.incRunningRoutine() defer p.decRunningRoutine() for { if p.getStatus() == statusKilled { glog.Trace("[protocol.%s] heartbeat signal sender is killed", p.tag) return } p.sendHeartbeatSignal(true) time.Sleep(time.Duration(p.GetHeartbeatInterval()) * time.Second) } } func (p *Protocol) sendHeartbeatSignal(isReq bool) { glog.Trace("[protocol.%s] send heartbeat signal", p.tag) var pkg *protocolPackage if isReq { pkg = newPackage(flagHeartbeatRequest, encryptNone, 0, nil) } else { pkg = newPackage(flagHeartbeat, encryptNone, 0, nil) } for !p.writeQueue.push(pkg, int(p.GetHeartbeatInterval())) { if p.getStatus() == statusKilled { glog.Info("[protocol.%s] protocol is killed", p.tag) return } } p.setHeartbeatLastSend() } func (p *Protocol) setStatus(status int32) { glog.Trace("[protocol.%s] set status %d", p.tag, status) if status == statusKilled { p.killCallback() } atomic.StoreInt32(&p.status, status) } func (p *Protocol) getStatus() int32 { glog.Trace("[protocol.%s] get status", p.tag) return atomic.LoadInt32(&p.status) } func (p *Protocol) SetHeartbeatInterval(interval uint8) { if interval < 3 { glog.Trace("[protocol.%s] heartbeatInterval is < 3, use 3", p.tag) interval = 3 } atomic.StoreUint32(&p.heartbeatInterval, uint32(interval)) } func (p *Protocol) GetHeartbeatInterval() uint8 { return uint8(atomic.LoadUint32(&p.heartbeatInterval)) } func (p *Protocol) SetHeartbeatTimeout(timeout uint8) { if timeout < 6 { glog.Trace("[protocol.%s] heartbeatTimeout is < 6, use 6", p.tag) timeout = 6 } atomic.StoreUint32(&p.heartbeatTimeout, uint32(timeout)) } func (p *Protocol) GetHeartbeatTimeout() uint8 { return uint8(atomic.LoadUint32(&p.heartbeatTimeout)) } func (p *Protocol) setHeartbeatLastReceived() { atomic.StoreInt64(&p.heartbeatLastReceived, time.Now().Unix()) } func (p *Protocol) GetHeartbeatLastReceived() int64 { return atomic.LoadInt64(&p.heartbeatLastReceived) } func (p *Protocol) setHeartbeatLastSend() { atomic.StoreInt64(&p.heartbeatLastSend, time.Now().Unix()) } func (p *Protocol) GetHeartbeatLastSend() int64 { return atomic.LoadInt64(&p.heartbeatLastSend) } func (p *Protocol) incRunningRoutine() { atomic.AddInt32(&p.runningRoutines, 1) } func (p *Protocol) decRunningRoutine() { atomic.AddInt32(&p.runningRoutines, -1) } func (p *Protocol) GetRunningRoutine() int32 { return atomic.LoadInt32(&p.runningRoutines) } func (p *Protocol) WaitKilled(timeout int) error { out := time.After(time.Duration(timeout) * time.Second) for { select { case <-out: return ErrorTimeout case <-time.After(time.Second): if p.GetRunningRoutine() <= 0 { return nil } } } } func (p *Protocol) GetTag() string { return p.tag } func (p *Protocol) SetTag(tag string) { p.tag = tag } func (p *Protocol) Kill() { p.setStatus(statusKilled) } func defaultReadCallback(data []byte) { glog.Trace("[protocol] default read callback %x", data) } func defaultHeartbeatTimeoutCallback(*Protocol) bool { glog.Trace("[protocol] default heartbeat timeout callback") return true } func defaultKillCallback() { glog.Trace("[protocol] default kill callback") } func GetDataMaxSize() int { return dataMaxSize } func CalculateTheNumberOfPackages(size int64) int64 { res := size / dataMaxSize if size%dataMaxSize != 0 { res += 1 } return res }