diff --git a/src/Simplex/Messaging/Transport/Client.hs b/src/Simplex/Messaging/Transport/Client.hs index f7af991a4..1a7f81992 100644 --- a/src/Simplex/Messaging/Transport/Client.hs +++ b/src/Simplex/Messaging/Transport/Client.hs @@ -26,6 +26,7 @@ import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Char (isAsciiLower, isDigit) import Data.Default (def) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L @@ -70,11 +71,14 @@ instance StrEncoding TransportHost where strP = A.choice [ THIPv4 <$> ((,,,) <$> ipNum <*> ipNum <*> ipNum <*> A.decimal), - THOnionHost <$> ((<>) <$> A.takeTill (== '.') <*> A.string ".onion"), - THDomainName . B.unpack <$> A.takeWhile1 (A.notInClass ":#,;/ ") + THOnionHost <$> ((<>) <$> A.takeWhile (\c -> isAsciiLower c || isDigit c) <*> A.string ".onion"), + THDomainName . B.unpack <$> (notOnion <$?> A.takeWhile1 (A.notInClass ":#,;/ \n\r\t")) ] where - ipNum = A.decimal <* A.char '.' + ipNum = validIP <$?> (A.decimal <* A.char '.') + validIP :: Int -> Either String Word8 + validIP n = if 0 <= n && n <= 255 then Right $ fromIntegral n else Left "invalid IP address" + notOnion s = if ".onion" `B.isSuffixOf` s then Left "invalid onion host" else Right s instance ToJSON TransportHost where toEncoding = strToJEncoding diff --git a/tests/CoreTests/EncodingTests.hs b/tests/CoreTests/EncodingTests.hs index 52ed73807..a89499777 100644 --- a/tests/CoreTests/EncodingTests.hs +++ b/tests/CoreTests/EncodingTests.hs @@ -1,4 +1,7 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module CoreTests.EncodingTests where @@ -10,7 +13,9 @@ import Data.Int (Int64) import Data.Time.Clock.System (SystemTime (..), getSystemTime, utcToSystemTime) import Data.Time.ISO8601 (parseISO8601) import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll) +import Simplex.Messaging.Transport.Client (TransportHost (..)) import Test.Hspec import Test.Hspec.QuickCheck (modifyMaxSuccess) import Test.QuickCheck @@ -39,8 +44,36 @@ encodingTests = modifyMaxSuccess (const 1000) $ do testSystemTime t' it "parse(encode(SystemTime) should equal the same Int64" . property $ \i -> parseAll smpP (smpEncode i) == Right (i :: Int64) + describe "Encoding transport hosts" $ do + describe "domain name hosts" $ do + it "should encode / decode domain name" $ THDomainName "smp.simplex.im" #==# "smp.simplex.im" + it "should not allow whitespace or punctuation" $ do + shouldNotParse @TransportHost "smp,simplex.im" "endOfInput" + shouldNotParse @TransportHost "smp:simplex.im" "endOfInput" + shouldNotParse @TransportHost "smp#simplex.im" "endOfInput" + shouldNotParse @TransportHost "smp simplex.im" "endOfInput" + shouldNotParse @TransportHost "smp\nsimplex.im" "endOfInput" + describe "onion hosts" $ do + it "should encode / decode onion host" $ THOnionHost "beccx4yfxxbvyhqypaavemqurytl6hozr47wfc7uuecacjqdvwpw2xid.onion" #==# "beccx4yfxxbvyhqypaavemqurytl6hozr47wfc7uuecacjqdvwpw2xid.onion" + it "should only allow latin letters and digits" $ do + shouldNotParse @TransportHost "beccx4yfxxbvyhqypaavemqurytl 6hozr47wfc7uuecacjqdvwpw2xid.onion" "endOfInput" + shouldNotParse @TransportHost "beccx4yfxxbvyhqypaavemqurytl\n6hozr47wfc7uuecacjqdvwpw2xid.onion" "endOfInput" + shouldNotParse @TransportHost "bèccx4yfxxbvyhqypaavemqurytl6hozr47wfc7uuecacjqdvwpw2xid.onion" "Failed reading: empty" + describe "IP address hosts" $ do + it "should encode / decode IP address" $ THIPv4 (192, 168, 0, 1) #==# "192.168.0.1" + it "should be valid" $ do + THDomainName "192.168.1" #==# "192.168.1" + THDomainName "192.256.0.1" #==# "192.256.0.1" + THDomainName "192.168.0.-1" #==# "192.168.0.-1" + shouldNotParse @TransportHost "192.168.0.0.1" "endOfInput" where testSystemTime :: SystemTime -> Expectation testSystemTime t = do smpEncode t `shouldBe` smpEncode (systemSeconds t) - parseAll smpP (smpEncode t) `shouldBe` Right t {systemNanoseconds = 0} + smpDecode (smpEncode t) `shouldBe` Right t {systemNanoseconds = 0} + (#==#) :: (StrEncoding s, Eq s, Show s) => s -> ByteString -> Expectation + (#==#) x s = do + strEncode x `shouldBe` s + strDecode s `shouldBe` Right x + shouldNotParse :: forall s. (StrEncoding s, Eq s, Show s) => ByteString -> String -> Expectation + shouldNotParse s err = strDecode s `shouldBe` (Left err :: Either String s)