{-# LANGUAGE CPP #-}
module Data.ProtoLens.Encoding (
    encodeMessage,
    buildMessage,
    decodeMessage,
    parseMessage,
    decodeMessageOrDie,
    -- ** Delimited messages
    buildMessageDelimited,
    parseMessageDelimited,
    decodeMessageDelimitedH,
    ) where

import System.IO (Handle)

import Data.ProtoLens.Message (Message(..))
import Data.ProtoLens.Encoding.Bytes (Parser, Builder)
import qualified Data.ProtoLens.Encoding.Bytes as Bytes

import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (runExceptT, ExceptT(..))
import qualified Data.ByteString as B
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup ((<>))
#endif

-- | Decode a message from its wire format.  Returns 'Left' if the decoding
-- fails.
decodeMessage :: Message msg => B.ByteString -> Either String msg
decodeMessage :: forall msg. Message msg => ByteString -> Either String msg
decodeMessage = Parser msg -> ByteString -> Either String msg
forall a. Parser a -> ByteString -> Either String a
Bytes.runParser Parser msg
forall msg. Message msg => Parser msg
parseMessage

-- | Decode a message from its wire format.  Throws an error if the decoding
-- fails.
decodeMessageOrDie :: Message msg => B.ByteString -> msg
decodeMessageOrDie :: forall msg. Message msg => ByteString -> msg
decodeMessageOrDie ByteString
bs = case ByteString -> Either String msg
forall msg. Message msg => ByteString -> Either String msg
decodeMessage ByteString
bs of
    Left String
e -> String -> msg
forall a. HasCallStack => String -> a
error (String -> msg) -> String -> msg
forall a b. (a -> b) -> a -> b
$ String
"decodeMessageOrDie: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e
    Right msg
x -> msg
x

-- | Encode a message to the wire format as a strict 'ByteString'.
encodeMessage :: Message msg => msg -> B.ByteString
encodeMessage :: forall msg. Message msg => msg -> ByteString
encodeMessage = Builder -> ByteString
Bytes.runBuilder (Builder -> ByteString) -> (msg -> Builder) -> msg -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. msg -> Builder
forall msg. Message msg => msg -> Builder
buildMessage

-- | Encode a message to the wire format, prefixed by its size as a VarInt,
-- as part of a 'Builder'.
--
-- This can be used to build up streams of messages in the size-delimited
-- format expected by some protocols.
buildMessageDelimited :: Message msg => msg -> Builder
buildMessageDelimited :: forall msg. Message msg => msg -> Builder
buildMessageDelimited msg
msg =
    let b :: ByteString
b = msg -> ByteString
forall msg. Message msg => msg -> ByteString
encodeMessage msg
msg
    in Word64 -> Builder
Bytes.putVarInt (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
b) Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
Bytes.putBytes ByteString
b

parseMessageDelimited :: Message msg => Parser msg
parseMessageDelimited :: forall msg. Message msg => Parser msg
parseMessageDelimited = do
    Word64
len <- Parser Word64
Bytes.getVarInt
    ByteString
bytes <- Int -> Parser ByteString
Bytes.getBytes (Int -> Parser ByteString) -> Int -> Parser ByteString
forall a b. (a -> b) -> a -> b
$ Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len
    (String -> Parser msg)
-> (msg -> Parser msg) -> Either String msg -> Parser msg
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> Parser msg
forall a. String -> Parser a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail msg -> Parser msg
forall a. a -> Parser a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String msg -> Parser msg)
-> Either String msg -> Parser msg
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String msg
forall msg. Message msg => ByteString -> Either String msg
decodeMessage ByteString
bytes

-- | Same as @decodeMessage@ but for delimited messages read through a Handle
decodeMessageDelimitedH :: Message msg => Handle -> IO (Either String msg)
decodeMessageDelimitedH :: forall msg. Message msg => Handle -> IO (Either String msg)
decodeMessageDelimitedH Handle
h = ExceptT String IO msg -> IO (Either String msg)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT String IO msg -> IO (Either String msg))
-> ExceptT String IO msg -> IO (Either String msg)
forall a b. (a -> b) -> a -> b
$
    Handle -> ExceptT String IO Word64
Bytes.getVarIntH Handle
h ExceptT String IO Word64
-> (Word64 -> ExceptT String IO ByteString)
-> ExceptT String IO ByteString
forall a b.
ExceptT String IO a
-> (a -> ExceptT String IO b) -> ExceptT String IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
    IO ByteString -> ExceptT String IO ByteString
forall a. IO a -> ExceptT String IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ExceptT String IO ByteString)
-> (Word64 -> IO ByteString)
-> Word64
-> ExceptT String IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handle -> Int -> IO ByteString
B.hGet Handle
h (Int -> IO ByteString)
-> (Word64 -> Int) -> Word64 -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ExceptT String IO ByteString
-> (ByteString -> ExceptT String IO msg) -> ExceptT String IO msg
forall a b.
ExceptT String IO a
-> (a -> ExceptT String IO b) -> ExceptT String IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
    IO (Either String msg) -> ExceptT String IO msg
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (IO (Either String msg) -> ExceptT String IO msg)
-> (ByteString -> IO (Either String msg))
-> ByteString
-> ExceptT String IO msg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either String msg -> IO (Either String msg)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String msg -> IO (Either String msg))
-> (ByteString -> Either String msg)
-> ByteString
-> IO (Either String msg)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either String msg
forall msg. Message msg => ByteString -> Either String msg
decodeMessage