{-# LANGUAGE OverloadedStrings #-}

module Network.TLS.Handshake.Server.TLS12 (
    recvClientSecondFlight12,
) where

import Control.Monad.State.Strict (gets)
import qualified Data.ByteString as B

import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.Server.Common
import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.State
import Network.TLS.IO
import Network.TLS.Imports
import Network.TLS.Packet hiding (getSession)
import Network.TLS.Parameters
import Network.TLS.Session
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Types
import Network.TLS.X509 hiding (Certificate)

----------------------------------------------------------------

recvClientSecondFlight12
    :: ServerParams
    -> Context
    -> Maybe SessionData
    -> IO ()
recvClientSecondFlight12 :: ServerParams -> Context -> Maybe SessionData -> IO ()
recvClientSecondFlight12 ServerParams
sparams Context
ctx Maybe SessionData
resumeSessionData = do
    case Maybe SessionData
resumeSessionData of
        Maybe SessionData
Nothing -> do
            ServerParams -> Context -> IO ()
recvClientCCC ServerParams
sparams Context
ctx
            mticket <- Context -> IO (Maybe ByteString)
sessionEstablished Context
ctx
            case mticket of
                Maybe ByteString
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                Just ByteString
ticket -> do
                    let life :: Second
life = Int -> Second
forall {a} {a}. (Num a, Integral a) => a -> a
adjustLifetime (Int -> Second) -> Int -> Second
forall a b. (a -> b) -> a -> b
$ ServerParams -> Int
serverTicketLifetime ServerParams
sparams
                    Context -> Packet -> IO ()
sendPacket12 Context
ctx (Packet -> IO ()) -> Packet -> IO ()
forall a b. (a -> b) -> a -> b
$ [Handshake] -> Packet
Handshake [Second -> ByteString -> Handshake
NewSessionTicket Second
life ByteString
ticket]
            sendCCSandFinished ctx ServerRole
        Just SessionData
_ -> do
            _ <- Context -> IO (Maybe ByteString)
sessionEstablished Context
ctx
            recvCCSandFinished ctx
    Context -> IO ()
handshakeDone12 Context
ctx
  where
    adjustLifetime :: a -> a
adjustLifetime a
i
        | a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = a
0
        | a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
604800 = a
604800
        | Bool
otherwise = a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
i

sessionEstablished :: Context -> IO (Maybe Ticket)
sessionEstablished :: Context -> IO (Maybe ByteString)
sessionEstablished Context
ctx = do
    session <- Context -> TLSSt Session -> IO Session
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Session
getSession
    -- only callback the session established if we have a session
    case session of
        Session (Just ByteString
sessionId) -> do
            sessionData <- Context -> IO (Maybe SessionData)
getSessionData Context
ctx
            let sessionId' = ByteString -> ByteString
B.copy ByteString
sessionId
            -- SessionID method: SessionID is used as key to store
            -- SessionData. Nothing is returned.
            --
            -- Session ticket method: SessionID is ignored. SessionData
            -- is encrypted and returned.
            sessionEstablish
                (sharedSessionManager $ ctxShared ctx)
                sessionId'
                (fromJust sessionData)
        Session
_ -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing -- never reach

----------------------------------------------------------------

-- | receive Client data in handshake until the Finished handshake.
--
--      <- [certificate]
--      <- client key xchg
--      <- [cert verify]
--      <- change cipher
--      <- finish
recvClientCCC :: ServerParams -> Context -> IO ()
recvClientCCC :: ServerParams -> Context -> IO ()
recvClientCCC ServerParams
sparams Context
ctx = Context -> RecvState IO -> IO ()
runRecvState Context
ctx ((Handshake -> IO (RecvState IO)) -> RecvState IO
forall (m :: * -> *). (Handshake -> m (RecvState m)) -> RecvState m
RecvStateHandshake Handshake -> IO (RecvState IO)
expectClientCertificate)
  where
    expectClientCertificate :: Handshake -> IO (RecvState IO)
expectClientCertificate (Certificate CertificateChain
certs) = do
        ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate ServerParams
sparams Context
ctx CertificateChain
certs
        Context -> Role -> CertificateChain -> IO ()
processCertificate Context
ctx Role
ServerRole CertificateChain
certs

        -- FIXME: We should check whether the certificate
        -- matches our request and that we support
        -- verifying with that certificate.

        RecvState IO -> IO (RecvState IO)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecvState IO -> IO (RecvState IO))
-> RecvState IO -> IO (RecvState IO)
forall a b. (a -> b) -> a -> b
$ (Handshake -> IO (RecvState IO)) -> RecvState IO
forall (m :: * -> *). (Handshake -> m (RecvState m)) -> RecvState m
RecvStateHandshake ((Handshake -> IO (RecvState IO)) -> RecvState IO)
-> (Handshake -> IO (RecvState IO)) -> RecvState IO
forall a b. (a -> b) -> a -> b
$ Bool -> Handshake -> IO (RecvState IO)
expectClientKeyExchange Bool
True
    expectClientCertificate Handshake
p = Bool -> Handshake -> IO (RecvState IO)
expectClientKeyExchange Bool
False Handshake
p

    -- cannot use RecvStateHandshake, as the next message could be a ChangeCipher,
    -- so we must process any packet, and in case of handshake call processHandshake manually.
    expectClientKeyExchange :: Bool -> Handshake -> IO (RecvState IO)
expectClientKeyExchange Bool
followedCertVerify (ClientKeyXchg ClientKeyXchgAlgorithmData
ckx) = do
        Context -> ClientKeyXchgAlgorithmData -> IO ()
processClientKeyXchg Context
ctx ClientKeyXchgAlgorithmData
ckx
        if Bool
followedCertVerify
            then RecvState IO -> IO (RecvState IO)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecvState IO -> IO (RecvState IO))
-> RecvState IO -> IO (RecvState IO)
forall a b. (a -> b) -> a -> b
$ (Handshake -> IO (RecvState IO)) -> RecvState IO
forall (m :: * -> *). (Handshake -> m (RecvState m)) -> RecvState m
RecvStateHandshake Handshake -> IO (RecvState IO)
expectCertificateVerify
            else RecvState IO -> IO (RecvState IO)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecvState IO -> IO (RecvState IO))
-> RecvState IO -> IO (RecvState IO)
forall a b. (a -> b) -> a -> b
$ (Packet -> IO (RecvState IO)) -> RecvState IO
forall (m :: * -> *). (Packet -> m (RecvState m)) -> RecvState m
RecvStatePacket ((Packet -> IO (RecvState IO)) -> RecvState IO)
-> (Packet -> IO (RecvState IO)) -> RecvState IO
forall a b. (a -> b) -> a -> b
$ Context -> Packet -> IO (RecvState IO)
expectChangeCipherSpec Context
ctx
    expectClientKeyExchange Bool
_ Handshake
p = String -> Maybe String -> IO (RecvState IO)
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake -> String
forall a. Show a => a -> String
show Handshake
p) (String -> Maybe String
forall a. a -> Maybe a
Just String
"client key exchange")

    expectCertificateVerify :: Handshake -> IO (RecvState IO)
expectCertificateVerify (CertVerify DigitallySigned
dsig) = do
        certs <- Context -> String -> IO CertificateChain
forall (m :: * -> *).
MonadIO m =>
Context -> String -> m CertificateChain
checkValidClientCertChain Context
ctx String
"change cipher message expected"

        usedVersion <- usingState_ ctx getVersion
        -- Fetch all handshake messages up to now.
        msgs <- usingHState ctx $ B.concat <$> getHandshakeMessages

        pubKey <- usingHState ctx getRemotePublicKey
        checkDigitalSignatureKey usedVersion pubKey

        verif <- checkCertificateVerify ctx usedVersion pubKey msgs dsig
        processClientCertVerify sparams ctx certs verif
        return $ RecvStatePacket $ expectChangeCipherSpec ctx
    expectCertificateVerify Handshake
p = String -> Maybe String -> IO (RecvState IO)
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake -> String
forall a. Show a => a -> String
show Handshake
p) (String -> Maybe String
forall a. a -> Maybe a
Just String
"client certificate verify")

----------------------------------------------------------------

expectChangeCipherSpec :: Context -> Packet -> IO (RecvState IO)
expectChangeCipherSpec :: Context -> Packet -> IO (RecvState IO)
expectChangeCipherSpec Context
ctx Packet
ChangeCipherSpec = do
    RecvState IO -> IO (RecvState IO)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecvState IO -> IO (RecvState IO))
-> RecvState IO -> IO (RecvState IO)
forall a b. (a -> b) -> a -> b
$ (Handshake -> IO (RecvState IO)) -> RecvState IO
forall (m :: * -> *). (Handshake -> m (RecvState m)) -> RecvState m
RecvStateHandshake ((Handshake -> IO (RecvState IO)) -> RecvState IO)
-> (Handshake -> IO (RecvState IO)) -> RecvState IO
forall a b. (a -> b) -> a -> b
$ Context -> Handshake -> IO (RecvState IO)
expectFinished Context
ctx
expectChangeCipherSpec Context
_ Packet
p = String -> Maybe String -> IO (RecvState IO)
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Packet -> String
forall a. Show a => a -> String
show Packet
p) (String -> Maybe String
forall a. a -> Maybe a
Just String
"change cipher")

----------------------------------------------------------------

-- process the client key exchange message. the protocol expects the initial
-- client version received in ClientHello, not the negotiated version.
-- in case the version mismatch, generate a random main secret
processClientKeyXchg :: Context -> ClientKeyXchgAlgorithmData -> IO ()
processClientKeyXchg :: Context -> ClientKeyXchgAlgorithmData -> IO ()
processClientKeyXchg Context
ctx (CKX_RSA ByteString
encryptedPreMain) = do
    (rver, role, random) <- Context
-> TLSSt (Version, Role, ByteString)
-> IO (Version, Role, ByteString)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt (Version, Role, ByteString)
 -> IO (Version, Role, ByteString))
-> TLSSt (Version, Role, ByteString)
-> IO (Version, Role, ByteString)
forall a b. (a -> b) -> a -> b
$ do
        (,,) (Version -> Role -> ByteString -> (Version, Role, ByteString))
-> TLSSt Version
-> TLSSt (Role -> ByteString -> (Version, Role, ByteString))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TLSSt Version
getVersion TLSSt (Role -> ByteString -> (Version, Role, ByteString))
-> TLSSt Role -> TLSSt (ByteString -> (Version, Role, ByteString))
forall a b. TLSSt (a -> b) -> TLSSt a -> TLSSt b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TLSSt Role
getRole TLSSt (ByteString -> (Version, Role, ByteString))
-> TLSSt ByteString -> TLSSt (Version, Role, ByteString)
forall a b. TLSSt (a -> b) -> TLSSt a -> TLSSt b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> TLSSt ByteString
genRandom Int
48
    ePreMain <- decryptRSA ctx encryptedPreMain
    mainSecret <- usingHState ctx $ do
        expectedVer <- gets hstClientVersion
        case ePreMain of
            Left KxError
_ -> Version -> Role -> ByteString -> HandshakeM ByteString
forall preMain.
ByteArrayAccess preMain =>
Version -> Role -> preMain -> HandshakeM ByteString
setMainSecretFromPre Version
rver Role
role ByteString
random
            Right ByteString
preMain -> case ByteString -> Either TLSError (Version, ByteString)
decodePreMainSecret ByteString
preMain of
                Left TLSError
_ -> Version -> Role -> ByteString -> HandshakeM ByteString
forall preMain.
ByteArrayAccess preMain =>
Version -> Role -> preMain -> HandshakeM ByteString
setMainSecretFromPre Version
rver Role
role ByteString
random
                Right (Version
ver, ByteString
_)
                    | Version
ver Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
/= Version
expectedVer -> Version -> Role -> ByteString -> HandshakeM ByteString
forall preMain.
ByteArrayAccess preMain =>
Version -> Role -> preMain -> HandshakeM ByteString
setMainSecretFromPre Version
rver Role
role ByteString
random
                    | Bool
otherwise -> Version -> Role -> ByteString -> HandshakeM ByteString
forall preMain.
ByteArrayAccess preMain =>
Version -> Role -> preMain -> HandshakeM ByteString
setMainSecretFromPre Version
rver Role
role ByteString
preMain
    logKey ctx (MainSecret mainSecret)
processClientKeyXchg Context
ctx (CKX_DH DHPublic
clientDHValue) = do
    rver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
    role <- usingState_ ctx getRole

    serverParams <- usingHState ctx getServerDHParams
    let params = ServerDHParams -> DHParams
serverDHParamsToParams ServerDHParams
serverParams
    unless (dhValid params $ dhUnwrapPublic clientDHValue) $
        throwCore $
            Error_Protocol "invalid client public key" IllegalParameter

    dhpriv <- usingHState ctx getDHPrivate
    let preMain = DHParams -> DHPrivate -> DHPublic -> DHKey
dhGetShared DHParams
params DHPrivate
dhpriv DHPublic
clientDHValue
    mainSecret <- usingHState ctx $ setMainSecretFromPre rver role preMain
    logKey ctx (MainSecret mainSecret)
processClientKeyXchg Context
ctx (CKX_ECDH ByteString
bytes) = do
    ServerECDHParams grp _ <- Context -> HandshakeM ServerECDHParams -> IO ServerECDHParams
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM ServerECDHParams
getServerECDHParams
    case decodeGroupPublic grp bytes of
        Left CryptoError
_ ->
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol String
"client public key cannot be decoded" AlertDescription
IllegalParameter
        Right GroupPublic
clipub -> do
            srvpri <- Context -> HandshakeM GroupPrivate -> IO GroupPrivate
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM GroupPrivate
getGroupPrivate
            case groupGetShared clipub srvpri of
                Just GroupKey
preMain -> do
                    rver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
                    role <- usingState_ ctx getRole
                    mainSecret <- usingHState ctx $ setMainSecretFromPre rver role preMain
                    logKey ctx (MainSecret mainSecret)
                Maybe GroupKey
Nothing ->
                    TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                        String -> AlertDescription -> TLSError
Error_Protocol String
"cannot generate a shared secret on ECDH" AlertDescription
IllegalParameter

----------------------------------------------------------------

processClientCertVerify
    :: ServerParams -> Context -> CertificateChain -> Bool -> IO ()
processClientCertVerify :: ServerParams -> Context -> CertificateChain -> Bool -> IO ()
processClientCertVerify ServerParams
_sparams Context
ctx CertificateChain
certs Bool
True = do
    -- When verification succeeds, commit the
    -- client certificate chain to the context.
    --
    Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ CertificateChain -> TLSSt ()
setClientCertificateChain CertificateChain
certs
    () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
processClientCertVerify ServerParams
sparams Context
ctx CertificateChain
certs Bool
False = do
    -- Either verification failed because of an
    -- invalid format (with an error message), or
    -- the signature is wrong.  In either case,
    -- ask the application if it wants to
    -- proceed, we will do that.
    res <- ServerHooks -> IO Bool
onUnverifiedClientCert (ServerParams -> ServerHooks
serverHooks ServerParams
sparams)
    if res
        then do
            -- When verification fails, but the
            -- application callbacks accepts, we
            -- also commit the client certificate
            -- chain to the context.
            usingState_ ctx $ setClientCertificateChain certs
        else decryptError "verification failed"

----------------------------------------------------------------

recvCCSandFinished :: Context -> IO ()
recvCCSandFinished :: Context -> IO ()
recvCCSandFinished Context
ctx = Context -> RecvState IO -> IO ()
runRecvState Context
ctx (RecvState IO -> IO ()) -> RecvState IO -> IO ()
forall a b. (a -> b) -> a -> b
$ (Packet -> IO (RecvState IO)) -> RecvState IO
forall (m :: * -> *). (Packet -> m (RecvState m)) -> RecvState m
RecvStatePacket ((Packet -> IO (RecvState IO)) -> RecvState IO)
-> (Packet -> IO (RecvState IO)) -> RecvState IO
forall a b. (a -> b) -> a -> b
$ Context -> Packet -> IO (RecvState IO)
expectChangeCipherSpec Context
ctx