diff --git a/src/Server.hs b/src/Server.hs index 7450f3fb7..b86e0d339 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -111,9 +111,9 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} = processCommand (connId, cmd) = do st <- asks connStore case cmd of - Cmd SBroker END -> unsubscribeConn >> return (connId, cmd) + Cmd SBroker END -> unsubscribeConn connId >> return (connId, cmd) Cmd SBroker _ -> return (connId, cmd) - Cmd SSender (SEND msgBody) -> sendMessage st msgBody + Cmd SSender (SEND msgBody) -> sendMessage st connId msgBody Cmd SRecipient command -> case command of CONN rKey -> createConn st rKey SUB -> subscribeConn connId @@ -122,13 +122,21 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} = OFF -> okResponse <$> suspendConn st connId DEL -> okResponse <$> deleteConn st connId where + ok :: Signed + ok = (connId, Cmd SBroker OK) + + okResponse :: Either ErrorType () -> Signed + okResponse = mkSigned connId . either ERR (const OK) + createConn :: MonadConnStore s m => s -> RecipientKey -> m Signed - createConn st rKey = - addConn st rKey >>= \case - Right Connection {recipientId = rId, senderId = sId} -> do - void $ subscribeConn rId - return . mkSigned rId $ IDS rId sId - Left e -> return . mkSigned "" $ ERR e + createConn st rKey = mkSigned "" <$> addSubscribe + where + addSubscribe = + addConn st rKey >>= \case + Right Connection {recipientId = rId, senderId = sId} -> do + void $ subscribeConn rId + return $ IDS rId sId + Left e -> return $ ERR e subscribeConn :: RecipientId -> m Signed subscribeConn rId = do @@ -139,28 +147,28 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} = writeTVar connections $ M.insert rId (Left ()) cs deliverMessage tryPeekMsg rId - unsubscribeConn :: m () - unsubscribeConn = do + unsubscribeConn :: RecipientId -> m () + unsubscribeConn rId = do cs <- readTVarIO connections - atomically . writeTVar connections $ M.delete connId cs - case M.lookup connId cs of + atomically . writeTVar connections $ M.delete rId cs + case M.lookup rId cs of Just (Right threadId) -> killThread threadId _ -> return () - sendMessage :: MonadConnStore s m => s -> MsgBody -> m Signed - sendMessage st msgBody = - getConn st SSender connId - >>= fmap (mkSigned connId) . either (return . ERR) (storeMessage msgBody) - - storeMessage :: MsgBody -> Connection -> m (Command 'Broker) - storeMessage msgBody c = case status c of - ConnActive -> do - ms <- asks msgStore - q <- getMsgQueue ms (recipientId c) - msg <- newMessage msgBody - writeMsg q msg - return OK - ConnOff -> return $ ERR AUTH + sendMessage :: MonadConnStore s m => s -> SenderId -> MsgBody -> m Signed + sendMessage st sId msgBody = + getConn st SSender sId + >>= fmap (mkSigned sId) . either (return . ERR) storeMessage + where + storeMessage :: Connection -> m (Command 'Broker) + storeMessage c = case status c of + ConnActive -> do + ms <- asks msgStore + q <- getMsgQueue ms (recipientId c) + msg <- newMessage msgBody + writeMsg q msg + return OK + ConnOff -> return $ ERR AUTH deliverMessage :: (MsgQueue -> m (Maybe Message)) -> RecipientId -> m Signed deliverMessage tryPeek rId = do @@ -182,17 +190,12 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} = where trackSubscriber sThrd = atomically . modifyTVar connections $ M.insert rId sThrd subscriber = do - peekMsg q >>= atomically . writeTBQueue sndQ . msgResponse rId + msg <- peekMsg q + atomically . writeTBQueue sndQ $ msgResponse rId msg trackSubscriber $ Left () - ok :: Signed - ok = (connId, Cmd SBroker OK) - mkSigned :: ConnId -> Command 'Broker -> Signed mkSigned cId command = (cId, Cmd SBroker command) - okResponse :: Either ErrorType () -> Signed - okResponse = mkSigned connId . either ERR (const OK) - msgResponse :: RecipientId -> Message -> Signed msgResponse rId Message {msgId, ts, msgBody} = mkSigned rId $ MSG msgId ts msgBody diff --git a/src/Transport.hs b/src/Transport.hs index 696fe9389..776af57cb 100644 --- a/src/Transport.hs +++ b/src/Transport.hs @@ -113,6 +113,8 @@ tGet fromParty h = do where tCredentials :: RawTransmission -> Cmd -> Either ErrorType Cmd tCredentials (signature, connId, _) cmd = case cmd of + -- IDS response should not have connection ID + Cmd SBroker (IDS _ _) -> Right cmd -- ERROR response does not always have connection ID Cmd SBroker (ERR _) -> Right cmd -- other responses must have connection ID diff --git a/tests/Test.hs b/tests/Test.hs index 7abb2f41b..6e7cdd407 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -37,8 +37,8 @@ testCreateSecure :: SpecWith () testCreateSecure = do it "CONN and KEY commands, SEND messages (no delivery yet)" $ smpTest \h -> do - Resp rId (IDS rId1 sId) <- sendRecv h ("", "", "CONN 123") - (rId1, rId) #== "creates connection" + Resp rId1 (IDS rId sId) <- sendRecv h ("", "", "CONN 123") + (rId1, "") #== "creates connection" Resp sId1 ok1 <- sendRecv h ("", sId, "SEND :hello") (ok1, OK) #== "accepts unsigned SEND" @@ -83,8 +83,8 @@ testCreateDelete :: SpecWith () testCreateDelete = do it "CONN, OFF and DEL commands, SEND messages (no delivery yet)" $ smpTest \h -> do - Resp rId (IDS rId1 sId) <- sendRecv h ("", "", "CONN 123") - (rId1, rId) #== "creates connection" + Resp rId1 (IDS rId sId) <- sendRecv h ("", "", "CONN 123") + (rId1, "") #== "creates connection" Resp _ ok1 <- sendRecv h ("123", rId, "KEY 456") (ok1, OK) #== "secures connection"