{-# LANGUAGE RecordWildCards #-}
module Network.TLS.Handshake.Server.ClientHello (
processClientHello,
) where
import Network.TLS.Context.Internal
import Network.TLS.Extension
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Process
import Network.TLS.Imports
import Network.TLS.Measurement
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
processClientHello
:: ServerParams -> Context -> Handshake -> IO (Version, CH)
processClientHello :: ServerParams -> Context -> Handshake -> IO (Version, CH)
processClientHello ServerParams
sparams Context
ctx clientHello :: Handshake
clientHello@(ClientHello Version
legacyVersion ClientRandom
cran [CompressionID]
compressions ch :: CH
ch@CH{[CipherID]
[ExtensionRaw]
Session
chSession :: Session
chCiphers :: [CipherID]
chExtensions :: [ExtensionRaw]
chExtensions :: CH -> [ExtensionRaw]
chCiphers :: CH -> [CipherID]
chSession :: CH -> Session
..}) = do
established <- Context -> IO Established
ctxEstablished Context
ctx
when (established /= NotEstablished) $ do
ver <- usingState_ ctx (getVersionWithDefault TLS12)
when (ver == TLS13) $
throwCore $
Error_Protocol "renegotiation is not allowed in TLS 1.3" UnexpectedMessage
eof <- ctxEOF ctx
let renegotiation = Established
established Established -> Established -> Bool
forall a. Eq a => a -> a -> Bool
== Established
Established Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
eof
when
(renegotiation && not (supportedClientInitiatedRenegotiation $ ctxSupported ctx))
$ throwCore
$ Error_Protocol_Warning "renegotiation is not allowed" NoRenegotiation
handshakeAuthorized <- withMeasure ctx (onNewHandshake $ serverHooks sparams)
unless
handshakeAuthorized
(throwCore $ Error_HandshakePolicy "server: handshake denied")
updateMeasure ctx incrementNbHandshakes
hrr <- usingState_ ctx getTLS13HRR
unless hrr $ startHandshake ctx legacyVersion cran
processHandshake12 ctx clientHello
when (legacyVersion /= TLS12) $
throwCore $
Error_Protocol (show legacyVersion ++ " is not supported") ProtocolVersion
when
( supportedFallbackScsv (ctxSupported ctx)
&& (0x5600 `elem` chCiphers)
&& legacyVersion < TLS12
)
$ throwCore
$ Error_Protocol "fallback is not allowed" InappropriateFallback
let clientVersions = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_SupportedVersions [ExtensionRaw]
chExtensions
Maybe ByteString
-> (ByteString -> Maybe SupportedVersions)
-> Maybe SupportedVersions
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe SupportedVersions
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
Just (SupportedVersionsClientHello [Version]
vers) -> [Version]
vers
Maybe SupportedVersions
_ -> []
clientVersion = Version -> Version -> Version
forall a. Ord a => a -> a -> a
min Version
TLS12 Version
legacyVersion
serverVersions
| Bool
renegotiation = (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13) (Supported -> [Version]
supportedVersions (Supported -> [Version]) -> Supported -> [Version]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx)
| Bool
otherwise = Supported -> [Version]
supportedVersions (Supported -> [Version]) -> Supported -> [Version]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx
mVersion = DebugParams -> Maybe Version
debugVersionForced (DebugParams -> Maybe Version) -> DebugParams -> Maybe Version
forall a b. (a -> b) -> a -> b
$ ServerParams -> DebugParams
serverDebug ServerParams
sparams
chosenVersion <- case mVersion of
Just Version
cver -> Version -> IO Version
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
cver
Maybe Version
Nothing ->
if (Version
TLS13 Version -> [Version] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Version]
serverVersions) Bool -> Bool -> Bool
&& [Version]
clientVersions [Version] -> [Version] -> Bool
forall a. Eq a => a -> a -> Bool
/= []
then case [Version] -> [Version] -> Maybe Version
findHighestVersionFrom13 [Version]
clientVersions [Version]
serverVersions of
Maybe Version
Nothing ->
TLSError -> IO Version
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO Version) -> TLSError -> IO Version
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol
(String
"client versions " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Version] -> String
forall a. Show a => a -> String
show [Version]
clientVersions String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is not supported")
AlertDescription
ProtocolVersion
Just Version
v -> Version -> IO Version
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
v
else case Version -> [Version] -> Maybe Version
findHighestVersionFrom Version
clientVersion [Version]
serverVersions of
Maybe Version
Nothing ->
TLSError -> IO Version
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO Version) -> TLSError -> IO Version
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol
(String
"client version " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Version -> String
forall a. Show a => a -> String
show Version
clientVersion String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is not supported")
AlertDescription
ProtocolVersion
Just Version
v -> Version -> IO Version
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
v
let serverName = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_ServerName [ExtensionRaw]
chExtensions Maybe ByteString
-> (ByteString -> Maybe ServerName) -> Maybe ServerName
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe ServerName
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
Just (ServerName [ServerNameType]
ns) -> [String] -> Maybe String
forall a. [a] -> Maybe a
listToMaybe ((ServerNameType -> Maybe String) -> [ServerNameType] -> [String]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ServerNameType -> Maybe String
toHostName [ServerNameType]
ns)
where
toHostName :: ServerNameType -> Maybe String
toHostName (ServerNameHostName String
hostName) = String -> Maybe String
forall a. a -> Maybe a
Just String
hostName
toHostName (ServerNameOther (CompressionID, ByteString)
_) = Maybe String
forall a. Maybe a
Nothing
Maybe ServerName
_ -> Maybe String
forall a. Maybe a
Nothing
when (chosenVersion == TLS13) $ do
mapM_ ensureNullCompression compressions
maybe (return ()) (usingState_ ctx . setClientSNI) serverName
return (chosenVersion, ch)
processClientHello ServerParams
_ Context
_ Handshake
_ =
TLSError -> IO (Version, CH)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO (Version, CH)) -> TLSError -> IO (Version, CH)
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol
String
"unexpected handshake message received in handshakeServerWith"
AlertDescription
HandshakeFailure
findHighestVersionFrom :: Version -> [Version] -> Maybe Version
findHighestVersionFrom :: Version -> [Version] -> Maybe Version
findHighestVersionFrom Version
clientVersion [Version]
allowedVersions =
case (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Version
clientVersion Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>=) ([Version] -> [Version]) -> [Version] -> [Version]
forall a b. (a -> b) -> a -> b
$ (Version -> Down Version) -> [Version] -> [Version]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn Version -> Down Version
forall a. a -> Down a
Down [Version]
allowedVersions of
[] -> Maybe Version
forall a. Maybe a
Nothing
Version
v : [Version]
_ -> Version -> Maybe Version
forall a. a -> Maybe a
Just Version
v
findHighestVersionFrom13 :: [Version] -> [Version] -> Maybe Version
findHighestVersionFrom13 :: [Version] -> [Version] -> Maybe Version
findHighestVersionFrom13 [Version]
clientVersions [Version]
serverVersions = case [Version]
svs [Version] -> [Version] -> [Version]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
cvs of
[] -> Maybe Version
forall a. Maybe a
Nothing
Version
v : [Version]
_ -> Version -> Maybe Version
forall a. a -> Maybe a
Just Version
v
where
svs :: [Version]
svs = (Version -> Down Version) -> [Version] -> [Version]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn Version -> Down Version
forall a. a -> Down a
Down [Version]
serverVersions
cvs :: [Version]
cvs = (Version -> Down Version) -> [Version] -> [Version]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn Version -> Down Version
forall a. a -> Down a
Down ([Version] -> [Version]) -> [Version] -> [Version]
forall a b. (a -> b) -> a -> b
$ (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>= Version
TLS12) [Version]
clientVersions