1 {-# LANGUAGE FlexibleContexts      #-}
    2 {-# LANGUAGE MultiParamTypeClasses #-}
    3 {-# LANGUAGE OverloadedStrings     #-}
    4 {-# LANGUAGE RecordWildCards       #-}
    5 {-# LANGUAGE TemplateHaskell       #-}
    6 
    7 module Network.Craze.Internal where
    8 
    9 import           Control.Exception (SomeException)
   10 import           Control.Monad     (forM, when)
   11 import           Data.Map.Lazy     (Map)
   12 import qualified Data.Map.Lazy     as M
   13 import           Data.Monoid       ((<>))
   14 
   15 import           Control.Concurrent.Async.Lifted (Async, async, asyncThreadId,
   16                                                   cancel, waitAnyCatch)
   17 import           Control.Concurrent.Lifted       (threadDelay)
   18 import           Control.Lens                    (at, makeLenses, use, (&),
   19                                                   (.~), (?=))
   20 import           Control.Monad.State             (MonadState)
   21 import           Control.Monad.Trans             (MonadIO, liftIO)
   22 import           Data.Text                       (Text)
   23 import qualified Data.Text                       as T
   24 import qualified Data.Text.IO                    as TIO
   25 import           Network.Curl                    (CurlBuffer, CurlHeader,
   26                                                   CurlResponse_,
   27                                                   curlGetResponse_)
   28 
   29 import Network.Craze.Types
   30 
   31 type ClientMap ht bt a = Map (Async (CurlResponse_ ht bt)) (ClientState a)
   32 
   33 data ClientState a = ClientState
   34   { _csOptions :: ProviderOptions
   35   , _csStatus  :: ClientStatus a
   36   }
   37 
   38 data RaceState ht bt a = RaceState
   39   { _rsClientMap  :: ClientMap ht bt a
   40   , _rsChecker    :: RacerChecker a
   41   , _rsHandler    :: RacerHandler ht bt a
   42   , _rsDebug      :: Bool
   43   , _rsReturnLast :: Bool
   44   }
   45 
   46 makeLenses ''ClientState
   47 makeLenses ''RaceState
   48 
   49 extractStatuses :: RaceState ht bt a -> [(Text, ClientStatus a)]
   50 extractStatuses RaceState{..} = M.elems $ makeTuple <$>  _rsClientMap
   51   where
   52     makeTuple :: ClientState a -> (Text, ClientStatus a)
   53     makeTuple ClientState{..} = (poTag _csOptions, _csStatus)
   54 
   55 makeRaceState
   56   :: (CurlHeader ht, CurlBuffer bt, MonadIO m)
   57   => Text
   58   -> Racer ht bt a
   59   -> m (RaceState ht bt a)
   60 makeRaceState url Racer{..} = do
   61   providerMap <- makeClientMap url racerProviders
   62 
   63   pure $ RaceState
   64     providerMap
   65     racerChecker
   66     racerHandler
   67     racerDebug
   68     racerReturnLast
   69 
   70 makeClientMap
   71   :: (CurlHeader ht, CurlBuffer bt, MonadIO m)
   72   => Text
   73   -> [RacerProvider]
   74   -> m (ClientMap ht bt a)
   75 makeClientMap url providers = M.fromList <$> forM providers (makeClient url)
   76 
   77 makeClient
   78   :: (CurlHeader ht, CurlBuffer bt, MonadIO m)
   79   => Text
   80   -> RacerProvider
   81   -> m (Async (CurlResponse_ ht bt), ClientState a)
   82 makeClient url provider = liftIO $ do
   83   options <- provider
   84   future <- async $ performGet url options
   85 
   86   pure (future, ClientState options Pending)
   87 
   88 performGet
   89   :: (CurlHeader ht, CurlBuffer bt)
   90   => Text
   91   -> ProviderOptions
   92   -> IO (CurlResponse_ ht bt)
   93 performGet url ProviderOptions{..} = do
   94   case poDelay of
   95     Nothing -> pure ()
   96     Just delay -> threadDelay delay
   97 
   98   curlGetResponse_ (T.unpack url) poOptions
   99 
  100 cancelAll :: MonadIO m => [Async a] -> m ()
  101 cancelAll = liftIO . mapM_ (async . cancel)
  102 
  103 cancelRemaining
  104   :: (MonadIO m, MonadState (RaceState ht bt a) m)
  105   => m ()
  106 cancelRemaining = do
  107   remaining <- onlyPending <$> use rsClientMap
  108 
  109   cancelAll $ M.keys remaining
  110 
  111 identifier :: Async (CurlResponse_ ht bt) -> ProviderOptions -> Text
  112 identifier a o = poTag o <> ":" <> (T.pack . show . asyncThreadId $ a)
  113 
  114 onlyPending :: ClientMap ht bt a -> ClientMap ht bt a
  115 onlyPending = M.filter (isPending . _csStatus)
  116 
  117 isPending :: ClientStatus a -> Bool
  118 isPending Pending = True
  119 isPending _ = False
  120 
  121 markAsSuccessful
  122   :: (MonadState (RaceState ht bt a) m)
  123   => Async (CurlResponse_ ht bt)
  124   -> a
  125   -> m ()
  126 markAsSuccessful key result = do
  127   maybePrevious <- use $ rsClientMap . at key
  128 
  129   case maybePrevious of
  130     Just previous -> (rsClientMap . at key)
  131       ?= (previous & csStatus .~ Successful result)
  132     Nothing -> pure ()
  133 
  134 markAsFailure
  135   :: (MonadState (RaceState ht bt a) m)
  136   => Async (CurlResponse_ ht bt)
  137   -> a
  138   -> m ()
  139 markAsFailure key result = do
  140   maybePrevious <- use $ rsClientMap . at key
  141 
  142   case maybePrevious of
  143     Just previous -> (rsClientMap . at key)
  144       ?= (previous & csStatus .~ Failed result)
  145     Nothing -> pure ()
  146 
  147 markAsErrored
  148   :: (MonadState (RaceState ht bt a) m)
  149   => Async (CurlResponse_ ht bt)
  150   -> SomeException
  151   -> m ()
  152 markAsErrored key result = do
  153   maybePrevious <- use $ rsClientMap . at key
  154 
  155   case maybePrevious of
  156     Just previous -> (rsClientMap . at key)
  157       ?= (previous & csStatus .~ Errored result)
  158     Nothing -> pure ()
  159 
  160 waitForOne
  161   :: (Eq a, MonadIO m, MonadState (RaceState ht bt a) m)
  162   => m (Maybe (Async (CurlResponse_ ht bt), a))
  163 waitForOne = do
  164   debug <- use rsDebug
  165   providerMap <- use rsClientMap
  166 
  167   let asyncs = _csOptions <$> onlyPending providerMap
  168 
  169   if null asyncs then pure Nothing else do
  170     winner <- liftIO $ waitAnyCatch (M.keys asyncs)
  171 
  172     case winner of
  173       (as, Right a) -> do
  174         handler <- use rsHandler
  175         check <- use rsChecker
  176         returnLast <- use rsReturnLast
  177         result <- liftIO $ handler a
  178 
  179         if check result then do
  180           markAsSuccessful as result
  181           cancelRemaining
  182 
  183           when debug . liftIO $ do
  184             TIO.putStr "[racer] Winner: "
  185             print (asyncThreadId as)
  186 
  187           pure $ Just (as, result)
  188           else do
  189             markAsFailure as result
  190 
  191             remaining <- M.keys . onlyPending <$> use rsClientMap
  192 
  193             if returnLast && null remaining
  194               then do
  195                 when debug . liftIO $ do
  196                   TIO.putStr "[racer] Reached last. Returning: "
  197                   print (asyncThreadId as)
  198 
  199                 pure $ Just (as, result)
  200               else waitForOne
  201       (as, Left ex) -> markAsErrored as ex >> waitForOne