module PostgresWebsockets.Claims
( ConnectionInfo,validateClaims
) where
import Protolude hiding (toS)
import Protolude.Conv
import Control.Lens
import Crypto.JWT
import Data.List
import Data.Time.Clock (UTCTime)
import qualified Crypto.JOSE.Types as JOSE.Types
import qualified Data.HashMap.Strict as M
import qualified Data.Aeson as JSON
type Claims = M.HashMap Text JSON.Value
type ConnectionInfo = ([Text], Text, Claims)
validateClaims
:: Maybe Text
-> ByteString
-> LByteString
-> UTCTime
-> IO (Either Text ConnectionInfo)
validateClaims :: Maybe Text
-> ByteString
-> LByteString
-> UTCTime
-> IO (Either Text ConnectionInfo)
validateClaims Maybe Text
requestChannel ByteString
secret LByteString
jwtToken UTCTime
time = ExceptT Text IO ConnectionInfo -> IO (Either Text ConnectionInfo)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT Text IO ConnectionInfo -> IO (Either Text ConnectionInfo))
-> ExceptT Text IO ConnectionInfo
-> IO (Either Text ConnectionInfo)
forall a b. (a -> b) -> a -> b
$ do
JWTAttempt
cl <- IO JWTAttempt -> ExceptT Text IO JWTAttempt
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO JWTAttempt -> ExceptT Text IO JWTAttempt)
-> IO JWTAttempt -> ExceptT Text IO JWTAttempt
forall a b. (a -> b) -> a -> b
$ UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims UTCTime
time (ByteString -> JWK
parseJWK ByteString
secret) LByteString
jwtToken
HashMap Text Value
cl' <- case JWTAttempt
cl of
JWTClaims HashMap Text Value
c -> HashMap Text Value -> ExceptT Text IO (HashMap Text Value)
forall (f :: * -> *) a. Applicative f => a -> f a
pure HashMap Text Value
c
JWTInvalid JWTError
JWTExpired -> Text -> ExceptT Text IO (HashMap Text Value)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"Token expired"
JWTInvalid JWTError
err -> Text -> ExceptT Text IO (HashMap Text Value)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text IO (HashMap Text Value))
-> Text -> ExceptT Text IO (HashMap Text Value)
forall a b. (a -> b) -> a -> b
$ Text
"Error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> JWTError -> Text
forall a b. (Show a, ConvertText String b) => a -> b
show JWTError
err
[Text]
channels <- let chs :: Maybe [Text]
chs = Text -> HashMap Text Value -> Maybe [Text]
claimAsJSONList Text
"channels" HashMap Text Value
cl' in [Text] -> ExceptT Text IO [Text]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Text] -> ExceptT Text IO [Text])
-> [Text] -> ExceptT Text IO [Text]
forall a b. (a -> b) -> a -> b
$ case Text -> HashMap Text Value -> Maybe Text
claimAsJSON Text
"channel" HashMap Text Value
cl' of
Just Text
c -> case Maybe [Text]
chs of
Just [Text]
cs -> [Text] -> [Text]
forall a. Eq a => [a] -> [a]
nub (Text
c Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
cs)
Maybe [Text]
Nothing -> [Text
c]
Maybe Text
Nothing -> [Text] -> Maybe [Text] -> [Text]
forall a. a -> Maybe a -> a
fromMaybe [] Maybe [Text]
chs
Text
mode <-
let md :: Maybe Text
md = Text -> HashMap Text Value -> Maybe Text
claimAsJSON Text
"mode" HashMap Text Value
cl'
in case Maybe Text
md of
Just Text
m -> Text -> ExceptT Text IO Text
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
m
Maybe Text
Nothing -> Text -> ExceptT Text IO Text
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"Missing mode"
[Text]
requestedAllowedChannels <- case (Maybe Text
requestChannel, [Text] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
channels) of
(Just Text
rc, Int
0) -> [Text] -> ExceptT Text IO [Text]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text
rc]
(Just Text
rc, Int
_) -> [Text] -> ExceptT Text IO [Text]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Text] -> ExceptT Text IO [Text])
-> [Text] -> ExceptT Text IO [Text]
forall a b. (a -> b) -> a -> b
$ (Text -> Bool) -> [Text] -> [Text]
forall a. (a -> Bool) -> [a] -> [a]
filter (Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
rc) [Text]
channels
(Maybe Text
Nothing, Int
_) -> [Text] -> ExceptT Text IO [Text]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text]
channels
[Text]
validChannels <- if [Text] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Text]
requestedAllowedChannels then Text -> ExceptT Text IO [Text]
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Text
"No allowed channels" else [Text] -> ExceptT Text IO [Text]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Text]
requestedAllowedChannels
ConnectionInfo -> ExceptT Text IO ConnectionInfo
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Text]
validChannels, Text
mode, HashMap Text Value
cl')
where
claimAsJSON :: Text -> Claims -> Maybe Text
claimAsJSON :: Text -> HashMap Text Value -> Maybe Text
claimAsJSON Text
name HashMap Text Value
cl = case Text -> HashMap Text Value -> Maybe Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Text
name HashMap Text Value
cl of
Just (JSON.String Text
s) -> Text -> Maybe Text
forall a. a -> Maybe a
Just Text
s
Maybe Value
_ -> Maybe Text
forall a. Maybe a
Nothing
claimAsJSONList :: Text -> Claims -> Maybe [Text]
claimAsJSONList :: Text -> HashMap Text Value -> Maybe [Text]
claimAsJSONList Text
name HashMap Text Value
cl = case Text -> HashMap Text Value -> Maybe Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Text
name HashMap Text Value
cl of
Just Value
channelsJson ->
case Value -> Result [Text]
forall a. FromJSON a => Value -> Result a
JSON.fromJSON Value
channelsJson :: JSON.Result [Text] of
JSON.Success [Text]
channelsList -> [Text] -> Maybe [Text]
forall a. a -> Maybe a
Just [Text]
channelsList
Result [Text]
_ -> Maybe [Text]
forall a. Maybe a
Nothing
Maybe Value
Nothing -> Maybe [Text]
forall a. Maybe a
Nothing
data JWTAttempt = JWTInvalid JWTError
| JWTClaims (M.HashMap Text JSON.Value)
deriving JWTAttempt -> JWTAttempt -> Bool
(JWTAttempt -> JWTAttempt -> Bool)
-> (JWTAttempt -> JWTAttempt -> Bool) -> Eq JWTAttempt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTAttempt -> JWTAttempt -> Bool
$c/= :: JWTAttempt -> JWTAttempt -> Bool
== :: JWTAttempt -> JWTAttempt -> Bool
$c== :: JWTAttempt -> JWTAttempt -> Bool
Eq
jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims UTCTime
_ JWK
_ LByteString
"" = JWTAttempt -> IO JWTAttempt
forall (m :: * -> *) a. Monad m => a -> m a
return (JWTAttempt -> IO JWTAttempt) -> JWTAttempt -> IO JWTAttempt
forall a b. (a -> b) -> a -> b
$ HashMap Text Value -> JWTAttempt
JWTClaims HashMap Text Value
forall k v. HashMap k v
M.empty
jwtClaims UTCTime
time JWK
jwk' LByteString
payload = do
let config :: JWTValidationSettings
config = (StringOrURI -> Bool) -> JWTValidationSettings
defaultJWTValidationSettings (Bool -> StringOrURI -> Bool
forall a b. a -> b -> a
const Bool
True)
Either JWTError ClaimsSet
eJwt <- ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet))
-> ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall a b. (a -> b) -> a -> b
$ do
SignedJWT
jwt <- LByteString -> ExceptT JWTError IO SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
LByteString -> m a
decodeCompact LByteString
payload
JWTValidationSettings
-> JWK -> UTCTime -> SignedJWT -> ExceptT JWTError IO ClaimsSet
forall a e (m :: * -> *) k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
HasCheckIssuedAt a, HasValidationSettings a, AsError e,
AsJWTError e, MonadError e m,
VerificationKeyStore
(ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k) =>
a -> k -> UTCTime -> SignedJWT -> m ClaimsSet
verifyClaimsAt JWTValidationSettings
config JWK
jwk' UTCTime
time SignedJWT
jwt
JWTAttempt -> IO JWTAttempt
forall (m :: * -> *) a. Monad m => a -> m a
return (JWTAttempt -> IO JWTAttempt) -> JWTAttempt -> IO JWTAttempt
forall a b. (a -> b) -> a -> b
$ case Either JWTError ClaimsSet
eJwt of
Left JWTError
e -> JWTError -> JWTAttempt
JWTInvalid JWTError
e
Right ClaimsSet
jwt -> HashMap Text Value -> JWTAttempt
JWTClaims (HashMap Text Value -> JWTAttempt)
-> (ClaimsSet -> HashMap Text Value) -> ClaimsSet -> JWTAttempt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClaimsSet -> HashMap Text Value
claims2map (ClaimsSet -> JWTAttempt) -> ClaimsSet -> JWTAttempt
forall a b. (a -> b) -> a -> b
$ ClaimsSet
jwt
claims2map :: ClaimsSet -> M.HashMap Text JSON.Value
claims2map :: ClaimsSet -> HashMap Text Value
claims2map = Value -> HashMap Text Value
val2map (Value -> HashMap Text Value)
-> (ClaimsSet -> Value) -> ClaimsSet -> HashMap Text Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClaimsSet -> Value
forall a. ToJSON a => a -> Value
JSON.toJSON
where
val2map :: Value -> HashMap Text Value
val2map (JSON.Object HashMap Text Value
o) = HashMap Text Value
o
val2map Value
_ = HashMap Text Value
forall k v. HashMap k v
M.empty
hs256jwk :: ByteString -> JWK
hs256jwk :: ByteString -> JWK
hs256jwk ByteString
key =
KeyMaterial -> JWK
fromKeyMaterial KeyMaterial
km
JWK -> (JWK -> JWK) -> JWK
forall a b. a -> (a -> b) -> b
& (Maybe KeyUse -> Identity (Maybe KeyUse)) -> JWK -> Identity JWK
Lens' JWK (Maybe KeyUse)
jwkUse ((Maybe KeyUse -> Identity (Maybe KeyUse)) -> JWK -> Identity JWK)
-> KeyUse -> JWK -> JWK
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ KeyUse
Sig
JWK -> (JWK -> JWK) -> JWK
forall a b. a -> (a -> b) -> b
& (Maybe JWKAlg -> Identity (Maybe JWKAlg)) -> JWK -> Identity JWK
Lens' JWK (Maybe JWKAlg)
jwkAlg ((Maybe JWKAlg -> Identity (Maybe JWKAlg)) -> JWK -> Identity JWK)
-> JWKAlg -> JWK -> JWK
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Alg -> JWKAlg
JWSAlg Alg
HS256
where
km :: KeyMaterial
km = OctKeyParameters -> KeyMaterial
OctKeyMaterial (Base64Octets -> OctKeyParameters
OctKeyParameters (ByteString -> Base64Octets
JOSE.Types.Base64Octets ByteString
key))
parseJWK :: ByteString -> JWK
parseJWK :: ByteString -> JWK
parseJWK ByteString
str =
JWK -> Maybe JWK -> JWK
forall a. a -> Maybe a -> a
fromMaybe (ByteString -> JWK
hs256jwk ByteString
str) (LByteString -> Maybe JWK
forall a. FromJSON a => LByteString -> Maybe a
JSON.decode (ByteString -> LByteString
forall a b. StringConv a b => a -> b
toS ByteString
str) :: Maybe JWK)