diff --git a/src/Transport.hs b/src/Transport.hs index 7953c3b48..427b3bc59 100644 --- a/src/Transport.hs +++ b/src/Transport.hs @@ -4,6 +4,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} module Transport where @@ -84,14 +85,12 @@ tPutRaw h (signature, connId, command) = do putLn h (encode connId) putLn h command -tGetRaw :: MonadIO m => Handle -> m (Maybe RawTransmission) -tGetRaw h = - getDecodedLn $ \signature -> - getDecodedLn $ \connId -> do - command <- getLn h - return $ Just (signature, connId, command) - where - getDecodedLn f = getLn h >>= either (\_ -> return Nothing) f . decode +tGetRaw :: MonadIO m => Handle -> m (Either String RawTransmission) +tGetRaw h = do + signature <- decode <$> getLn h + connId <- decode <$> getLn h + command <- getLn h + return $ liftM2 (,,command) signature connId tPut :: MonadIO m => Handle -> Transmission -> m () tPut h (signature, (connId, command)) = tPutRaw h (signature, connId, serializeCommand command) @@ -109,13 +108,13 @@ fromServer = \case -- | get client and server transmissions -- `fromParty` is used to limit allowed senders - `fromClient` or `fromServer` should be used tGet :: forall m. MonadIO m => (Cmd -> Either ErrorType Cmd) -> Handle -> m TransmissionOrError -tGet fromParty h = tGetRaw h >>= maybe badTransmission tParseComplete +tGet fromParty h = tGetRaw h >>= either (const tError) tParseLoadBody where - badTransmission :: m TransmissionOrError - badTransmission = return (B.empty, (B.empty, Left $ SYNTAX errBadTransmission)) + tError :: m TransmissionOrError + tError = return (B.empty, (B.empty, Left $ SYNTAX errBadTransmission)) - tParseComplete :: RawTransmission -> m TransmissionOrError - tParseComplete t@(signature, connId, command) = do + tParseLoadBody :: RawTransmission -> m TransmissionOrError + tParseLoadBody t@(signature, connId, command) = do let cmd = parseCommand command >>= fromParty >>= tCredentials t fullCmd <- either (return . Left) cmdWithMsgBody cmd return (signature, (connId, fullCmd)) diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index fda1a1343..84ca3a3e7 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -5,7 +5,6 @@ module SMPClient where import Control.Monad.IO.Unlift import Crypto.Random -import Data.Maybe import Network.Socket import Numeric.Natural import Server @@ -44,4 +43,4 @@ smpServerTest :: [RawTransmission] -> IO [RawTransmission] smpServerTest commands = runSmpTest \h -> mapM (sendReceive h) commands where sendReceive :: Handle -> RawTransmission -> IO RawTransmission - sendReceive h t = tPutRaw h t >> fromJust <$> tGetRaw h + sendReceive h t = tPutRaw h t >> either (error "bad transmission") id <$> tGetRaw h