From 344c537446cc950cc7a3b5211fb33d39f66de090 Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Thu, 12 Dec 2024 21:56:08 -0500 Subject: [PATCH 01/11] Thighten hasql-notifications dependency --- postgres-websockets.cabal | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/postgres-websockets.cabal b/postgres-websockets.cabal index 4a5f2ff..a0e59a4 100644 --- a/postgres-websockets.cabal +++ b/postgres-websockets.cabal @@ -43,7 +43,7 @@ library , either >= 5.0.1.1 && < 5.1 , envparse >= 0.5.0 && < 0.6 , hasql ^>= 1.7 - , hasql-notifications >= 0.1.0.0 && < 0.3 + , hasql-notifications >= 0.2.3.0 && < 0.3 , hasql-pool ^>= 1.2 , http-types >= 0.12.3 && < 0.13 , jose >= 0.11 && < 0.12 @@ -95,7 +95,7 @@ test-suite postgres-websockets-test , aeson >= 2.0 && < 2.3 , hasql ^>= 1.7 , hasql-pool ^>= 1.2 - , hasql-notifications >= 0.1.0.0 && < 0.3 + , hasql-notifications >= 0.2.3.0 && < 0.3 , http-types >= 0.9 , time >= 1.8.0.2 && < 1.13 , unordered-containers >= 0.2 From 7600c89096afc80055854fe9519e2dbd7b318593 Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Wed, 19 Mar 2025 14:28:41 -0400 Subject: [PATCH 02/11] Update actions/cache for our CI actions are erroring and it seems the version 2 is no longer supported. --- .github/workflows/ci.yml | 2 +- .github/workflows/release.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 60c3b54..d0feba1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,7 +37,7 @@ jobs: with: ghc-version: ${{ matrix.ghc }} - - uses: actions/cache@v2.1.3 + - uses: actions/cache@v4 name: Cache ~/dist-newstyle with: path: ~/dist-newstyle diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f1efe0f..821bcdd 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -37,7 +37,7 @@ jobs: with: ghc-version: ${{ matrix.ghc }} - - uses: actions/cache@v2.1.3 + - uses: actions/cache@v4 name: Cache ~/dist-newstyle with: path: ~/dist-newstyle From 586b55dd7787cda8dacb144caab06d4b584d306f Mon Sep 17 00:00:00 2001 From: Wolfgang Walther Date: Wed, 19 Mar 2025 13:56:38 +0100 Subject: [PATCH 03/11] Relax dependencies --- postgres-websockets.cabal | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/postgres-websockets.cabal b/postgres-websockets.cabal index a0e59a4..908c60b 100644 --- a/postgres-websockets.cabal +++ b/postgres-websockets.cabal @@ -37,18 +37,18 @@ library build-depends: base >= 4.7 && < 5 , aeson >= 2.0 && < 2.3 , alarmclock >= 0.7.0.2 && < 0.8 - , auto-update >= 0.1.6 && < 0.2 + , auto-update >= 0.1.6 && < 0.3 , base64-bytestring >= 1.0.0.3 && < 1.3 - , bytestring >= 0.11.5 && < 0.12 + , bytestring >= 0.11.5 && < 0.13 , either >= 5.0.1.1 && < 5.1 - , envparse >= 0.5.0 && < 0.6 - , hasql ^>= 1.7 + , envparse >= 0.5.0 && < 0.7 + , hasql >= 1.7 && < 1.9 , hasql-notifications >= 0.2.3.0 && < 0.3 , hasql-pool ^>= 1.2 , http-types >= 0.12.3 && < 0.13 , jose >= 0.11 && < 0.12 - , lens >= 5.2.3 && < 5.3 - , postgresql-libpq >= 0.10.0 && < 0.11 + , lens >= 5.2.3 && < 5.4 + , postgresql-libpq >= 0.10.0 && < 0.12 , protolude >= 0.2.3 && < 0.4 , retry >= 0.8.1.0 && < 0.10 , stm >= 2.5.0.0 && < 2.6 @@ -62,7 +62,7 @@ library , wai-websockets >= 3.0 && < 4 , warp >= 3.2 && < 4 , warp-tls >= 3.2 && < 4 - , websockets >= 0.9 && < 0.13 + , websockets >= 0.9 && < 0.14 default-language: Haskell2010 @@ -93,7 +93,7 @@ test-suite postgres-websockets-test , postgres-websockets , hspec >= 2.7.1 && < 2.12 , aeson >= 2.0 && < 2.3 - , hasql ^>= 1.7 + , hasql >= 1.7 && < 1.9 , hasql-pool ^>= 1.2 , hasql-notifications >= 0.2.3.0 && < 0.3 , http-types >= 0.9 @@ -101,9 +101,9 @@ test-suite postgres-websockets-test , unordered-containers >= 0.2 , wai-extra >= 3.0.29 && < 3.2 , stm >= 2.5.0.0 && < 2.6 - , websockets >= 0.12.7.0 && < 0.13 - , network >= 2.8.0.1 && < 3.2 - , lens >= 4.17.1 && < 5.3 + , websockets >= 0.12.7.0 && < 0.14 + , network >= 2.8.0.1 && < 3.3 + , lens >= 4.17.1 && < 5.4 , lens-aeson >= 1.0.0 && < 1.3 ghc-options: -threaded -rtsopts -with-rtsopts=-N default-language: Haskell2010 From 72889ce261309e4bc0a3633f3922dc5313e2d31b Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Wed, 19 Mar 2025 15:15:22 -0400 Subject: [PATCH 04/11] Bump patch version number since we have new dependencies --- postgres-websockets.cabal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres-websockets.cabal b/postgres-websockets.cabal index 908c60b..7cde957 100644 --- a/postgres-websockets.cabal +++ b/postgres-websockets.cabal @@ -1,6 +1,6 @@ cabal-version: 3.0 name: postgres-websockets -version: 0.11.2.2 +version: 0.11.2.3 synopsis: Middleware to map LISTEN/NOTIFY messages to Websockets description: WAI middleware that adds websockets capabilites on top of PostgreSQL's asynchronous notifications using LISTEN and NOTIFY commands. Fully functioning server included. homepage: https://github.com/diogob/postgres-websockets#readme From 2bdcaa8227e7679583b10cf5c622f8908e65563d Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Wed, 19 Mar 2025 15:20:28 -0400 Subject: [PATCH 05/11] Hackage will not accept the package without proper description of autogen-modules --- postgres-websockets.cabal | 1 + 1 file changed, 1 insertion(+) diff --git a/postgres-websockets.cabal b/postgres-websockets.cabal index 7cde957..95868fa 100644 --- a/postgres-websockets.cabal +++ b/postgres-websockets.cabal @@ -30,6 +30,7 @@ library , PostgresWebsockets.Claims , PostgresWebsockets.Config + autogen-modules: Paths_postgres_websockets other-modules: Paths_postgres_websockets , PostgresWebsockets.Server , PostgresWebsockets.Middleware From 3bb1028c289e8fa19ffe1027839aecaab8128744 Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Fri, 18 Apr 2025 16:46:11 -0400 Subject: [PATCH 06/11] Remove protolude since the library seems a bit stale and we can replace the interesting bits with a simpler module. --- Setup.hs | 1 + app/Main.hs | 14 +-- postgres-websockets.cabal | 8 +- src/APrelude.hs | 104 ++++++++++++++++ src/PostgresWebsockets.hs | 28 ++--- src/PostgresWebsockets/Broadcast.hs | 9 +- src/PostgresWebsockets/Claims.hs | 148 +++++++++++------------ src/PostgresWebsockets/Config.hs | 16 +-- src/PostgresWebsockets/Context.hs | 9 +- src/PostgresWebsockets/HasqlBroadcast.hs | 27 ++--- src/PostgresWebsockets/Middleware.hs | 135 ++++++++++----------- src/PostgresWebsockets/Server.hs | 8 +- test/BroadcastSpec.hs | 27 +++-- test/ClaimsSpec.hs | 69 +++++++---- test/HasqlBroadcastSpec.hs | 28 ++--- test/ServerSpec.hs | 8 +- 16 files changed, 380 insertions(+), 259 deletions(-) create mode 100644 src/APrelude.hs diff --git a/Setup.hs b/Setup.hs index 9a994af..e8ef27d 100644 --- a/Setup.hs +++ b/Setup.hs @@ -1,2 +1,3 @@ import Distribution.Simple + main = defaultMain diff --git a/app/Main.hs b/app/Main.hs index c831e20..fcc41bd 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -1,19 +1,19 @@ module Main where -import Protolude +import APrelude import PostgresWebsockets - import System.IO (BufferMode (..), hSetBuffering) main :: IO () main = do hSetBuffering stdout LineBuffering - hSetBuffering stdin LineBuffering + hSetBuffering stdin LineBuffering hSetBuffering stderr NoBuffering - putStrLn $ ("postgres-websockets " :: Text) - <> prettyVersion - <> " / Connects websockets to PostgreSQL asynchronous notifications." + putStrLn $ + ("postgres-websockets ") + <> (unpack prettyVersion) + <> " / Connects websockets to PostgreSQL asynchronous notifications." conf <- loadConfig - void $ serve conf \ No newline at end of file + void $ serve conf diff --git a/postgres-websockets.cabal b/postgres-websockets.cabal index 95868fa..f77b2a2 100644 --- a/postgres-websockets.cabal +++ b/postgres-websockets.cabal @@ -18,7 +18,7 @@ common warnings common language default-language: Haskell2010 - default-extensions: OverloadedStrings, NoImplicitPrelude, LambdaCase, RecordWildCards, QuasiQuotes + default-extensions: OverloadedStrings, LambdaCase, RecordWildCards, QuasiQuotes library import: warnings @@ -29,6 +29,7 @@ library , PostgresWebsockets.HasqlBroadcast , PostgresWebsockets.Claims , PostgresWebsockets.Config + , APrelude autogen-modules: Paths_postgres_websockets other-modules: Paths_postgres_websockets @@ -49,8 +50,9 @@ library , http-types >= 0.12.3 && < 0.13 , jose >= 0.11 && < 0.12 , lens >= 5.2.3 && < 5.4 + , mtl + , async , postgresql-libpq >= 0.10.0 && < 0.12 - , protolude >= 0.2.3 && < 0.4 , retry >= 0.8.1.0 && < 0.10 , stm >= 2.5.0.0 && < 2.6 , stm-containers >= 1.1.0.2 && < 1.3 @@ -76,7 +78,6 @@ executable postgres-websockets ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: base >= 4.7 && < 5 , postgres-websockets - , protolude >= 0.2.3 && < 0.4 default-language: Haskell2010 test-suite postgres-websockets-test @@ -90,7 +91,6 @@ test-suite postgres-websockets-test , HasqlBroadcastSpec , ServerSpec build-depends: base - , protolude >= 0.2.3 && < 0.4 , postgres-websockets , hspec >= 2.7.1 && < 2.12 , aeson >= 2.0 && < 2.3 diff --git a/src/APrelude.hs b/src/APrelude.hs new file mode 100644 index 0000000..7c29c83 --- /dev/null +++ b/src/APrelude.hs @@ -0,0 +1,104 @@ +module APrelude + ( Text, + ByteString, + LByteString, + Generic, + fromMaybe, + putErrLn, + fromRight, + isJust, + decodeUtf8, + encodeUtf8, + MVar, + readMVar, + swapMVar, + newMVar, + STM, + atomically, + ThreadId, + forkFinally, + forkIO, + killThread, + threadDelay, + (>=>), + when, + forever, + void, + panic, + SomeException, + throwError, + liftIO, + runExceptT, + unpack, + pack, + showText, + showBS, + LBS.fromStrict, + stdin, + stdout, + stderr, + hPutStrLn, + Word16, + forM, + forM_, + takeMVar, + newEmptyMVar, + wait, + headDef, + tailSafe, + withAsync, + putMVar, + die, + myThreadId, + replicateM, + bracket, + ) +where + +import Control.Concurrent (ThreadId, forkFinally, forkIO, killThread, myThreadId, threadDelay) +import Control.Concurrent.Async (wait, withAsync) +import Control.Concurrent.MVar (MVar, newEmptyMVar, newMVar, putMVar, readMVar, swapMVar, takeMVar) +import Control.Concurrent.STM (STM, atomically) +import Control.Exception (Exception, SomeException, bracket, throw) +import Control.Monad (forM, forM_, forever, replicateM, void, when, (>=>)) +import Control.Monad.Error.Class (throwError) +import Control.Monad.Except (runExceptT) +import Control.Monad.IO.Class (liftIO) +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BS +import qualified Data.ByteString.Lazy as LBS +import Data.Either (fromRight) +import Data.Maybe (fromMaybe, isJust, listToMaybe) +import Data.Text (Text, pack, unpack) +import qualified Data.Text as T +import Data.Text.Encoding +import Data.Word (Word16) +import GHC.Generics (Generic) +import System.Exit (die) +import System.IO (hPutStrLn, stderr, stdin, stdout) + +showBS :: (Show a) => a -> BS.ByteString +showBS = BS.pack . show + +showText :: (Show a) => a -> Text +showText = T.pack . show + +type LByteString = LBS.ByteString + +-- | Uncatchable exceptions thrown and never caught. +newtype FatalError = FatalError {fatalErrorMessage :: Text} + deriving (Show) + +instance Exception FatalError + +panic :: Text -> a +panic a = throw (FatalError a) + +putErrLn :: Text -> IO () +putErrLn = hPutStrLn stderr . unpack + +headDef :: a -> [a] -> a +headDef def = fromMaybe def . listToMaybe + +tailSafe :: [a] -> [a] +tailSafe = drop 1 diff --git a/src/PostgresWebsockets.hs b/src/PostgresWebsockets.hs index f42e887..26a6306 100644 --- a/src/PostgresWebsockets.hs +++ b/src/PostgresWebsockets.hs @@ -1,16 +1,16 @@ -{-| -Module : PostgresWebsockets -Description : PostgresWebsockets main library interface. - -These are all function necessary to configure and start the server. --} +-- | +-- Module : PostgresWebsockets +-- Description : PostgresWebsockets main library interface. +-- +-- These are all function necessary to configure and start the server. module PostgresWebsockets - ( prettyVersion - , loadConfig - , serve - , postgresWsMiddleware - ) where + ( prettyVersion, + loadConfig, + serve, + postgresWsMiddleware, + ) +where -import PostgresWebsockets.Middleware ( postgresWsMiddleware ) -import PostgresWebsockets.Server ( serve ) -import PostgresWebsockets.Config ( prettyVersion, loadConfig ) +import PostgresWebsockets.Config (loadConfig, prettyVersion) +import PostgresWebsockets.Middleware (postgresWsMiddleware) +import PostgresWebsockets.Server (serve) diff --git a/src/PostgresWebsockets/Broadcast.hs b/src/PostgresWebsockets/Broadcast.hs index 8fac2db..99262a8 100644 --- a/src/PostgresWebsockets/Broadcast.hs +++ b/src/PostgresWebsockets/Broadcast.hs @@ -26,11 +26,10 @@ module PostgresWebsockets.Broadcast ) where +import APrelude import Control.Concurrent.STM.TChan import Control.Concurrent.STM.TQueue import qualified Data.Aeson as A -import Protolude hiding (toS) -import Protolude.Conv (toS) import qualified StmContainers.Map as M data Message = Message @@ -63,7 +62,7 @@ instance A.ToJSON MultiplexerSnapshot -- | Given a multiplexer derive a type that can be printed for debugging or logging purposes takeSnapshot :: Multiplexer -> IO MultiplexerSnapshot takeSnapshot multi = - MultiplexerSnapshot <$> size <*> e <*> thread + MultiplexerSnapshot <$> size <*> e <*> (pack <$> thread) where size = atomically $ M.size $ channels multi thread = show <$> readMVar (producerThreadId multi) @@ -113,7 +112,7 @@ superviseMultiplexer multi msInterval shouldRestart = do new <- reopenProducer multi void $ swapMVar (producerThreadId multi) new snapAfter <- takeSnapshot multi - putStrLn $ + print $ "Restarting producer. Multiplexer updated: " <> A.encode snapBefore <> " -> " @@ -142,7 +141,7 @@ onMessage multi chan action = do where disposeListener _ = atomically $ do mC <- M.lookup chan (channels multi) - let c = fromMaybe (panic $ "trying to remove listener from non existing channel: " <> toS chan) mC + let c = fromMaybe (panic $ "trying to remove listener from non existing channel: " <> chan) mC M.delete chan (channels multi) when (listeners c - 1 > 0) $ M.insert Channel {broadcast = broadcast c, listeners = listeners c - 1} chan (channels multi) diff --git a/src/PostgresWebsockets/Claims.hs b/src/PostgresWebsockets/Claims.hs index 9310c07..cd9fee5 100644 --- a/src/PostgresWebsockets/Claims.hs +++ b/src/PostgresWebsockets/Claims.hs @@ -1,54 +1,56 @@ -{-| -Module : PostgresWebsockets.Claims -Description : Parse and validate JWT to open postgres-websockets channels. - -This module provides the JWT claims validation. Since websockets and -listening connections in the database tend to be resource intensive -(not to mention stateful) we need claims authorizing a specific channel and -mode of operation. --} +-- | +-- Module : PostgresWebsockets.Claims +-- Description : Parse and validate JWT to open postgres-websockets channels. +-- +-- This module provides the JWT claims validation. Since websockets and +-- listening connections in the database tend to be resource intensive +-- (not to mention stateful) we need claims authorizing a specific channel and +-- mode of operation. module PostgresWebsockets.Claims - ( ConnectionInfo,validateClaims - ) where + ( ConnectionInfo, + validateClaims, + ) +where -import Protolude hiding (toS) -import Protolude.Conv +import APrelude import Control.Lens -import Crypto.JWT -import Data.List -import Data.Time.Clock (UTCTime) import qualified Crypto.JOSE.Types as JOSE.Types +import Crypto.JWT import qualified Data.Aeson as JSON -import qualified Data.Aeson.KeyMap as JSON import qualified Data.Aeson.Key as Key +import qualified Data.Aeson.KeyMap as JSON +import Data.List +import Data.Time.Clock (UTCTime) type Claims = JSON.KeyMap JSON.Value + type ConnectionInfo = ([Text], Text, Claims) -{-| Given a secret, a token and a timestamp it validates the claims and returns - either an error message or a triple containing channel, mode and claims KeyMap. --} -validateClaims - :: Maybe Text - -> ByteString - -> LByteString - -> UTCTime - -> IO (Either Text ConnectionInfo) +-- | Given a secret, a token and a timestamp it validates the claims and returns +-- either an error message or a triple containing channel, mode and claims KeyMap. +validateClaims :: + Maybe Text -> + ByteString -> + LByteString -> + UTCTime -> + IO (Either Text ConnectionInfo) validateClaims requestChannel secret jwtToken time = runExceptT $ do - cl <- liftIO $ jwtClaims time (parseJWK secret) jwtToken + cl <- liftIO $ jwtClaims time (parseJWK secret) jwtToken cl' <- case cl of - JWTClaims c -> pure c + JWTClaims c -> pure c JWTInvalid JWTExpired -> throwError "Token expired" - JWTInvalid err -> throwError $ "Error: " <> show err - channels <- let chs = claimAsJSONList "channels" cl' in pure $ case claimAsJSON "channel" cl' of - Just c -> case chs of - Just cs -> nub (c : cs) - Nothing -> [c] - Nothing -> fromMaybe [] chs + JWTInvalid err -> throwError $ "Error: " <> showText err + channels <- + let chs = claimAsJSONList "channels" cl' + in pure $ case claimAsJSON "channel" cl' of + Just c -> case chs of + Just cs -> nub (c : cs) + Nothing -> [c] + Nothing -> fromMaybe [] chs mode <- let md = claimAsJSON "mode" cl' - in case md of - Just m -> pure m + in case md of + Just m -> pure m Nothing -> throwError "Missing mode" requestedAllowedChannels <- case (requestChannel, length channels) of (Just rc, 0) -> pure [rc] @@ -56,32 +58,30 @@ validateClaims requestChannel secret jwtToken time = runExceptT $ do (Nothing, _) -> pure channels validChannels <- if null requestedAllowedChannels then throwError "No allowed channels" else pure requestedAllowedChannels pure (validChannels, mode, cl') + where + claimAsJSON :: Text -> Claims -> Maybe Text + claimAsJSON name cl = case JSON.lookup (Key.fromText name) cl of + Just (JSON.String s) -> Just s + _ -> Nothing - where - claimAsJSON :: Text -> Claims -> Maybe Text - claimAsJSON name cl = case JSON.lookup (Key.fromText name) cl of - Just (JSON.String s) -> Just s - _ -> Nothing - - claimAsJSONList :: Text -> Claims -> Maybe [Text] - claimAsJSONList name cl = case JSON.lookup (Key.fromText name) cl of - Just channelsJson -> - case JSON.fromJSON channelsJson :: JSON.Result [Text] of - JSON.Success channelsList -> Just channelsList - _ -> Nothing - Nothing -> Nothing + claimAsJSONList :: Text -> Claims -> Maybe [Text] + claimAsJSONList name cl = case JSON.lookup (Key.fromText name) cl of + Just channelsJson -> + case JSON.fromJSON channelsJson :: JSON.Result [Text] of + JSON.Success channelsList -> Just channelsList + _ -> Nothing + Nothing -> Nothing -{-| - Possible situations encountered with client JWTs --} -data JWTAttempt = JWTInvalid JWTError - | JWTClaims (JSON.KeyMap JSON.Value) - deriving Eq +-- | +-- Possible situations encountered with client JWTs +data JWTAttempt + = JWTInvalid JWTError + | JWTClaims (JSON.KeyMap JSON.Value) + deriving (Eq) -{-| - Receives the JWT secret (from config) and a JWT and returns a map - of JWT claims. --} +-- | +-- Receives the JWT secret (from config) and a JWT and returns a map +-- of JWT claims. jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt jwtClaims _ _ "" = return $ JWTClaims JSON.empty jwtClaims time jwk' payload = do @@ -90,32 +90,30 @@ jwtClaims time jwk' payload = do jwt <- decodeCompact payload verifyClaimsAt config jwk' time jwt return $ case eJwt of - Left e -> JWTInvalid e + Left e -> JWTInvalid e Right jwt -> JWTClaims . claims2map $ jwt -{-| - Internal helper used to turn JWT ClaimSet into something - easier to work with --} +-- | +-- Internal helper used to turn JWT ClaimSet into something +-- easier to work with claims2map :: ClaimsSet -> JSON.KeyMap JSON.Value claims2map = val2map . JSON.toJSON - where - val2map (JSON.Object o) = o - val2map _ = JSON.empty + where + val2map (JSON.Object o) = o + val2map _ = JSON.empty -{-| - Internal helper to generate HMAC-SHA256. When the jwt key in the - config file is a simple string rather than a JWK object, we'll - apply this function to it. --} +-- | +-- Internal helper to generate HMAC-SHA256. When the jwt key in the +-- config file is a simple string rather than a JWK object, we'll +-- apply this function to it. hs256jwk :: ByteString -> JWK hs256jwk key = fromKeyMaterial km & jwkUse ?~ Sig & jwkAlg ?~ JWSAlg HS256 - where - km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key)) + where + km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key)) parseJWK :: ByteString -> JWK parseJWK str = - fromMaybe (hs256jwk str) (JSON.decode (toS str) :: Maybe JWK) + fromMaybe (hs256jwk str) (JSON.decode (fromStrict str) :: Maybe JWK) diff --git a/src/PostgresWebsockets/Config.hs b/src/PostgresWebsockets/Config.hs index 3436845..7215e56 100644 --- a/src/PostgresWebsockets/Config.hs +++ b/src/PostgresWebsockets/Config.hs @@ -14,6 +14,7 @@ module PostgresWebsockets.Config ) where +import APrelude import qualified Data.ByteString as BS import qualified Data.ByteString.Base64 as B64 import Data.String (IsString (..)) @@ -22,8 +23,6 @@ import Data.Version (versionBranch) import Env import Network.Wai.Handler.Warp import Paths_postgres_websockets (version) -import Protolude hiding (intercalate, optional, replace, toS, (<>)) -import Protolude.Conv -- | Config file settings for the server data AppConfig = AppConfig @@ -45,7 +44,7 @@ data AppConfig = AppConfig -- | User friendly version number prettyVersion :: Text -prettyVersion = intercalate "." $ map show $ versionBranch version +prettyVersion = intercalate "." $ map showText $ versionBranch version -- | Load all postgres-websockets config from Environment variables. This can be used to use just the middleware or to feed into warpSettings loadConfig :: IO AppConfig @@ -58,9 +57,9 @@ loadConfig = -- | Given a shutdown handler and an AppConfig builds a Warp Settings to start a stand-alone server warpSettings :: (IO () -> IO ()) -> AppConfig -> Settings warpSettings waitForShutdown AppConfig {..} = - setHost (fromString $ toS configHost) + setHost (fromString $ unpack configHost) . setPort configPort - . setServerName (toS $ "postgres-websockets/" <> prettyVersion) + . setServerName ("postgres-websockets/" <> encodeUtf8 prettyVersion) . setTimeout 3600 . setInstallShutdownHandler waitForShutdown . setGracefulShutdownTimeout (Just 5) @@ -72,7 +71,8 @@ warpSettings waitForShutdown AppConfig {..} = readOptions :: IO AppConfig readOptions = Env.parse (header "You need to configure some environment variables to start the service.") $ - AppConfig <$> var (str <=< nonempty) "PGWS_DB_URI" (help "String to connect to PostgreSQL") + AppConfig + <$> var (str <=< nonempty) "PGWS_DB_URI" (help "String to connect to PostgreSQL") <*> optional (var str "PGWS_ROOT_PATH" (help "Root path to serve static files, unset to disable.")) <*> var str "PGWS_HOST" (def "*4" <> helpDef show <> help "Address the server will listen for websocket connections") <*> var auto "PGWS_PORT" (def 3000 <> helpDef show <> help "Port the server will listen for websocket connections") @@ -96,7 +96,7 @@ loadDatabaseURIFile :: AppConfig -> IO AppConfig loadDatabaseURIFile conf@AppConfig {..} = case stripPrefix "@" configDatabase of Nothing -> pure conf - Just filename -> setDatabase . strip <$> readFile (toS filename) + Just filename -> setDatabase . strip . pack <$> readFile (unpack filename) where setDatabase uri = conf {configDatabase = uri} @@ -112,7 +112,7 @@ loadSecretFile conf = extractAndTransform secret transformString isB64 =<< case stripPrefix "@" s of Nothing -> return . encodeUtf8 $ s - Just filename -> chomp <$> BS.readFile (toS filename) + Just filename -> chomp <$> BS.readFile (unpack filename) where chomp bs = fromMaybe bs (BS.stripSuffix "\n" bs) diff --git a/src/PostgresWebsockets/Context.hs b/src/PostgresWebsockets/Context.hs index a944f14..86e25ca 100644 --- a/src/PostgresWebsockets/Context.hs +++ b/src/PostgresWebsockets/Context.hs @@ -7,6 +7,7 @@ module PostgresWebsockets.Context ) where +import APrelude import Control.AutoUpdate ( defaultUpdateSettings, mkAutoUpdate, @@ -18,8 +19,6 @@ import qualified Hasql.Pool.Config as P import PostgresWebsockets.Broadcast (Multiplexer) import PostgresWebsockets.Config (AppConfig (..)) import PostgresWebsockets.HasqlBroadcast (newHasqlBroadcaster) -import Protolude hiding (toS) -import Protolude.Conv data Context = Context { ctxConfig :: AppConfig, @@ -33,15 +32,15 @@ mkContext :: AppConfig -> IO () -> IO Context mkContext conf@AppConfig {..} shutdownServer = do Context conf <$> P.acquire config - <*> newHasqlBroadcaster shutdown (toS configListenChannel) configRetries configReconnectInterval pgSettings + <*> newHasqlBroadcaster shutdown configListenChannel configRetries configReconnectInterval pgSettings <*> mkGetTime where config = P.settings [P.staticConnectionSettings pgSettings] shutdown = maybe shutdownServer - (const $ putText "Producer thread is dead") + (const $ putStrLn "Producer thread is dead") configReconnectInterval mkGetTime :: IO (IO UTCTime) mkGetTime = mkAutoUpdate defaultUpdateSettings {updateAction = getCurrentTime} - pgSettings = toS configDatabase + pgSettings = encodeUtf8 configDatabase diff --git a/src/PostgresWebsockets/HasqlBroadcast.hs b/src/PostgresWebsockets/HasqlBroadcast.hs index c622aff..9bb39a2 100644 --- a/src/PostgresWebsockets/HasqlBroadcast.hs +++ b/src/PostgresWebsockets/HasqlBroadcast.hs @@ -15,11 +15,11 @@ module PostgresWebsockets.HasqlBroadcast ) where +import APrelude import Control.Retry (RetryStatus (..), capDelay, exponentialBackoff, retrying) import Data.Aeson (Value (..), decode) -import qualified Data.Aeson.KeyMap as JSON import qualified Data.Aeson.Key as Key - +import qualified Data.Aeson.KeyMap as JSON import Data.Either.Combinators (mapBoth) import Data.Function (id) import GHC.Show @@ -30,8 +30,6 @@ import Hasql.Notifications import qualified Hasql.Session as H import qualified Hasql.Statement as H import PostgresWebsockets.Broadcast -import Protolude hiding (putErrLn, show, toS) -import Protolude.Conv -- | Returns a multiplexer from a connection URI, keeps trying to connect in case there is any error. -- This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners @@ -44,7 +42,7 @@ newHasqlBroadcaster onConnectionFailure ch maxRetries checkInterval = newHasqlBr -- This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners newHasqlBroadcasterOrError :: IO () -> Text -> ByteString -> IO (Either ByteString Multiplexer) newHasqlBroadcasterOrError onConnectionFailure ch = - acquire >=> (sequence . mapBoth (toSL . show) (newHasqlBroadcasterForConnection . return)) + acquire >=> (sequence . mapBoth showBS (newHasqlBroadcasterForConnection . return)) where newHasqlBroadcasterForConnection = newHasqlBroadcasterForChannel onConnectionFailure ch Nothing @@ -60,7 +58,7 @@ tryUntilConnected maxRetries = shouldRetry RetryStatus {..} con = case con of Left err -> do - putErrLn $ "Error connecting notification listener to database: " <> (toS . show) err + putErrLn $ "Error connecting notification listener to database: " <> showText err pure $ rsIterNumber < maxRetries - 1 _ -> return False @@ -94,16 +92,16 @@ newHasqlBroadcasterForChannel onConnectionFailure ch checkInterval getCon = do return multi where toMsg :: Text -> Text -> Message - toMsg c m = case decode (toS m) of + toMsg c m = case decode (fromStrict $ encodeUtf8 m) of Just v -> Message (channelDef c v) m Nothing -> Message c m lookupStringDef :: Text -> Text -> Value -> Text lookupStringDef key d (Object obj) = - case lookupDefault (String $ toS d) key obj of - String s -> toS s - _ -> toS d - lookupStringDef _ d _ = toS d + case lookupDefault (String d) key obj of + String s -> s + _ -> d + lookupStringDef _ d _ = d lookupDefault d key obj = fromMaybe d $ JSON.lookup (Key.fromText key) obj @@ -116,12 +114,9 @@ newHasqlBroadcasterForChannel onConnectionFailure ch checkInterval getCon = do con <- getCon listen con $ toPgIdentifier ch waitForNotifications - (\c m -> atomically $ writeTQueue msgQ $ toMsg (toS c) (toS m)) + (\c m -> atomically $ writeTQueue msgQ $ toMsg (decodeUtf8 c) (decodeUtf8 m)) con -putErrLn :: Text -> IO () -putErrLn = hPutStrLn stderr - isListening :: Connection -> Text -> IO Bool isListening con ch = do resultOrError <- H.run session con @@ -136,4 +131,4 @@ isListeningStatement = where sql = "select exists (select * from pg_stat_activity where datname = current_database() and query ilike $1);" encoder = HE.param $ HE.nonNullable HE.text - decoder = HD.singleRow (HD.column (HD.nonNullable HD.bool)) \ No newline at end of file + decoder = HD.singleRow (HD.column (HD.nonNullable HD.bool)) diff --git a/src/PostgresWebsockets/Middleware.hs b/src/PostgresWebsockets/Middleware.hs index 7c0108b..e0e06d9 100644 --- a/src/PostgresWebsockets/Middleware.hs +++ b/src/PostgresWebsockets/Middleware.hs @@ -1,53 +1,50 @@ -{-| -Module : PostgresWebsockets.Middleware -Description : PostgresWebsockets WAI middleware, add functionality to any WAI application. - -Allow websockets connections that will communicate with the database through LISTEN/NOTIFY channels. --} {-# LANGUAGE DeriveGeneric #-} +-- | +-- Module : PostgresWebsockets.Middleware +-- Description : PostgresWebsockets WAI middleware, add functionality to any WAI application. +-- +-- Allow websockets connections that will communicate with the database through LISTEN/NOTIFY channels. module PostgresWebsockets.Middleware - ( postgresWsMiddleware - ) where + ( postgresWsMiddleware, + ) +where -import Protolude hiding (toS) -import Protolude.Conv -import Data.Time.Clock (UTCTime) -import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime) +import APrelude import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm) +import qualified Data.Aeson as A +import qualified Data.Aeson.Key as Key +import qualified Data.Aeson.KeyMap as A +import qualified Data.ByteString.Lazy as BL +import qualified Data.Text as T +import Data.Time.Clock (UTCTime) +import Data.Time.Clock.POSIX (posixSecondsToUTCTime, utcTimeToPOSIXSeconds) import qualified Hasql.Notifications as H import qualified Hasql.Pool as H import qualified Network.Wai as Wai import qualified Network.Wai.Handler.WebSockets as WS import qualified Network.WebSockets as WS - -import qualified Data.Aeson as A -import qualified Data.Aeson.KeyMap as A -import qualified Data.Aeson.Key as Key - -import qualified Data.Text as T -import qualified Data.ByteString.Lazy as BL - import PostgresWebsockets.Broadcast (onMessage) -import PostgresWebsockets.Claims ( ConnectionInfo, validateClaims ) -import PostgresWebsockets.Context ( Context(..) ) -import PostgresWebsockets.Config (AppConfig(..)) import qualified PostgresWebsockets.Broadcast as B +import PostgresWebsockets.Claims (ConnectionInfo, validateClaims) +import PostgresWebsockets.Config (AppConfig (..)) +import PostgresWebsockets.Context (Context (..)) - -data Event = - WebsocketMessage +data Event + = WebsocketMessage | ConnectionOpen deriving (Show, Eq, Generic) data Message = Message - { claims :: A.Object - , event :: Event - , payload :: Text - , channel :: Text - } deriving (Show, Eq, Generic) + { claims :: A.Object, + event :: Event, + payload :: Text, + channel :: Text + } + deriving (Show, Eq, Generic) instance A.ToJSON Event + instance A.ToJSON Message -- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware. @@ -62,8 +59,8 @@ jwtExpirationStatusCode = 3001 -- when the websocket is closed a ConnectionClosed Exception is triggered -- this kills all children and frees resources for us wsApp :: Context -> WS.ServerApp -wsApp Context{..} pendingConn = - ctxGetTime >>= validateClaims requestChannel (configJwtSecret ctxConfig) (toS jwtToken) >>= either rejectRequest forkSessions +wsApp Context {..} pendingConn = + ctxGetTime >>= validateClaims requestChannel (configJwtSecret ctxConfig) (fromStrict $ encodeUtf8 jwtToken) >>= either rejectRequest forkSessions where hasRead m = m == ("r" :: Text) || m == ("rw" :: Text) hasWrite m = m == ("w" :: Text) || m == ("rw" :: Text) @@ -71,10 +68,10 @@ wsApp Context{..} pendingConn = rejectRequest :: Text -> IO () rejectRequest msg = do putErrLn $ "Rejecting Request: " <> msg - WS.rejectRequest pendingConn (toS msg) + WS.rejectRequest pendingConn (encodeUtf8 msg) -- the URI has one of the two formats - /:jwt or /:channel/:jwt - pathElements = T.split (== '/') $ T.drop 1 $ (toSL . WS.requestPath) $ WS.pendingRequest pendingConn + pathElements = T.split (== '/') $ T.drop 1 $ (decodeUtf8 . WS.requestPath) $ WS.pendingRequest pendingConn jwtToken = case length pathElements `compare` 1 of GT -> headDef "" $ tailSafe pathElements @@ -85,35 +82,37 @@ wsApp Context{..} pendingConn = _ -> Nothing forkSessions :: ConnectionInfo -> IO () forkSessions (chs, mode, validClaims) = do - -- We should accept only after verifying JWT - conn <- WS.acceptRequest pendingConn - -- Fork a pinging thread to ensure browser connections stay alive - WS.withPingThread conn 30 (pure ()) $ do - case A.lookup "exp" validClaims of - Just (A.Number expClaim) -> do - connectionExpirer <- newAlarmClock $ const (WS.sendCloseCode conn jwtExpirationStatusCode ("JWT expired" :: ByteString)) - setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim) - Just _ -> pure () - Nothing -> pure () - - let sendNotification msg channel = sendMessageWithTimestamp $ websocketMessageForChannel msg channel - sendMessageToDatabase = sendToDatabase ctxPool (configListenChannel ctxConfig) - sendMessageWithTimestamp = timestampMessage ctxGetTime >=> sendMessageToDatabase - websocketMessageForChannel = Message validClaims WebsocketMessage - connectionOpenMessage = Message validClaims ConnectionOpen - - case configMetaChannel ctxConfig of - Nothing -> pure () - Just ch -> sendMessageWithTimestamp $ connectionOpenMessage (toS $ T.intercalate "," chs) ch - - when (hasRead mode) $ - forM_ chs $ flip (onMessage ctxMulti) $ WS.sendTextData conn . B.payload - - when (hasWrite mode) $ - notifySession conn sendNotification chs - - waitForever <- newEmptyMVar - void $ takeMVar waitForever + -- We should accept only after verifying JWT + conn <- WS.acceptRequest pendingConn + -- Fork a pinging thread to ensure browser connections stay alive + WS.withPingThread conn 30 (pure ()) $ do + case A.lookup "exp" validClaims of + Just (A.Number expClaim) -> do + connectionExpirer <- newAlarmClock $ const (WS.sendCloseCode conn jwtExpirationStatusCode ("JWT expired" :: ByteString)) + setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim) + Just _ -> pure () + Nothing -> pure () + + let sendNotification msg channel = sendMessageWithTimestamp $ websocketMessageForChannel msg channel + sendMessageToDatabase = sendToDatabase ctxPool (configListenChannel ctxConfig) + sendMessageWithTimestamp = timestampMessage ctxGetTime >=> sendMessageToDatabase + websocketMessageForChannel = Message validClaims WebsocketMessage + connectionOpenMessage = Message validClaims ConnectionOpen + + case configMetaChannel ctxConfig of + Nothing -> pure () + Just ch -> sendMessageWithTimestamp $ connectionOpenMessage (T.intercalate "," chs) ch + + when (hasRead mode) $ + forM_ chs $ + flip (onMessage ctxMulti) $ + WS.sendTextData conn . B.payload + + when (hasWrite mode) $ + notifySession conn sendNotification chs + + waitForever <- newEmptyMVar + void $ takeMVar waitForever -- Having both channel and claims as parameters seem redundant -- But it allows the function to ignore the claims structure and the source @@ -124,16 +123,16 @@ notifySession wsCon sendToChannel chs = where relayData = do msg <- WS.receiveData wsCon - forM_ chs (sendToChannel msg . toS) + forM_ chs (sendToChannel msg) sendToDatabase :: H.Pool -> Text -> Message -> IO () sendToDatabase pool dbChannel = notify . jsonMsg where - notify = void . H.notifyPool pool dbChannel . toS + notify = void . H.notifyPool pool dbChannel . decodeUtf8 jsonMsg = BL.toStrict . A.encode timestampMessage :: IO UTCTime -> Message -> IO Message -timestampMessage getTime msg@Message{..} = do +timestampMessage getTime msg@Message {..} = do time <- utcTimeToPOSIXSeconds <$> getTime - return $ msg{ claims = A.insert (Key.fromText "message_delivered_at") (A.Number $ realToFrac time) claims} + return $ msg {claims = A.insert (Key.fromText "message_delivered_at") (A.Number $ realToFrac time) claims} diff --git a/src/PostgresWebsockets/Server.hs b/src/PostgresWebsockets/Server.hs index e579777..1cf33ba 100644 --- a/src/PostgresWebsockets/Server.hs +++ b/src/PostgresWebsockets/Server.hs @@ -6,6 +6,7 @@ module PostgresWebsockets.Server ) where +import APrelude import Network.HTTP.Types (status200) import Network.Wai (Application, responseLBS) import Network.Wai.Application.Static (defaultFileServerSettings, staticApp) @@ -15,13 +16,12 @@ import Network.Wai.Middleware.RequestLogger (logStdout) import PostgresWebsockets.Config (AppConfig (..), warpSettings) import PostgresWebsockets.Context (mkContext) import PostgresWebsockets.Middleware (postgresWsMiddleware) -import Protolude -- | Start a stand-alone warp server using the parameters from AppConfig and a opening a database connection pool. serve :: AppConfig -> IO () serve conf@AppConfig {..} = do shutdownSignal <- newEmptyMVar - putStrLn $ ("Listening on port " :: Text) <> show configPort + putStrLn $ "Listening on port " <> show configPort let shutdown = putErrLn ("Broadcaster connection is dead" :: Text) >> putMVar shutdownSignal () ctx <- mkContext conf shutdown @@ -31,13 +31,13 @@ serve conf@AppConfig {..} = do app = postgresWsMiddleware ctx $ logStdout $ maybe dummyApp staticApp' configPath case (configCertificateFile, configKeyFile) of - (Just certificate, Just key) -> runTLS (tlsSettings (toS certificate) (toS key)) appSettings app + (Just certificate, Just key) -> runTLS (tlsSettings (unpack certificate) (unpack key)) appSettings app _ -> runSettings appSettings app die "Shutting down server..." where staticApp' :: Text -> Application - staticApp' = staticApp . defaultFileServerSettings . toS + staticApp' = staticApp . defaultFileServerSettings . unpack dummyApp :: Application dummyApp _ respond = respond $ responseLBS status200 [("Content-Type", "text/plain")] "Hello, Web!" diff --git a/test/BroadcastSpec.hs b/test/BroadcastSpec.hs index 1fa376a..393a3c6 100644 --- a/test/BroadcastSpec.hs +++ b/test/BroadcastSpec.hs @@ -1,11 +1,9 @@ module BroadcastSpec (spec) where -import Protolude +import APrelude import Control.Concurrent.STM.TQueue - -import Test.Hspec - import PostgresWebsockets.Broadcast +import Test.Hspec spec :: Spec spec = do @@ -13,18 +11,27 @@ spec = do it "opens a separate thread for a producer function" $ do output <- newTQueueIO :: IO (TQueue ThreadId) - void $ liftIO $ newMultiplexer (\_-> do - tid <- myThreadId - atomically $ writeTQueue output tid - ) (\_ -> return ()) + void $ + liftIO $ + newMultiplexer + ( \_ -> do + tid <- myThreadId + atomically $ writeTQueue output tid + ) + (\_ -> return ()) outMsg <- atomically $ readTQueue output myThreadId `shouldNotReturn` outMsg describe "relayMessages" $ it "relays a single message from producer to 1 listener on 1 test channel" $ do output <- newTQueueIO :: IO (TQueue Message) - multi <- liftIO $ newMultiplexer (\msgs-> - atomically $ writeTQueue msgs (Message "test" "payload")) (\_ -> return ()) + multi <- + liftIO $ + newMultiplexer + ( \msgs -> + atomically $ writeTQueue msgs (Message "test" "payload") + ) + (\_ -> return ()) void $ onMessage multi "test" $ atomically . writeTQueue output liftIO $ relayMessages multi diff --git a/test/ClaimsSpec.hs b/test/ClaimsSpec.hs index 358a7ca..9d5cdd9 100644 --- a/test/ClaimsSpec.hs +++ b/test/ClaimsSpec.hs @@ -1,12 +1,12 @@ module ClaimsSpec (spec) where -import Protolude - -import Test.Hspec -import Data.Aeson (Value (..), toJSON) +import APrelude +import Data.Aeson (Value (..), toJSON) import qualified Data.Aeson.KeyMap as JSON import Data.Time.Clock -import PostgresWebsockets.Claims +import PostgresWebsockets.Claims +import Test.Hspec +import Prelude secret :: ByteString secret = "reallyreallyreallyreallyverysafe" @@ -16,39 +16,60 @@ spec = describe "validate claims" $ do it "should invalidate an expired token" $ do time <- getCurrentTime - validateClaims Nothing secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0IiwiZXhwIjoxfQ.4rDYiMZFR2WHB7Eq4HMdvDP_BQZVtHIfyJgy0NshbHY" time + validateClaims + Nothing + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0IiwiZXhwIjoxfQ.4rDYiMZFR2WHB7Eq4HMdvDP_BQZVtHIfyJgy0NshbHY" + time `shouldReturn` Left "Token expired" it "request any channel from a token that does not have channels or channel claims should succeed" $ do time <- getCurrentTime - validateClaims (Just "test") secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciJ9.jL5SsRFegNUlbBm8_okhHSujqLcKKZdDglfdqNl1_rY" time - `shouldReturn` Right (["test"], "r", JSON.fromList[("mode",String "r")]) + validateClaims + (Just "test") + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciJ9.jL5SsRFegNUlbBm8_okhHSujqLcKKZdDglfdqNl1_rY" + time + `shouldReturn` Right (["test"], "r", JSON.fromList [("mode", String "r")]) it "requesting a channel that is set by and old style channel claim should work" $ do time <- getCurrentTime - validateClaims (Just "test") secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58" time - `shouldReturn` Right (["test"], "r", JSON.fromList[("mode",String "r"),("channel",String "test")]) + validateClaims + (Just "test") + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58" + time + `shouldReturn` Right (["test"], "r", JSON.fromList [("mode", String "r"), ("channel", String "test")]) it "no requesting channel should return all channels in the token" $ do time <- getCurrentTime - validateClaims Nothing secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJhbm90aGVyIHRlc3QiXX0.b9N8J8tPOPIxxFj5WJ7sWrmcL8ib63i8eirsRZTM9N0" time - `shouldReturn` Right (["test", "another test"], "r", JSON.fromList[("mode",String "r"),("channels", toJSON["test"::Text, "another test"::Text] ) ]) + validateClaims + Nothing + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJhbm90aGVyIHRlc3QiXX0.b9N8J8tPOPIxxFj5WJ7sWrmcL8ib63i8eirsRZTM9N0" + time + `shouldReturn` Right (["test", "another test"], "r", JSON.fromList [("mode", String "r"), ("channels", toJSON ["test" :: Text, "another test" :: Text])]) it "requesting a channel from the channels claim shoud return only the requested channel" $ do time <- getCurrentTime - validateClaims (Just "test") secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJ0ZXN0MiJdfQ.MumdJ5FpFX4Z6SJD3qsygVF0r9vqxfqhj5J30O32N0k" time - `shouldReturn` Right (["test"], "r", JSON.fromList[("mode",String "r"),("channels", toJSON ["test"::Text, "test2"] )]) + validateClaims + (Just "test") + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJ0ZXN0MiJdfQ.MumdJ5FpFX4Z6SJD3qsygVF0r9vqxfqhj5J30O32N0k" + time + `shouldReturn` Right (["test"], "r", JSON.fromList [("mode", String "r"), ("channels", toJSON ["test" :: Text, "test2"])]) it "requesting a channel not from the channels claim shoud error" $ do time <- getCurrentTime - validateClaims (Just "notAllowed") secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJ0ZXN0MiJdfQ.MumdJ5FpFX4Z6SJD3qsygVF0r9vqxfqhj5J30O32N0k" time - `shouldReturn` Left "No allowed channels" + validateClaims + (Just "notAllowed") + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJ0ZXN0MiJdfQ.MumdJ5FpFX4Z6SJD3qsygVF0r9vqxfqhj5J30O32N0k" + time + `shouldReturn` Left "No allowed channels" it "requesting a channel with no mode fails" $ do time <- getCurrentTime - validateClaims (Just "test") secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJjaGFubmVscyI6WyJ0ZXN0IiwidGVzdDIiXX0.akC1PEYk2DEZtLP2XjC6qXOGZJejmPx49qv-VeEtQYQ" time + validateClaims + (Just "test") + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJjaGFubmVscyI6WyJ0ZXN0IiwidGVzdDIiXX0.akC1PEYk2DEZtLP2XjC6qXOGZJejmPx49qv-VeEtQYQ" + time `shouldReturn` Left "Missing mode" diff --git a/test/HasqlBroadcastSpec.hs b/test/HasqlBroadcastSpec.hs index 1cbdf67..8d00156 100644 --- a/test/HasqlBroadcastSpec.hs +++ b/test/HasqlBroadcastSpec.hs @@ -1,25 +1,23 @@ module HasqlBroadcastSpec (spec) where -import Protolude - -import Data.Function (id) -import Test.Hspec +import APrelude +import Hasql.Notifications import PostgresWebsockets.Broadcast import PostgresWebsockets.HasqlBroadcast -import Hasql.Notifications +import Test.Hspec spec :: Spec spec = describe "newHasqlBroadcaster" $ do - let newConnection connStr = - either (panic . show) id - <$> acquire connStr + let newConnection connStr = + either (panic . showText) id + <$> acquire connStr - it "relay messages sent to the appropriate database channel" $ do - multi <- either (panic .show) id <$> newHasqlBroadcasterOrError (pure ()) "postgres-websockets" "postgres://postgres:roottoor@localhost:5432/postgres_ws_test" - msg <- liftIO newEmptyMVar - onMessage multi "test" $ putMVar msg + it "relay messages sent to the appropriate database channel" $ do + multi <- either (panic . showText) id <$> newHasqlBroadcasterOrError (pure ()) "postgres-websockets" "postgres://postgres:roottoor@localhost:5432/postgres_ws_test" + msg <- liftIO newEmptyMVar + onMessage multi "test" $ putMVar msg - con <- newConnection "postgres://postgres:roottoor@localhost:5432/postgres_ws_test" - void $ notify con (toPgIdentifier "postgres-websockets") "{\"channel\": \"test\", \"payload\": \"hello there\"}" + con <- newConnection "postgres://postgres:roottoor@localhost:5432/postgres_ws_test" + void $ notify con (toPgIdentifier "postgres-websockets") "{\"channel\": \"test\", \"payload\": \"hello there\"}" - readMVar msg `shouldReturn` Message "test" "{\"channel\": \"test\", \"payload\": \"hello there\"}" + readMVar msg `shouldReturn` Message "test" "{\"channel\": \"test\", \"payload\": \"hello there\"}" diff --git a/test/ServerSpec.hs b/test/ServerSpec.hs index 626921e..6666043 100644 --- a/test/ServerSpec.hs +++ b/test/ServerSpec.hs @@ -1,12 +1,12 @@ module ServerSpec (spec) where +import APrelude import Control.Lens import Data.Aeson.Lens import Network.Socket (withSocketsDo) import qualified Network.WebSockets as WS import PostgresWebsockets import PostgresWebsockets.Config -import Protolude import Test.Hspec testServerConfig :: AppConfig @@ -46,7 +46,7 @@ sendWsData uri msg = WS.runClient "127.0.0.1" (configPort testServerConfig) - (toS uri) + (unpack uri) (`WS.sendTextData` msg) testChannel :: Text @@ -67,7 +67,7 @@ waitForWsData uri = do WS.runClient "127.0.0.1" (configPort testServerConfig) - (toS uri) + (unpack uri) ( \c -> do m <- WS.receiveData c putMVar msg m @@ -84,7 +84,7 @@ waitForMultipleWsData messageCount uri = do WS.runClient "127.0.0.1" (configPort testServerConfig) - (toS uri) + (unpack uri) ( \c -> do m <- replicateM messageCount (WS.receiveData c) putMVar msg m From f4798062283cfd632fc5c6c4d613be53c13345ff Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Sat, 19 Apr 2025 15:00:55 -0400 Subject: [PATCH 07/11] Remove redundant parenthesis --- app/Main.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/Main.hs b/app/Main.hs index fcc41bd..d98eb29 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -11,8 +11,8 @@ main = do hSetBuffering stderr NoBuffering putStrLn $ - ("postgres-websockets ") - <> (unpack prettyVersion) + "postgres-websockets " + <> unpack prettyVersion <> " / Connects websockets to PostgreSQL asynchronous notifications." conf <- loadConfig From f82330cad9d69a557b7472617806faa75aeb6ff6 Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Sat, 19 Apr 2025 15:05:13 -0400 Subject: [PATCH 08/11] Add bounds to mtl and async --- postgres-websockets.cabal | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/postgres-websockets.cabal b/postgres-websockets.cabal index f77b2a2..536d3a9 100644 --- a/postgres-websockets.cabal +++ b/postgres-websockets.cabal @@ -50,8 +50,8 @@ library , http-types >= 0.12.3 && < 0.13 , jose >= 0.11 && < 0.12 , lens >= 5.2.3 && < 5.4 - , mtl - , async + , mtl >=2.3.1 && <2.4 + , async >=2.2.5 && <2.3 , postgresql-libpq >= 0.10.0 && < 0.12 , retry >= 0.8.1.0 && < 0.10 , stm >= 2.5.0.0 && < 2.6 From 577a2f0bf4750c682c2c3c63e37d90e0ec6f95eb Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Wed, 14 May 2025 17:36:44 -0400 Subject: [PATCH 09/11] It seems that without the notifySession being open in the middleware, the runtime detects the waiting on the MVar as an infinite wait. Which although true, it's useful to relay read messages and the connection will be closed by warp when the client disconects. This should fix [#105] --- src/PostgresWebsockets/Middleware.hs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/PostgresWebsockets/Middleware.hs b/src/PostgresWebsockets/Middleware.hs index e0e06d9..bd00ec0 100644 --- a/src/PostgresWebsockets/Middleware.hs +++ b/src/PostgresWebsockets/Middleware.hs @@ -111,8 +111,7 @@ wsApp Context {..} pendingConn = when (hasWrite mode) $ notifySession conn sendNotification chs - waitForever <- newEmptyMVar - void $ takeMVar waitForever + void $ forever $ threadDelay maxBound -- Having both channel and claims as parameters seem redundant -- But it allows the function to ignore the claims structure and the source From 25f7fc7588b96f213ac08254bde472df6663e354 Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Thu, 15 May 2025 19:19:50 -0400 Subject: [PATCH 10/11] Remove unnecessary imports --- src/PostgresWebsockets/Config.hs | 2 +- src/PostgresWebsockets/HasqlBroadcast.hs | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/PostgresWebsockets/Config.hs b/src/PostgresWebsockets/Config.hs index 7215e56..ea471de 100644 --- a/src/PostgresWebsockets/Config.hs +++ b/src/PostgresWebsockets/Config.hs @@ -18,7 +18,7 @@ import APrelude import qualified Data.ByteString as BS import qualified Data.ByteString.Base64 as B64 import Data.String (IsString (..)) -import Data.Text (intercalate, pack, replace, strip, stripPrefix) +import Data.Text (intercalate, replace, strip, stripPrefix) import Data.Version (versionBranch) import Env import Network.Wai.Handler.Warp diff --git a/src/PostgresWebsockets/HasqlBroadcast.hs b/src/PostgresWebsockets/HasqlBroadcast.hs index 9bb39a2..3a3e440 100644 --- a/src/PostgresWebsockets/HasqlBroadcast.hs +++ b/src/PostgresWebsockets/HasqlBroadcast.hs @@ -21,8 +21,6 @@ import Data.Aeson (Value (..), decode) import qualified Data.Aeson.Key as Key import qualified Data.Aeson.KeyMap as JSON import Data.Either.Combinators (mapBoth) -import Data.Function (id) -import GHC.Show import Hasql.Connection import qualified Hasql.Decoders as HD import qualified Hasql.Encoders as HE From 51792c6056516d8db6fe2e468f680ddb23820735 Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Thu, 15 May 2025 20:49:37 -0400 Subject: [PATCH 11/11] Bump patch version nuumber --- postgres-websockets.cabal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres-websockets.cabal b/postgres-websockets.cabal index 536d3a9..ec7231c 100644 --- a/postgres-websockets.cabal +++ b/postgres-websockets.cabal @@ -1,6 +1,6 @@ cabal-version: 3.0 name: postgres-websockets -version: 0.11.2.3 +version: 0.11.2.4 synopsis: Middleware to map LISTEN/NOTIFY messages to Websockets description: WAI middleware that adds websockets capabilites on top of PostgreSQL's asynchronous notifications using LISTEN and NOTIFY commands. Fully functioning server included. homepage: https://github.com/diogob/postgres-websockets#readme