{-|
Module      : PostgresWebsockets.Claims
Description : Parse and validate JWT to open postgres-websockets channels.

This module provides the JWT claims validation. Since websockets and
listening connections in the database tend to be resource intensive
(not to mention stateful) we need claims authorizing a specific channel and
mode of operation.
-}
module PostgresWebsockets.Claims
  ( ConnectionInfo,validateClaims
  ) where

import Protolude hiding (toS)
import Protolude.Conv
import Control.Lens
import Crypto.JWT
import Data.List
import Data.Time.Clock (UTCTime)
import qualified Crypto.JOSE.Types as JOSE.Types
import qualified Data.HashMap.Strict as M
import qualified Data.Aeson as JSON


type Claims = M.HashMap Text JSON.Value
type ConnectionInfo = ([Text], Text, Claims)

{-| Given a secret, a token and a timestamp it validates the claims and returns
    either an error message or a triple containing channel, mode and claims hashmap.
-}
validateClaims
  :: Maybe Text
  -> ByteString
  -> LByteString
  -> UTCTime
  -> IO (Either Text ConnectionInfo)
validateClaims :: Maybe Text
-> ByteString
-> LByteString
-> UTCTime
-> IO (Either Text ConnectionInfo)
validateClaims Maybe Text
requestChannel ByteString
secret LByteString
jwtToken UTCTime
time = ExceptT Text IO ConnectionInfo -> IO (Either Text ConnectionInfo)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT Text IO ConnectionInfo -> IO (Either Text ConnectionInfo))
-> ExceptT Text IO ConnectionInfo
-> IO (Either Text ConnectionInfo)
forall a b. (a -> b) -> a -> b
$ do
  JWTAttempt
cl  <- IO JWTAttempt -> ExceptT Text IO JWTAttempt
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO JWTAttempt -> ExceptT Text IO JWTAttempt)
-> IO JWTAttempt -> ExceptT Text IO JWTAttempt
forall a b. (a -> b) -> a -> b
$ UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims UTCTime
time (ByteString -> JWK
parseJWK ByteString
secret) LByteString
jwtToken
  HashMap Text Value
cl' <- case JWTAttempt
cl of
    JWTClaims  HashMap Text Value
c          -> HashMap Text Value -> ExceptT Text IO (HashMap Text Value)
forall (f :: * -> *) a. Applicative f => a -> f a
pure HashMap Text Value
c
    JWTInvalid JWTError
JWTExpired -> Text -> ExceptT Text IO (HashMap Text Value)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"Token expired"
    JWTInvalid JWTError
err -> Text -> ExceptT Text IO (HashMap Text Value)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text IO (HashMap Text Value))
-> Text -> ExceptT Text IO (HashMap Text Value)
forall a b. (a -> b) -> a -> b
$ Text
"Error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> JWTError -> Text
forall a b. (Show a, ConvertText String b) => a -> b
show JWTError
err
  [Text]
channels  <-  let chs :: Maybe [Text]
chs = Text -> HashMap Text Value -> Maybe [Text]
claimAsJSONList Text
"channels" HashMap Text Value
cl' in [Text] -> ExceptT Text IO [Text]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Text] -> ExceptT Text IO [Text])
-> [Text] -> ExceptT Text IO [Text]
forall a b. (a -> b) -> a -> b
$ case Text -> HashMap Text Value -> Maybe Text
claimAsJSON Text
"channel" HashMap Text Value
cl' of
    Just Text
c ->  case Maybe [Text]
chs of
      Just [Text]
cs ->  [Text] -> [Text]
forall a. Eq a => [a] -> [a]
nub (Text
c Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
cs)
      Maybe [Text]
Nothing ->  [Text
c]
    Maybe Text
Nothing -> [Text] -> Maybe [Text] -> [Text]
forall a. a -> Maybe a -> a
fromMaybe [] Maybe [Text]
chs
  Text
mode <-
    let md :: Maybe Text
md = Text -> HashMap Text Value -> Maybe Text
claimAsJSON Text
"mode" HashMap Text Value
cl'
    in case Maybe Text
md of
          Just Text
m  -> Text -> ExceptT Text IO Text
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
m
          Maybe Text
Nothing -> Text -> ExceptT Text IO Text
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"Missing mode"
  [Text]
requestedAllowedChannels <- case (Maybe Text
requestChannel, [Text] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
channels) of
    (Just Text
rc, Int
0) -> [Text] -> ExceptT Text IO [Text]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text
rc]
    (Just Text
rc, Int
_) -> [Text] -> ExceptT Text IO [Text]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Text] -> ExceptT Text IO [Text])
-> [Text] -> ExceptT Text IO [Text]
forall a b. (a -> b) -> a -> b
$ (Text -> Bool) -> [Text] -> [Text]
forall a. (a -> Bool) -> [a] -> [a]
filter (Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
rc) [Text]
channels
    (Maybe Text
Nothing, Int
_) -> [Text] -> ExceptT Text IO [Text]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text]
channels
  [Text]
validChannels <- if [Text] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Text]
requestedAllowedChannels then Text -> ExceptT Text IO [Text]
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"No allowed channels" else [Text] -> ExceptT Text IO [Text]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text]
requestedAllowedChannels
  ConnectionInfo -> ExceptT Text IO ConnectionInfo
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Text]
validChannels, Text
mode, HashMap Text Value
cl')

 where
  claimAsJSON :: Text -> Claims -> Maybe Text
  claimAsJSON :: Text -> HashMap Text Value -> Maybe Text
claimAsJSON Text
name HashMap Text Value
cl = case Text -> HashMap Text Value -> Maybe Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Text
name HashMap Text Value
cl of
    Just (JSON.String Text
s) -> Text -> Maybe Text
forall a. a -> Maybe a
Just Text
s
    Maybe Value
_ -> Maybe Text
forall a. Maybe a
Nothing

  claimAsJSONList :: Text -> Claims -> Maybe [Text]
  claimAsJSONList :: Text -> HashMap Text Value -> Maybe [Text]
claimAsJSONList Text
name HashMap Text Value
cl = case Text -> HashMap Text Value -> Maybe Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Text
name HashMap Text Value
cl of
    Just Value
channelsJson ->
      case Value -> Result [Text]
forall a. FromJSON a => Value -> Result a
JSON.fromJSON Value
channelsJson :: JSON.Result [Text] of
        JSON.Success [Text]
channelsList -> [Text] -> Maybe [Text]
forall a. a -> Maybe a
Just [Text]
channelsList
        Result [Text]
_ -> Maybe [Text]
forall a. Maybe a
Nothing
    Maybe Value
Nothing -> Maybe [Text]
forall a. Maybe a
Nothing

{-|
  Possible situations encountered with client JWTs
-}
data JWTAttempt = JWTInvalid JWTError
                | JWTClaims (M.HashMap Text JSON.Value)
                deriving JWTAttempt -> JWTAttempt -> Bool
(JWTAttempt -> JWTAttempt -> Bool)
-> (JWTAttempt -> JWTAttempt -> Bool) -> Eq JWTAttempt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTAttempt -> JWTAttempt -> Bool
$c/= :: JWTAttempt -> JWTAttempt -> Bool
== :: JWTAttempt -> JWTAttempt -> Bool
$c== :: JWTAttempt -> JWTAttempt -> Bool
Eq

{-|
  Receives the JWT secret (from config) and a JWT and returns a map
  of JWT claims.
-}
jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims UTCTime
_ JWK
_ LByteString
"" = JWTAttempt -> IO JWTAttempt
forall (m :: * -> *) a. Monad m => a -> m a
return (JWTAttempt -> IO JWTAttempt) -> JWTAttempt -> IO JWTAttempt
forall a b. (a -> b) -> a -> b
$ HashMap Text Value -> JWTAttempt
JWTClaims HashMap Text Value
forall k v. HashMap k v
M.empty
jwtClaims UTCTime
time JWK
jwk' LByteString
payload = do
  let config :: JWTValidationSettings
config = (StringOrURI -> Bool) -> JWTValidationSettings
defaultJWTValidationSettings (Bool -> StringOrURI -> Bool
forall a b. a -> b -> a
const Bool
True)
  Either JWTError ClaimsSet
eJwt <- ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet))
-> ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall a b. (a -> b) -> a -> b
$ do
    SignedJWT
jwt <- LByteString -> ExceptT JWTError IO SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
LByteString -> m a
decodeCompact LByteString
payload
    JWTValidationSettings
-> JWK -> UTCTime -> SignedJWT -> ExceptT JWTError IO ClaimsSet
forall a e (m :: * -> *) k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
 HasCheckIssuedAt a, HasValidationSettings a, AsError e,
 AsJWTError e, MonadError e m,
 VerificationKeyStore
   (ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k) =>
a -> k -> UTCTime -> SignedJWT -> m ClaimsSet
verifyClaimsAt JWTValidationSettings
config JWK
jwk' UTCTime
time SignedJWT
jwt
  JWTAttempt -> IO JWTAttempt
forall (m :: * -> *) a. Monad m => a -> m a
return (JWTAttempt -> IO JWTAttempt) -> JWTAttempt -> IO JWTAttempt
forall a b. (a -> b) -> a -> b
$ case Either JWTError ClaimsSet
eJwt of
    Left JWTError
e    -> JWTError -> JWTAttempt
JWTInvalid JWTError
e
    Right ClaimsSet
jwt -> HashMap Text Value -> JWTAttempt
JWTClaims (HashMap Text Value -> JWTAttempt)
-> (ClaimsSet -> HashMap Text Value) -> ClaimsSet -> JWTAttempt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClaimsSet -> HashMap Text Value
claims2map (ClaimsSet -> JWTAttempt) -> ClaimsSet -> JWTAttempt
forall a b. (a -> b) -> a -> b
$ ClaimsSet
jwt

{-|
  Internal helper used to turn JWT ClaimSet into something
  easier to work with
-}
claims2map :: ClaimsSet -> M.HashMap Text JSON.Value
claims2map :: ClaimsSet -> HashMap Text Value
claims2map = Value -> HashMap Text Value
val2map (Value -> HashMap Text Value)
-> (ClaimsSet -> Value) -> ClaimsSet -> HashMap Text Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClaimsSet -> Value
forall a. ToJSON a => a -> Value
JSON.toJSON
 where
  val2map :: Value -> HashMap Text Value
val2map (JSON.Object HashMap Text Value
o) = HashMap Text Value
o
  val2map Value
_          = HashMap Text Value
forall k v. HashMap k v
M.empty

{-|
  Internal helper to generate HMAC-SHA256. When the jwt key in the
  config file is a simple string rather than a JWK object, we'll
  apply this function to it.
-}
hs256jwk :: ByteString -> JWK
hs256jwk :: ByteString -> JWK
hs256jwk ByteString
key =
  KeyMaterial -> JWK
fromKeyMaterial KeyMaterial
km
    JWK -> (JWK -> JWK) -> JWK
forall a b. a -> (a -> b) -> b
& (Maybe KeyUse -> Identity (Maybe KeyUse)) -> JWK -> Identity JWK
Lens' JWK (Maybe KeyUse)
jwkUse ((Maybe KeyUse -> Identity (Maybe KeyUse)) -> JWK -> Identity JWK)
-> KeyUse -> JWK -> JWK
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ KeyUse
Sig
    JWK -> (JWK -> JWK) -> JWK
forall a b. a -> (a -> b) -> b
& (Maybe JWKAlg -> Identity (Maybe JWKAlg)) -> JWK -> Identity JWK
Lens' JWK (Maybe JWKAlg)
jwkAlg ((Maybe JWKAlg -> Identity (Maybe JWKAlg)) -> JWK -> Identity JWK)
-> JWKAlg -> JWK -> JWK
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Alg -> JWKAlg
JWSAlg Alg
HS256
 where
  km :: KeyMaterial
km = OctKeyParameters -> KeyMaterial
OctKeyMaterial (Base64Octets -> OctKeyParameters
OctKeyParameters (ByteString -> Base64Octets
JOSE.Types.Base64Octets ByteString
key))

parseJWK :: ByteString -> JWK
parseJWK :: ByteString -> JWK
parseJWK ByteString
str =
  JWK -> Maybe JWK -> JWK
forall a. a -> Maybe a -> a
fromMaybe (ByteString -> JWK
hs256jwk ByteString
str) (LByteString -> Maybe JWK
forall a. FromJSON a => LByteString -> Maybe a
JSON.decode (ByteString -> LByteString
forall a b. StringConv a b => a -> b
toS ByteString
str) :: Maybe JWK)