diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index 8f3062fed..9c04afed8 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -24,7 +24,7 @@ import Simplex.Messaging.Server (runSMPServer) import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Server.StoreLog (StoreLog, openReadStoreLog, storeLogFilePath) import Simplex.Messaging.Transport (ATransport (..), TLS, Transport (..)) --- import Simplex.Messaging.Transport.WebSockets (WS) +import Simplex.Messaging.Transport.WebSockets (WS) import System.Directory (createDirectoryIfMissing, doesFileExist, removeFile) import System.Exit (exitFailure) import System.FilePath (combine) @@ -119,8 +119,7 @@ getConfig opts = do makeConfig :: IniOpts -> C.PrivateKey 'C.RSA -> Maybe (StoreLog 'ReadMode) -> ServerConfig makeConfig IniOpts {serverPort, blockSize, enableWebsockets, serverPrivateKeyFile, serverCertificateFile} pk storeLog = - -- let transports = (serverPort, transport @TLS) : [("80", transport @WS) | enableWebsockets] - let transports = [(serverPort, transport @TLS)] + let transports = (serverPort, transport @TLS) : [("80", transport @WS) | enableWebsockets] in serverConfig {transports, storeLog, blockSize, serverPrivateKey = pk, serverPrivateKeyFile, serverCertificateFile} printConfig :: ServerConfig -> IO () diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 492708131..81aae8d9a 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -65,7 +65,7 @@ import Simplex.Messaging.Agent.Protocol (SMPServer (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Transport (ATransport (..), SessionId (..), THandle (..), TLS, TProxy, Transport (..), TransportError, clientHandshake, runTransportClient) --- import Simplex.Messaging.Transport.WebSockets (WS) +import Simplex.Messaging.Transport.WebSockets (WS) import Simplex.Messaging.Util (bshow, liftError, raceAny_) import System.Timeout (timeout) @@ -179,7 +179,7 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing, smpBlock useTransport :: (ServiceName, ATransport) useTransport = case port smpServer of Nothing -> defaultTransport cfg - -- Just "80" -> ("80", transport @WS) + Just "80" -> ("80", transport @WS) Just p -> (p, transport @TLS) client :: forall c. Transport c => TProxy c -> SMPClient -> TMVar (Either SMPClientError (THandle c)) -> c -> IO () diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 3af37aed4..6d1b4067a 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -37,6 +37,7 @@ module Simplex.Messaging.Transport -- * TLS 1.3 Transport TLS (..), + closeTLS, -- * SMP encrypted transport THandle (..), diff --git a/src/Simplex/Messaging/Transport/WebSockets.hs b/src/Simplex/Messaging/Transport/WebSockets.hs index 29f3de2d1..4e0ba0ffb 100644 --- a/src/Simplex/Messaging/Transport/WebSockets.hs +++ b/src/Simplex/Messaging/Transport/WebSockets.hs @@ -1,62 +1,79 @@ {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} module Simplex.Messaging.Transport.WebSockets (WS (..)) where import qualified Control.Exception as E import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Network.Socket (Socket) +import qualified Data.ByteString.Lazy as BL +import qualified Network.TLS as T import Network.WebSockets import Network.WebSockets.Stream (Stream) import qualified Network.WebSockets.Stream as S -import Simplex.Messaging.Transport (TProxy, Transport (..), TransportError (..), trimCR) +import Simplex.Messaging.Transport (TLS (..), TProxy, Transport (..), TransportError (..), closeTLS, trimCR) data WS = WS {wsStream :: Stream, wsConnection :: Connection} --- websocketsOpts :: ConnectionOptions --- websocketsOpts = --- defaultConnectionOptions --- { connectionCompressionOptions = NoCompression, --- connectionFramePayloadSizeLimit = SizeLimit 8192, --- connectionMessageDataSizeLimit = SizeLimit 65536 --- } +websocketsOpts :: ConnectionOptions +websocketsOpts = + defaultConnectionOptions + { connectionCompressionOptions = NoCompression, + connectionFramePayloadSizeLimit = SizeLimit 8192, + connectionMessageDataSizeLimit = SizeLimit 65536 + } --- instance Transport WS where --- transportName :: TProxy WS -> String --- transportName _ = "WebSockets" +instance Transport WS where + transportName :: TProxy WS -> String + transportName _ = "WebSockets" --- getServerConnection :: Socket -> IO WS --- getServerConnection sock = do --- s <- S.makeSocketStream sock --- WS s <$> acceptClientRequest s --- where --- acceptClientRequest :: Stream -> IO Connection --- acceptClientRequest s = makePendingConnectionFromStream s websocketsOpts >>= acceptRequest + getServerConnection :: TLS -> IO WS + getServerConnection TLS {tlsContext} = do + s <- websocketsStream tlsContext + WS s <$> acceptClientRequest s + where + acceptClientRequest :: Stream -> IO Connection + acceptClientRequest s = makePendingConnectionFromStream s websocketsOpts >>= acceptRequest --- getClientConnection :: Socket -> IO WS --- getClientConnection sock = do --- s <- S.makeSocketStream sock --- WS s <$> sendClientRequest s --- where --- sendClientRequest :: Stream -> IO Connection --- sendClientRequest s = newClientConnection s "" "/" websocketsOpts [] + getClientConnection :: TLS -> IO WS + getClientConnection TLS {tlsContext} = do + s <- websocketsStream tlsContext + WS s <$> sendClientRequest s + where + sendClientRequest :: Stream -> IO Connection + sendClientRequest s = newClientConnection s "" "/" websocketsOpts [] --- closeConnection :: WS -> IO () --- closeConnection = S.close . wsStream + closeConnection :: WS -> IO () + closeConnection = S.close . wsStream --- cGet :: WS -> Int -> IO ByteString --- cGet c n = do --- s <- receiveData (wsConnection c) --- if B.length s == n --- then pure s --- else E.throwIO TEBadBlock + cGet :: WS -> Int -> IO ByteString + cGet c n = do + s <- receiveData (wsConnection c) + if B.length s == n + then pure s + else E.throwIO TEBadBlock --- cPut :: WS -> ByteString -> IO () --- cPut = sendBinaryData . wsConnection + cPut :: WS -> ByteString -> IO () + cPut = sendBinaryData . wsConnection --- getLn :: WS -> IO ByteString --- getLn c = do --- s <- trimCR <$> receiveData (wsConnection c) --- if B.null s || B.last s /= '\n' --- then E.throwIO TEBadBlock --- else pure $ B.init s + getLn :: WS -> IO ByteString + getLn c = do + s <- trimCR <$> receiveData (wsConnection c) + if B.null s || B.last s /= '\n' + then E.throwIO TEBadBlock + else pure $ B.init s + +websocketsStream :: T.Context -> IO S.Stream +websocketsStream tlsContext = + S.makeStream readStream writeStream + where + readStream :: IO (Maybe ByteString) + readStream = + (Just <$> T.recvData tlsContext) `E.catch` \case + T.Error_EOF -> pure Nothing + e -> E.throwIO e + writeStream :: Maybe BL.ByteString -> IO () + writeStream = \case + Nothing -> closeTLS tlsContext + Just bs -> T.sendData tlsContext bs diff --git a/tests/Test.hs b/tests/Test.hs index 9f946592f..25cfa3f3d 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -4,7 +4,7 @@ import AgentTests (agentTests) import ProtocolErrorTests import ServerTests import Simplex.Messaging.Transport (TLS, Transport (..)) --- import Simplex.Messaging.Transport.WebSockets (WS) +import Simplex.Messaging.Transport.WebSockets (WS) import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive) import Test.Hspec @@ -14,6 +14,6 @@ main = do hspec $ do describe "Protocol errors" protocolErrorTests describe "SMP server via TLS 1.3" $ serverTests (transport @TLS) - -- describe "SMP server via WebSockets" $ serverTests (transport @WS) + describe "SMP server via WebSockets" $ serverTests (transport @WS) describe "SMP client agent" $ agentTests (transport @TLS) removeDirectoryRecursive "tests/tmp"