diff --git a/common/commonspace/synctree/queue.go b/common/commonspace/synctree/queue.go new file mode 100644 index 00000000..3e71e1b8 --- /dev/null +++ b/common/commonspace/synctree/queue.go @@ -0,0 +1,66 @@ +package synctree + +import ( + "errors" + "sync" +) + +type ReceiveQueue interface { + AddMessage(senderId string, msg treeMsg) (queueFull bool) + GetMessage(senderId string) (msg treeMsg, err error) + ClearQueue(senderId string) +} + +type receiveQueue struct { + sync.Mutex + handlerMap map[string][]treeMsg + maxSize int +} + +func newReceiveQueue(maxSize int) ReceiveQueue { + return &receiveQueue{ + Mutex: sync.Mutex{}, + handlerMap: map[string][]treeMsg{}, + maxSize: maxSize, + } +} + +var errEmptyQueue = errors.New("the queue is empty") + +func (q *receiveQueue) AddMessage(senderId string, msg treeMsg) (queueFull bool) { + q.Lock() + defer q.Unlock() + + queue := q.handlerMap[senderId] + queueFull = len(queue) >= maxQueueSize + queue = append(queue, msg) + q.handlerMap[senderId] = queue + + return +} + +func (q *receiveQueue) GetMessage(senderId string) (msg treeMsg, err error) { + q.Lock() + defer q.Unlock() + + if len(q.handlerMap) == 0 { + err = errEmptyQueue + return + } + + msg = q.handlerMap[senderId][0] + return +} + +func (q *receiveQueue) ClearQueue(senderId string) { + q.Lock() + defer q.Unlock() + + queue := q.handlerMap[senderId] + excessLen := len(queue) - q.maxSize + 1 + if excessLen <= 0 { + excessLen = 1 + } + queue = queue[excessLen:] + q.handlerMap[senderId] = queue +} diff --git a/common/commonspace/synctree/synctreehandler.go b/common/commonspace/synctree/synctreehandler.go index 8d97d5d2..77f17bc1 100644 --- a/common/commonspace/synctree/synctreehandler.go +++ b/common/commonspace/synctree/synctreehandler.go @@ -16,7 +16,7 @@ type syncTreeHandler struct { objTree tree.ObjectTree syncClient SyncClient handlerLock sync.Mutex - handlerMap map[string][]treeMsg + queue ReceiveQueue } const maxQueueSize = 5 @@ -30,7 +30,7 @@ func newSyncTreeHandler(objTree tree.ObjectTree, syncClient SyncClient) synchand return &syncTreeHandler{ objTree: objTree, syncClient: syncClient, - handlerMap: map[string][]treeMsg{}, + queue: newReceiveQueue(maxQueueSize), } } @@ -43,13 +43,7 @@ func (s *syncTreeHandler) HandleMessage(ctx context.Context, senderId string, ms return } - s.handlerLock.Lock() - queue := s.handlerMap[senderId] - queueFull := len(queue) >= maxQueueSize - queue = append(queue, treeMsg{msg.ReplyId, unmarshalled}) - s.handlerMap[senderId] = queue - s.handlerLock.Unlock() - + queueFull := s.queue.AddMessage(senderId, treeMsg{msg.ReplyId, unmarshalled}) if queueFull { return } @@ -71,30 +65,19 @@ func (s *syncTreeHandler) HandleMessage(ctx context.Context, senderId string, ms func (s *syncTreeHandler) handleMessage(ctx context.Context, senderId string) (actions []sendFunc, err error) { s.objTree.Lock() defer s.objTree.Unlock() - s.handlerLock.Lock() - treeMessage := s.handlerMap[senderId][0] - unmarshalled := treeMessage.syncMessage - replyId := treeMessage.replyId - s.handlerLock.Unlock() + msg, err := s.queue.GetMessage(senderId) + if err != nil { + return + } - defer func() { - s.handlerLock.Lock() - defer s.handlerLock.Unlock() - queue := s.handlerMap[senderId] - excessLen := len(queue) - maxQueueSize + 1 - if excessLen <= 0 { - excessLen = 1 - } - queue = queue[excessLen:] - s.handlerMap[senderId] = queue - }() + defer s.queue.ClearQueue(senderId) - content := unmarshalled.GetContent() + content := msg.syncMessage.GetContent() switch { case content.GetHeadUpdate() != nil: - return s.handleHeadUpdate(ctx, senderId, content.GetHeadUpdate(), replyId) + return s.handleHeadUpdate(ctx, senderId, content.GetHeadUpdate(), msg.replyId) case content.GetFullSyncRequest() != nil: - return s.handleFullSyncRequest(ctx, senderId, content.GetFullSyncRequest(), replyId) + return s.handleFullSyncRequest(ctx, senderId, content.GetFullSyncRequest(), msg.replyId) case content.GetFullSyncResponse() != nil: return s.handleFullSyncResponse(ctx, senderId, content.GetFullSyncResponse()) }