{-# LANGUAGE CPP #-}

#include "HsNetDef.h"

module Network.Socket.Name (
    getPeerName
  , getSocketName
  , socketPort
  , socketPortSafe
  ) where

import Foreign.Marshal.Utils (with)

import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.Types

-- | Getting peer's socket address.
getPeerName :: SocketAddress sa => Socket -> IO sa
getPeerName :: forall sa. SocketAddress sa => Socket -> IO sa
getPeerName Socket
s =
 (Ptr sa -> Int -> IO sa) -> IO sa
forall sa a. SocketAddress sa => (Ptr sa -> Int -> IO a) -> IO a
withNewSocketAddress ((Ptr sa -> Int -> IO sa) -> IO sa)
-> (Ptr sa -> Int -> IO sa) -> IO sa
forall a b. (a -> b) -> a -> b
$ \Ptr sa
ptr Int
sz ->
   CInt -> (Ptr CInt -> IO sa) -> IO sa
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz) ((Ptr CInt -> IO sa) -> IO sa) -> (Ptr CInt -> IO sa) -> IO sa
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
int_star -> Socket -> (CInt -> IO sa) -> IO sa
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO sa) -> IO sa) -> (CInt -> IO sa) -> IO sa
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
     String -> IO CInt -> IO ()
forall a. (Eq a, Num a) => String -> IO a -> IO ()
throwSocketErrorIfMinus1Retry_ String
"Network.Socket.getPeerName" (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
       CInt -> Ptr sa -> Ptr CInt -> IO CInt
forall sa. CInt -> Ptr sa -> Ptr CInt -> IO CInt
c_getpeername CInt
fd Ptr sa
ptr Ptr CInt
int_star
     CInt
_sz <- Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
int_star
     Ptr sa -> IO sa
forall sa. SocketAddress sa => Ptr sa -> IO sa
peekSocketAddress Ptr sa
ptr

-- | Getting my socket address.
getSocketName :: SocketAddress sa => Socket -> IO sa
getSocketName :: forall sa. SocketAddress sa => Socket -> IO sa
getSocketName Socket
s =
 (Ptr sa -> Int -> IO sa) -> IO sa
forall sa a. SocketAddress sa => (Ptr sa -> Int -> IO a) -> IO a
withNewSocketAddress ((Ptr sa -> Int -> IO sa) -> IO sa)
-> (Ptr sa -> Int -> IO sa) -> IO sa
forall a b. (a -> b) -> a -> b
$ \Ptr sa
ptr Int
sz ->
   CInt -> (Ptr CInt -> IO sa) -> IO sa
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz) ((Ptr CInt -> IO sa) -> IO sa) -> (Ptr CInt -> IO sa) -> IO sa
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
int_star -> Socket -> (CInt -> IO sa) -> IO sa
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO sa) -> IO sa) -> (CInt -> IO sa) -> IO sa
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
     String -> IO CInt -> IO ()
forall a. (Eq a, Num a) => String -> IO a -> IO ()
throwSocketErrorIfMinus1Retry_ String
"Network.Socket.getSocketName" (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
       CInt -> Ptr sa -> Ptr CInt -> IO CInt
forall sa. CInt -> Ptr sa -> Ptr CInt -> IO CInt
c_getsockname CInt
fd Ptr sa
ptr Ptr CInt
int_star
     Ptr sa -> IO sa
forall sa. SocketAddress sa => Ptr sa -> IO sa
peekSocketAddress Ptr sa
ptr

foreign import CALLCONV unsafe "getpeername"
  c_getpeername :: CInt -> Ptr sa -> Ptr CInt -> IO CInt
foreign import CALLCONV unsafe "getsockname"
  c_getsockname :: CInt -> Ptr sa -> Ptr CInt -> IO CInt

-- ---------------------------------------------------------------------------
-- socketPort
--
-- The port number the given socket is currently connected to can be
-- determined by calling $port$, is generally only useful when bind
-- was given $aNY\_PORT$.

-- | Getting the port of socket.
--   `IOError` is thrown if a port is not available.
socketPort :: Socket            -- Connected & Bound Socket
           -> IO PortNumber     -- Port Number of Socket
socketPort :: Socket -> IO PortNumber
socketPort Socket
s = do
    SockAddr
sa <- Socket -> IO SockAddr
forall sa. SocketAddress sa => Socket -> IO sa
getSocketName Socket
s
    case SockAddr
sa of
      SockAddrInet PortNumber
port HostAddress
_      -> PortNumber -> IO PortNumber
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return PortNumber
port
      SockAddrInet6 PortNumber
port HostAddress
_ HostAddress6
_ HostAddress
_ -> PortNumber -> IO PortNumber
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return PortNumber
port
      SockAddr
_                        -> IOError -> IO PortNumber
forall a. IOError -> IO a
ioError (IOError -> IO PortNumber) -> IOError -> IO PortNumber
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"Network.Socket.socketPort: AF_UNIX not supported."

-- ---------------------------------------------------------------------------
-- socketPortSafe
-- | Getting the port of socket.
socketPortSafe :: Socket                -- Connected & Bound Socket
               -> IO (Maybe PortNumber) -- Port Number of Socket
socketPortSafe :: Socket -> IO (Maybe PortNumber)
socketPortSafe Socket
s = do
    SockAddr
sa <- Socket -> IO SockAddr
forall sa. SocketAddress sa => Socket -> IO sa
getSocketName Socket
s
    Maybe PortNumber -> IO (Maybe PortNumber)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe PortNumber -> IO (Maybe PortNumber))
-> Maybe PortNumber -> IO (Maybe PortNumber)
forall a b. (a -> b) -> a -> b
$ case SockAddr
sa of
      SockAddrInet PortNumber
port HostAddress
_      -> PortNumber -> Maybe PortNumber
forall a. a -> Maybe a
Just PortNumber
port
      SockAddrInet6 PortNumber
port HostAddress
_ HostAddress6
_ HostAddress
_ -> PortNumber -> Maybe PortNumber
forall a. a -> Maybe a
Just PortNumber
port
      SockAddr
_                        -> Maybe PortNumber
forall a. Maybe a
Nothing