remote: add controller address preferences (#905)

* remote: add controller address preferences

* suppress localhost from breaking multicast discovery w/o prefs

* rewrite findCtrlAddress

* refactor

* refactor2

* add tests

---------

Co-authored-by: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com>
This commit is contained in:
Alexander Bondarenko
2023-11-28 16:12:29 +02:00
committed by GitHub
parent 2a6be894e1
commit febf9019e2
5 changed files with 94 additions and 55 deletions

View File

@@ -43,6 +43,7 @@ import qualified Data.List.NonEmpty as L
import Data.Maybe (isNothing)
import qualified Data.Text as T
import Data.Time.Clock.System (getSystemTime)
import Data.Word (Word16)
import qualified Data.X509 as X509
import Data.X509.Validation (Fingerprint (..), getFingerprint)
import Network.Socket (PortNumber, SockAddr (..), hostAddressToTuple)
@@ -101,26 +102,29 @@ data RCHClient_ = RCHClient_
endSession :: TMVar ()
}
type RCHostConnection = (RCSignedInvitation, RCHostClient, RCStepTMVar (SessionCode, TLS, RCStepTMVar (RCHostSession, RCHostHello, RCHostPairing)))
type RCHostConnection = (NonEmpty RCCtrlAddress, RCSignedInvitation, RCHostClient, RCStepTMVar (SessionCode, TLS, RCStepTMVar (RCHostSession, RCHostHello, RCHostPairing)))
connectRCHost :: TVar ChaChaDRG -> RCHostPairing -> J.Value -> Bool -> ExceptT RCErrorType IO RCHostConnection
connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ctrlAppInfo multicast = do
connectRCHost :: TVar ChaChaDRG -> RCHostPairing -> J.Value -> Bool -> Maybe RCCtrlAddress -> Maybe Word16 -> ExceptT RCErrorType IO RCHostConnection
connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ctrlAppInfo multicast rcAddrPrefs_ port_ = do
r <- newEmptyTMVarIO
host <- getLocalAddress >>= maybe (throwError RCENoLocalAddress) pure
found@(RCCtrlAddress {address} :| _) <- findCtrlAddress
c@RCHClient_ {startedPort, announcer} <- liftIO mkClient
hostKeys <- liftIO genHostKeys
action <- runClient c r hostKeys `putRCError` r
-- wait for the port to make invitation
-- TODO can't we actually find to which interface the server got connected to get host there?
portNum <- atomically $ readTMVar startedPort
signedInv@RCSignedInvitation {invitation} <- maybe (throwError RCETLSStartFailed) (liftIO . mkInvitation hostKeys host) portNum
signedInv@RCSignedInvitation {invitation} <- maybe (throwError RCETLSStartFailed) (liftIO . mkInvitation hostKeys address) portNum
when multicast $ case knownHost of
Nothing -> throwError RCENewController
Just KnownHostPairing {hostDhPubKey} -> do
ann <- async . liftIO . runExceptT $ announceRC drg 60 idPrivKey hostDhPubKey hostKeys invitation
atomically $ putTMVar announcer ann
pure (signedInv, RCHostClient {action, client_ = c}, r)
pure (found, signedInv, RCHostClient {action, client_ = c}, r)
where
findCtrlAddress :: ExceptT RCErrorType IO (NonEmpty RCCtrlAddress)
findCtrlAddress = do
found' <- liftIO $ getLocalAddress rcAddrPrefs_
maybe (throwError RCENoLocalAddress) pure $ L.nonEmpty found'
mkClient :: IO RCHClient_
mkClient = do
startedPort <- newEmptyTMVarIO
@@ -131,7 +135,7 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
runClient :: RCHClient_ -> RCStepTMVar (SessionCode, TLS, RCStepTMVar (RCHostSession, RCHostHello, RCHostPairing)) -> RCHostKeys -> ExceptT RCErrorType IO (Async ())
runClient RCHClient_ {startedPort, announcer, hostCAHash, endSession} r hostKeys = do
tlsCreds <- liftIO $ genTLSCredentials caKey caCert
startTLSServer startedPort tlsCreds (tlsHooks r knownHost hostCAHash) $ \tls ->
startTLSServer port_ startedPort tlsCreds (tlsHooks r knownHost hostCAHash) $ \tls ->
void . runExceptT $ do
r' <- newEmptyTMVarIO
whenM (atomically $ tryPutTMVar r $ Right (tlsUniq tls, tls, r')) $

View File

@@ -7,21 +7,22 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
-- XXX: remove non-discovery functions
module Simplex.RemoteControl.Discovery where
import Control.Applicative ((<|>))
import Control.Logger.Simple
import Control.Monad
import Crypto.Random (getRandomBytes)
import Data.ByteString (ByteString)
import Data.Default (def)
import Data.Maybe (listToMaybe, mapMaybe)
import Data.List (delete, find)
import Data.Maybe (mapMaybe)
import Data.String (IsString)
import qualified Data.Text as T
import Data.Word (Word16)
import Network.Info (IPv4 (..), NetworkInterface (..), getNetworkInterfaces)
import qualified Network.Socket as N
import qualified Network.TLS as TLS
import qualified Network.UDP as UDP
import Simplex.Messaging.Encoding (Encoding (..))
import Simplex.Messaging.Transport (supportedParameters)
import qualified Simplex.Messaging.Transport as Transport
import Simplex.Messaging.Transport.Client (TransportHost (..))
@@ -41,49 +42,36 @@ pattern ANY_ADDR_V4 = "0.0.0.0"
pattern DISCOVERY_PORT :: (IsString a, Eq a) => a
pattern DISCOVERY_PORT = "5227"
getLocalAddress :: MonadIO m => m (Maybe TransportHost)
getLocalAddress = listToMaybe . mapMaybe usable <$> liftIO getNetworkInterfaces
getLocalAddress :: Maybe RCCtrlAddress -> IO [RCCtrlAddress]
getLocalAddress preferred_ =
maybe id preferAddress preferred_ . mkLastLocalHost . mapMaybe toCtrlAddr <$> getNetworkInterfaces
where
usable NetworkInterface {ipv4 = IPv4 ha} = case N.hostAddressToTuple ha of
toCtrlAddr NetworkInterface {name, ipv4 = IPv4 ha} = case N.hostAddressToTuple ha of
(0, 0, 0, 0) -> Nothing -- "no" address
(255, 255, 255, 255) -> Nothing -- broadcast
(127, _, _, _) -> Nothing -- localhost
(169, 254, _, _) -> Nothing -- link-local
ok -> Just $ THIPv4 ok
ok -> Just RCCtrlAddress {address = THIPv4 ok, interface = T.pack name}
getLocalAddressMulticast :: MonadIO m => TMVar Int -> m (Maybe TransportHost)
getLocalAddressMulticast subscribers = liftIO $ do
probe <- mkIpProbe
let bytes = smpEncode probe
withListener subscribers $ \receiver ->
withSender $ \sender -> do
UDP.send sender bytes
let expect = do
UDP.recvFrom receiver >>= \case
(p, _) | p /= bytes -> expect
(_, UDP.ClientSockAddr (N.SockAddrInet _port host) _cmsg) -> pure $ THIPv4 (N.hostAddressToTuple host)
(_, UDP.ClientSockAddr _badAddr _) -> error "receiving from IPv4 socket"
timeout 1000000 expect
mkLastLocalHost :: [RCCtrlAddress] -> [RCCtrlAddress]
mkLastLocalHost addrs = case find localHost addrs of
Nothing -> addrs
Just lh -> delete lh addrs <> [lh]
where
localHost RCCtrlAddress {address = a} = a == THIPv4 (127, 0, 0, 1)
mkIpProbe :: MonadIO m => m IpProbe
mkIpProbe = do
randomNonce <- liftIO $ getRandomBytes 32
pure IpProbe {versionRange = ipProbeVersionRange, randomNonce}
preferAddress :: RCCtrlAddress -> [RCCtrlAddress] -> [RCCtrlAddress]
preferAddress RCCtrlAddress {address, interface} addrs =
case find matchAddr addrs <|> find matchIface addrs of
Nothing -> addrs
Just p -> p : delete p addrs
where
matchAddr RCCtrlAddress {address = a} = a == address
matchIface RCCtrlAddress {interface = i} = i == interface
-- | Send replay-proof announce datagrams
-- runAnnouncer :: (C.PrivateKeyEd25519, Announce) -> IO ()
-- runAnnouncer (announceKey, initialAnnounce) = withSender $ loop initialAnnounce
-- where
-- loop announce sock = do
-- UDP.send sock $ smpEncode (signAnnounce announceKey announce)
-- threadDelay 1000000
-- loop announce {announceCounter = announceCounter announce + 1} sock
-- XXX: move to RemoteControl.Client
startTLSServer :: MonadUnliftIO m => TMVar (Maybe N.PortNumber) -> TLS.Credentials -> TLS.ServerHooks -> (Transport.TLS -> IO ()) -> m (Async ())
startTLSServer startedOnPort credentials hooks server = async . liftIO $ do
startTLSServer :: MonadUnliftIO m => Maybe Word16 -> TMVar (Maybe N.PortNumber) -> TLS.Credentials -> TLS.ServerHooks -> (Transport.TLS -> IO ()) -> m (Async ())
startTLSServer port_ startedOnPort credentials hooks server = async . liftIO $ do
started <- newEmptyTMVarIO
bracketOnError (startTCPServer started "0") (\_e -> setPort Nothing) $ \socket ->
bracketOnError (startTCPServer started $ maybe "0" show port_) (\_e -> setPort Nothing) $ \socket ->
ifM
(atomically $ readTMVar started)
(runServer started socket)

View File

@@ -14,6 +14,7 @@ import qualified Data.Aeson as J
import qualified Data.Aeson.TH as JQ
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString (ByteString)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import qualified Simplex.Messaging.Crypto as C
@@ -23,6 +24,7 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON)
import Simplex.Messaging.Transport (TLS)
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util (safeDecodeUtf8)
import Simplex.Messaging.Version (Version, VersionRange, mkVersionRange)
import UnliftIO
@@ -134,6 +136,12 @@ data KnownHostPairing = KnownHostPairing
hostDhPubKey :: C.PublicKeyX25519
}
data RCCtrlAddress = RCCtrlAddress
{ address :: TransportHost, -- allows any interface when found exactly
interface :: Text
}
deriving (Show, Eq)
-- | Long-term part of host (mobile) connection to controller (desktop)
data RCCtrlPairing = RCCtrlPairing
{ caKey :: C.APrivateSignKey,
@@ -226,3 +234,5 @@ cancelTasks :: MonadIO m => Tasks -> m ()
cancelTasks tasks = readTVarIO tasks >>= mapM_ cancel
$(JQ.deriveJSON (sumTypeJSON $ dropPrefix "RCE") ''RCErrorType)
$(JQ.deriveJSON defaultJSON ''RCCtrlAddress)