{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module Network.Socket.BufferPool.Recv (
receive
, makeRecvN
) where
import qualified Data.ByteString as BS
import Data.ByteString.Internal (ByteString(..), unsafeCreate)
import Data.IORef
import Foreign.C.Error (eAGAIN, getErrno, throwErrno)
import Foreign.C.Types
import Foreign.Ptr (Ptr, castPtr)
import GHC.Conc (threadWaitRead)
import Network.Socket (Socket, withFdSocket)
import System.Posix.Types (Fd(..))
#ifdef mingw32_HOST_OS
import GHC.IO.FD (FD(..), readRawBufferPtr)
import Network.Socket.BufferPool.Windows
#endif
import Network.Socket.BufferPool.Types
import Network.Socket.BufferPool.Buffer
receive :: Socket -> BufferPool -> Recv
receive :: Socket -> BufferPool -> Recv
receive Socket
sock BufferPool
pool = BufferPool -> (Buffer -> BufSize -> IO BufSize) -> Recv
withBufferPool BufferPool
pool ((Buffer -> BufSize -> IO BufSize) -> Recv)
-> (Buffer -> BufSize -> IO BufSize) -> Recv
forall a b. (a -> b) -> a -> b
$ \Buffer
ptr BufSize
size -> do
#if MIN_VERSION_network(3,1,0)
Socket -> (CInt -> IO BufSize) -> IO BufSize
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO BufSize) -> IO BufSize)
-> (CInt -> IO BufSize) -> IO BufSize
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
fd <- fdSocket sock
#else
let fd = fdSocket sock
#endif
let size' :: CSize
size' = BufSize -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
size
CInt -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> BufSize) -> IO CInt -> IO BufSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Buffer -> CSize -> IO CInt
tryReceive CInt
fd Buffer
ptr CSize
size'
tryReceive :: CInt -> Buffer -> CSize -> IO CInt
tryReceive :: CInt -> Buffer -> CSize -> IO CInt
tryReceive CInt
sock Buffer
ptr CSize
size = IO CInt
go
where
go :: IO CInt
go = do
#ifdef mingw32_HOST_OS
bytes <- windowsThreadBlockHack $ fromIntegral <$> readRawBufferPtr "tryReceive" (FD sock 1) (castPtr ptr) 0 size
#else
CInt
bytes <- CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
c_recv CInt
sock (Buffer -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Buffer
ptr) CSize
size CInt
0
#endif
if CInt
bytes CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== -CInt
1 then do
Errno
errno <- IO Errno
getErrno
if Errno
errno Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eAGAIN then do
Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
sock)
IO CInt
go
else
String -> IO CInt
forall a. String -> IO a
throwErrno String
"tryReceive"
else
CInt -> IO CInt
forall (m :: * -> *) a. Monad m => a -> m a
return CInt
bytes
makeRecvN :: ByteString -> Recv -> IO RecvN
makeRecvN :: ByteString -> Recv -> IO RecvN
makeRecvN ByteString
bs0 Recv
recv = do
IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
RecvN -> IO RecvN
forall (m :: * -> *) a. Monad m => a -> m a
return (RecvN -> IO RecvN) -> RecvN -> IO RecvN
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvN
recvN IORef ByteString
ref Recv
recv
recvN :: IORef ByteString -> Recv -> RecvN
recvN :: IORef ByteString -> Recv -> RecvN
recvN IORef ByteString
ref Recv
recv BufSize
size = do
ByteString
cached <- IORef ByteString -> Recv
forall a. IORef a -> IO a
readIORef IORef ByteString
ref
(ByteString
bs, ByteString
leftover) <- ByteString -> BufSize -> Recv -> IO (ByteString, ByteString)
tryRecvN ByteString
cached BufSize
size Recv
recv
IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
ref ByteString
leftover
ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
tryRecvN :: ByteString -> Int -> IO ByteString -> IO (ByteString, ByteString)
tryRecvN :: ByteString -> BufSize -> Recv -> IO (ByteString, ByteString)
tryRecvN ByteString
init0 BufSize
siz0 Recv
recv
| BufSize
siz0 BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
<= BufSize
len0 = (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString, ByteString) -> IO (ByteString, ByteString))
-> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
siz0 ByteString
init0
| Bool
otherwise = ([ByteString] -> [ByteString])
-> BufSize -> IO (ByteString, ByteString)
go (ByteString
init0ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:) (BufSize
siz0 BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len0)
where
len0 :: BufSize
len0 = ByteString -> BufSize
BS.length ByteString
init0
go :: ([ByteString] -> [ByteString])
-> BufSize -> IO (ByteString, ByteString)
go [ByteString] -> [ByteString]
build BufSize
left = do
ByteString
bs <- Recv
recv
let len :: BufSize
len = ByteString -> BufSize
BS.length ByteString
bs
if BufSize
len BufSize -> BufSize -> Bool
forall a. Eq a => a -> a -> Bool
== BufSize
0 then
(ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"")
else if BufSize
len BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
>= BufSize
left then do
let (ByteString
consume, ByteString
leftover) = BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
left ByteString
bs
ret :: ByteString
ret = BufSize -> [ByteString] -> ByteString
concatN BufSize
siz0 ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
build [ByteString
consume]
(ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ret, ByteString
leftover)
else do
let build' :: [ByteString] -> [ByteString]
build' = [ByteString] -> [ByteString]
build ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString]) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString
bs ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:)
left' :: BufSize
left' = BufSize
left BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len
([ByteString] -> [ByteString])
-> BufSize -> IO (ByteString, ByteString)
go [ByteString] -> [ByteString]
build' BufSize
left'
concatN :: Int -> [ByteString] -> ByteString
concatN :: BufSize -> [ByteString] -> ByteString
concatN BufSize
total [ByteString]
bss0 = BufSize -> (Buffer -> IO ()) -> ByteString
unsafeCreate BufSize
total ((Buffer -> IO ()) -> ByteString)
-> (Buffer -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Buffer
ptr -> [ByteString] -> Buffer -> IO ()
goCopy [ByteString]
bss0 Buffer
ptr
where
goCopy :: [ByteString] -> Buffer -> IO ()
goCopy [] Buffer
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
goCopy (ByteString
bs:[ByteString]
bss) Buffer
ptr = do
Buffer
ptr' <- Buffer -> ByteString -> IO Buffer
copy Buffer
ptr ByteString
bs
[ByteString] -> Buffer -> IO ()
goCopy [ByteString]
bss Buffer
ptr'
#ifndef mingw32_HOST_OS
foreign import ccall unsafe "recv"
c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
#endif