-- | A module that contains exception-safe equivalents of @inline-c@ QuasiQuoters.

{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE QuasiQuotes #-}

module Language.C.Inline.Cpp.Exception
  ( CppException(..)
  , CppExceptionPtr
  , toSomeException
  , throwBlock
  , tryBlock
  , catchBlock
  , tryBlockQuoteExp
  ) where

import           Control.Exception.Safe
import qualified Data.ByteString.Unsafe as BS (unsafePackMallocCString)
import           Data.ByteString (ByteString)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Language.C.Inline as C
import qualified Language.C.Inline.Internal as C
import qualified Language.C.Inline.Cpp as Cpp
import           Language.C.Inline.Cpp (AbstractCppExceptionPtr)
import           Language.Haskell.TH
import           Language.Haskell.TH.Quote
import           Foreign
import           Foreign.C
import           System.IO.Unsafe(unsafePerformIO)

C.context Cpp.cppCtx
C.include "HaskellException.hxx"

-- | An exception thrown in C++ code.
data CppException
  = CppStdException CppExceptionPtr ByteString (Maybe ByteString)
  | CppHaskellException SomeException
  | CppNonStdException CppExceptionPtr (Maybe ByteString)

instance Show CppException where
  showsPrec :: Int -> CppException -> ShowS
showsPrec Int
p (CppStdException CppExceptionPtr
_ ByteString
msg Maybe ByteString
typ) = Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
11) (String -> ShowS
showString String
"CppStdException e " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 ByteString
msg ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Maybe ByteString -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 Maybe ByteString
typ)
  showsPrec Int
p (CppHaskellException SomeException
e) = Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
11) (String -> ShowS
showString String
"CppHaskellException " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> SomeException -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 SomeException
e)
  showsPrec Int
p (CppNonStdException CppExceptionPtr
_ Maybe ByteString
typ) = Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
11) (String -> ShowS
showString String
"CppOtherException e " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Maybe ByteString -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 Maybe ByteString
typ)

instance Exception CppException where
  displayException :: CppException -> String
displayException (CppStdException CppExceptionPtr
_ ByteString
msg Maybe ByteString
_typ) = ByteString -> String
bsToChars ByteString
msg
  displayException (CppHaskellException SomeException
e) = SomeException -> String
forall e. Exception e => e -> String
displayException SomeException
e
  displayException (CppNonStdException CppExceptionPtr
_ (Just ByteString
typ)) = String
"exception: Exception of type " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ByteString -> String
bsToChars ByteString
typ
  displayException (CppNonStdException CppExceptionPtr
_ Maybe ByteString
Nothing) = String
"exception: Non-std exception of unknown type"

type CppExceptionPtr = ForeignPtr AbstractCppExceptionPtr

-- | This converts a plain pointer to a managed object.
--
-- The pointer must have been created with @new@. The returned 'CppExceptionPtr'
-- will @delete@ it when it is garbage collected, so you must not @delete@ it
-- on your own. This function is called "unsafe" because it is not memory safe
-- by itself, but safe when used correctly; similar to for example
-- 'BS.unsafePackMallocCString'.
unsafeFromNewCppExceptionPtr :: Ptr AbstractCppExceptionPtr -> IO CppExceptionPtr
unsafeFromNewCppExceptionPtr :: Ptr AbstractCppExceptionPtr -> IO CppExceptionPtr
unsafeFromNewCppExceptionPtr = FinalizerPtr AbstractCppExceptionPtr
-> Ptr AbstractCppExceptionPtr -> IO CppExceptionPtr
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr AbstractCppExceptionPtr
finalizeAbstractCppExceptionPtr

finalizeAbstractCppExceptionPtr :: FinalizerPtr AbstractCppExceptionPtr
{-# NOINLINE finalizeAbstractCppExceptionPtr #-}
finalizeAbstractCppExceptionPtr :: FinalizerPtr AbstractCppExceptionPtr
finalizeAbstractCppExceptionPtr =
  IO (FinalizerPtr AbstractCppExceptionPtr)
-> FinalizerPtr AbstractCppExceptionPtr
forall a. IO a -> a
unsafePerformIO
    IO (FinalizerPtr AbstractCppExceptionPtr)
[C.exp|
      void (*)(std::exception_ptr *) {
        [](std::exception_ptr *v){ delete v; }
      }|]

-- | Like 'toException' but unwrap 'CppHaskellException'
toSomeException :: CppException -> SomeException
toSomeException :: CppException -> SomeException
toSomeException (CppHaskellException SomeException
e) = SomeException
e
toSomeException CppException
x = CppException -> SomeException
forall e. Exception e => e -> SomeException
toException CppException
x

-- NOTE: Other C++ exception types (std::runtime_error etc) could be distinguished like this in the future.
pattern ExTypeNoException :: CInt
pattern $mExTypeNoException :: forall {r}. CInt -> ((# #) -> r) -> ((# #) -> r) -> r
$bExTypeNoException :: CInt
ExTypeNoException = 0

pattern ExTypeStdException :: CInt
pattern $mExTypeStdException :: forall {r}. CInt -> ((# #) -> r) -> ((# #) -> r) -> r
$bExTypeStdException :: CInt
ExTypeStdException = 1

pattern ExTypeHaskellException :: CInt
pattern $mExTypeHaskellException :: forall {r}. CInt -> ((# #) -> r) -> ((# #) -> r) -> r
$bExTypeHaskellException :: CInt
ExTypeHaskellException = 2

pattern ExTypeOtherException :: CInt
pattern $mExTypeOtherException :: forall {r}. CInt -> ((# #) -> r) -> ((# #) -> r) -> r
$bExTypeOtherException :: CInt
ExTypeOtherException = 3


handleForeignCatch :: (Ptr (Ptr ()) -> IO a) -> IO (Either CppException a)
handleForeignCatch :: forall a. (Ptr (Ptr ()) -> IO a) -> IO (Either CppException a)
handleForeignCatch Ptr (Ptr ()) -> IO a
cont =
  Int
-> Int
-> (Ptr (Ptr ()) -> IO (Either CppException a))
-> IO (Either CppException a)
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned (Ptr () -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr ()
forall a. HasCallStack => a
undefined :: Ptr ()) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
5) (Ptr () -> Int
forall a. Storable a => a -> Int
alignment (Ptr ()
forall a. HasCallStack => a
undefined :: Ptr ())) ((Ptr (Ptr ()) -> IO (Either CppException a))
 -> IO (Either CppException a))
-> (Ptr (Ptr ()) -> IO (Either CppException a))
-> IO (Either CppException a)
forall a b. (a -> b) -> a -> b
$ \Ptr (Ptr ())
basePtr -> do
    let ptrSize :: Int
ptrSize         = Ptr () -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr ()
forall a. HasCallStack => a
undefined :: Ptr ())
        exTypePtr :: Ptr CInt
exTypePtr       = Ptr (Ptr ()) -> Ptr CInt
forall a b. Ptr a -> Ptr b
castPtr Ptr (Ptr ())
basePtr :: Ptr CInt
        msgCStringPtr :: Ptr CString
msgCStringPtr   = Ptr Any -> Ptr CString
forall a b. Ptr a -> Ptr b
castPtr (Ptr (Ptr ())
basePtr Ptr (Ptr ()) -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
ptrSize) :: Ptr CString
        typCStringPtr :: Ptr CString
typCStringPtr   = Ptr Any -> Ptr CString
forall a b. Ptr a -> Ptr b
castPtr (Ptr (Ptr ())
basePtr Ptr (Ptr ()) -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
ptrSizeInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
2))  :: Ptr CString
        exPtr :: Ptr (Ptr AbstractCppExceptionPtr)
exPtr           = Ptr Any -> Ptr (Ptr AbstractCppExceptionPtr)
forall a b. Ptr a -> Ptr b
castPtr (Ptr (Ptr ())
basePtr Ptr (Ptr ()) -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
ptrSizeInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3))  :: Ptr (Ptr AbstractCppExceptionPtr)
        haskellExPtrPtr :: Ptr (Ptr ())
haskellExPtrPtr = Ptr Any -> Ptr (Ptr ())
forall a b. Ptr a -> Ptr b
castPtr (Ptr (Ptr ())
basePtr Ptr (Ptr ()) -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
ptrSizeInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
4)) :: Ptr (Ptr ())
    -- we need to mask this entire block because the C++ allocates the
    -- string for the exception message and we need to make sure that
    -- we free it (see the @free@ below). The foreign code would not be
    -- preemptable anyway, so I do not think this loses us anything.
    IO (Either CppException a) -> IO (Either CppException a)
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
mask_ (IO (Either CppException a) -> IO (Either CppException a))
-> IO (Either CppException a) -> IO (Either CppException a)
forall a b. (a -> b) -> a -> b
$ do
      res <- Ptr (Ptr ()) -> IO a
cont Ptr (Ptr ())
basePtr
      exType <- peek exTypePtr
      case exType of
        CInt
ExTypeNoException -> Either CppException a -> IO (Either CppException a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Either CppException a
forall a b. b -> Either a b
Right a
res)
        CInt
ExTypeStdException -> do
          ex <- Ptr AbstractCppExceptionPtr -> IO CppExceptionPtr
unsafeFromNewCppExceptionPtr (Ptr AbstractCppExceptionPtr -> IO CppExceptionPtr)
-> IO (Ptr AbstractCppExceptionPtr) -> IO CppExceptionPtr
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr (Ptr AbstractCppExceptionPtr)
-> IO (Ptr AbstractCppExceptionPtr)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr AbstractCppExceptionPtr)
exPtr

          -- BS.unsafePackMallocCString: safe because setMessageOfStdException
          -- (invoked via tryBlockQuoteExp) sets msgCStringPtr to a newly
          -- malloced string.
          errMsg <- BS.unsafePackMallocCString =<< peek msgCStringPtr

          -- BS.unsafePackMallocCString: safe because currentExceptionTypeName
          -- returns a newly malloced string
          mbExcType <- maybePeek BS.unsafePackMallocCString =<< peek typCStringPtr

          return (Left (CppStdException ex errMsg mbExcType))
        CInt
ExTypeHaskellException -> do
          haskellExPtr <- Ptr (Ptr ()) -> IO (Ptr ())
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr ())
haskellExPtrPtr
          stablePtr <- [C.block| void * {
              return (static_cast<HaskellException *>($(void *haskellExPtr)))->haskellExceptionStablePtr->stablePtr;
            } |]
          someExc <- deRefStablePtr (castPtrToStablePtr stablePtr)
          [C.block| void{
              delete static_cast<HaskellException *>($(void *haskellExPtr));
            } |]
          return (Left (CppHaskellException someExc))
        CInt
ExTypeOtherException -> do
          ex <- Ptr AbstractCppExceptionPtr -> IO CppExceptionPtr
unsafeFromNewCppExceptionPtr (Ptr AbstractCppExceptionPtr -> IO CppExceptionPtr)
-> IO (Ptr AbstractCppExceptionPtr) -> IO CppExceptionPtr
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr (Ptr AbstractCppExceptionPtr)
-> IO (Ptr AbstractCppExceptionPtr)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr AbstractCppExceptionPtr)
exPtr

          -- BS.unsafePackMallocCString: safe because currentExceptionTypeName
          -- returns a newly malloced string
          mbExcType <- maybePeek BS.unsafePackMallocCString =<< peek typCStringPtr

          return (Left (CppNonStdException ex mbExcType)) :: IO (Either CppException a)
        CInt
_ -> String -> IO (Either CppException a)
forall a. HasCallStack => String -> a
error String
"Unexpected C++ exception type."

-- | Like 'tryBlock', but will throw unwrapped 'CppHaskellException's or other 'CppException's rather than returning
-- them in an 'Either'
throwBlock :: QuasiQuoter
throwBlock :: QuasiQuoter
throwBlock = QuasiQuoter
  { quoteExp :: String -> Q Exp
quoteExp = \String
blockStr -> do
      [e| either (throwIO . toSomeException) return =<< $(QuasiQuoter -> String -> Q Exp
tryBlockQuoteExp QuasiQuoter
C.block String
blockStr) |]
  , quotePat :: String -> Q Pat
quotePat = String -> Q Pat
forall {m :: * -> *} {p} {a}. MonadFail m => p -> m a
unsupported
  , quoteType :: String -> Q Type
quoteType = String -> Q Type
forall {m :: * -> *} {p} {a}. MonadFail m => p -> m a
unsupported
  , quoteDec :: String -> Q [Dec]
quoteDec = String -> Q [Dec]
forall {m :: * -> *} {p} {a}. MonadFail m => p -> m a
unsupported
  } where
      unsupported :: p -> m a
unsupported p
_ = String -> m a
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unsupported quasiquotation."

-- | Variant of 'throwBlock' for blocks which return 'void'.
catchBlock :: QuasiQuoter
catchBlock :: QuasiQuoter
catchBlock = QuasiQuoter
  { quoteExp :: String -> Q Exp
quoteExp = \String
blockStr -> QuasiQuoter -> String -> Q Exp
quoteExp QuasiQuoter
throwBlock (String
"void {" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
blockStr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"}")
  , quotePat :: String -> Q Pat
quotePat = String -> Q Pat
forall {m :: * -> *} {p} {a}. MonadFail m => p -> m a
unsupported
  , quoteType :: String -> Q Type
quoteType = String -> Q Type
forall {m :: * -> *} {p} {a}. MonadFail m => p -> m a
unsupported
  , quoteDec :: String -> Q [Dec]
quoteDec = String -> Q [Dec]
forall {m :: * -> *} {p} {a}. MonadFail m => p -> m a
unsupported
  } where
      unsupported :: p -> m a
unsupported p
_ = String -> m a
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unsupported quasiquotation."

exceptionalValue :: String -> String
exceptionalValue :: ShowS
exceptionalValue String
typeStr =
  case String
typeStr of
    String
"void" -> String
""
    String
"char" -> String
"0"
    String
"short" -> String
"0"
    String
"long" -> String
"0"
    String
"int" -> String
"0"
    String
"int8_t" -> String
"0"
    String
"int16_t" -> String
"0"
    String
"int32_t" -> String
"0"
    String
"int64_t" -> String
"0"
    String
"uint8_t" -> String
"0"
    String
"uint16_t" -> String
"0"
    String
"uint32_t" -> String
"0"
    String
"uint64_t" -> String
"0"
    String
"float" -> String
"0"
    String
"double" -> String
"0"
    String
"bool" -> String
"0"
    String
"signed char" -> String
"0"
    String
"signed short" -> String
"0"
    String
"signed int" -> String
"0"
    String
"signed long" -> String
"0"
    String
"unsigned char" -> String
"0"
    String
"unsigned short" -> String
"0"
    String
"unsigned int" -> String
"0"
    String
"unsigned long" -> String
"0"
    String
"size_t" -> String
"0"
    String
"wchar_t" -> String
"0"
    String
"ptrdiff_t" -> String
"0"
    String
"sig_atomic_t" -> String
"0"
    String
"intptr_t" -> String
"0"
    String
"uintptr_t" -> String
"0"
    String
"intmax_t" -> String
"0"
    String
"uintmax_t" -> String
"0"
    String
"clock_t" -> String
"0"
    String
"time_t" -> String
"0"
    String
"useconds_t" -> String
"0"
    String
"suseconds_t" -> String
"0"
    String
"FILE" -> String
"0"
    String
"fpos_t" -> String
"0"
    String
"jmp_buf" -> String
"0"
    String
_ -> String
"{}"

tryBlockQuoteExp :: QuasiQuoter -> String -> Q Exp
tryBlockQuoteExp :: QuasiQuoter -> String -> Q Exp
tryBlockQuoteExp QuasiQuoter
block String
blockStr = do
  let (String
ty, String
body, Int
bodyLineShift) = String -> (String, String, Int)
C.splitTypedC String
blockStr
  _ <- String -> Q [Dec]
C.include String
"HaskellException.hxx"
  basePtrVarName <- newName "basePtr"
  there <- location
  let inlineCStr = [String] -> String
unlines
        [ String
ty String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" {"
        , String
"  void** __inline_c_cpp_base_ptr__ = $(void** " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
basePtrVarName String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
");"
        , String
"  int* __inline_c_cpp_exception_type__ = (int*)__inline_c_cpp_base_ptr__;"
        , String
"  const char** __inline_c_cpp_error_message__ = (const char**)(__inline_c_cpp_base_ptr__ + 1);"
        , String
"  const char** __inline_c_cpp_error_typ__ = (const char**)(__inline_c_cpp_base_ptr__ + 2);"
        , String
"  std::exception_ptr** __inline_c_cpp_exception_ptr__ = (std::exception_ptr**)(__inline_c_cpp_base_ptr__ + 3);"
        , String
"  HaskellException** __inline_c_cpp_haskellexception__ = (HaskellException**)(__inline_c_cpp_base_ptr__ + 4);"
        , String
"  *__inline_c_cpp_exception_type__ = 0;"
        , String
"  try {"
        , Loc -> String
C.lineDirective (Int -> Loc -> Loc
C.shiftLines (Int
bodyLineShift Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Loc
there)
        , String
body
        , Loc -> String
C.lineDirective $(C.here)
        , String
"  } catch (const HaskellException &e) {"
        , String
"    *__inline_c_cpp_exception_type__ = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ CInt -> String
forall a. Show a => a -> String
show CInt
ExTypeHaskellException String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
";"
        , String
"    *__inline_c_cpp_haskellexception__ = new HaskellException(e);"
        , String
"    return " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
exceptionalValue String
ty String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
";"
        , String
"  } catch (const std::exception &e) {"
        , String
"    *__inline_c_cpp_exception_ptr__ = new std::exception_ptr(std::current_exception());"
        , String
"    *__inline_c_cpp_exception_type__ = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ CInt -> String
forall a. Show a => a -> String
show CInt
ExTypeStdException String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
";"
        , String
"    setMessageOfStdException(e, __inline_c_cpp_error_message__, __inline_c_cpp_error_typ__);"
        , String
"    return " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
exceptionalValue String
ty String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
";"
        , String
"  } catch (...) {"
        , String
"    *__inline_c_cpp_exception_ptr__ = new std::exception_ptr(std::current_exception());"
        , String
"    *__inline_c_cpp_exception_type__ = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ CInt -> String
forall a. Show a => a -> String
show CInt
ExTypeOtherException String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
";"
        , String
"    setCppExceptionType(__inline_c_cpp_error_typ__);"
        , String
"    return " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
exceptionalValue String
ty String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
";"
        , String
"  }"
        , String
"}"
        ]
  [e| handleForeignCatch $ \ $(varP basePtrVarName) -> $(quoteExp block inlineCStr) |]

-- | Similar to `C.block`, but C++ exceptions will be caught and the result is (Either CppException value). The return type must be void or constructible with @{}@.
-- Using this will automatically include @exception@, @cstring@ and @cstdlib@.
tryBlock :: QuasiQuoter
tryBlock :: QuasiQuoter
tryBlock = QuasiQuoter
  { quoteExp :: String -> Q Exp
quoteExp = QuasiQuoter -> String -> Q Exp
tryBlockQuoteExp QuasiQuoter
C.block
  , quotePat :: String -> Q Pat
quotePat = String -> Q Pat
forall {m :: * -> *} {p} {a}. MonadFail m => p -> m a
unsupported
  , quoteType :: String -> Q Type
quoteType = String -> Q Type
forall {m :: * -> *} {p} {a}. MonadFail m => p -> m a
unsupported
  , quoteDec :: String -> Q [Dec]
quoteDec = String -> Q [Dec]
forall {m :: * -> *} {p} {a}. MonadFail m => p -> m a
unsupported
  } where
      unsupported :: p -> m a
unsupported p
_ = String -> m a
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unsupported quasiquotation."

bsToChars :: ByteString -> String
bsToChars :: ByteString -> String
bsToChars = Text -> String
T.unpack (Text -> String) -> (ByteString -> Text) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OnDecodeError -> ByteString -> Text
T.decodeUtf8With OnDecodeError
T.lenientDecode