{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
module Web.Cookie
(
SetCookie
, setCookieName
, setCookieValue
, setCookiePath
, setCookieExpires
, setCookieMaxAge
, setCookieDomain
, setCookieHttpOnly
, setCookieSecure
, setCookieSameSite
, SameSiteOption
, sameSiteLax
, sameSiteStrict
, sameSiteNone
, parseSetCookie
, renderSetCookie
, renderSetCookieBS
, defaultSetCookie
, def
, Cookies
, parseCookies
, renderCookies
, renderCookiesBS
, CookiesText
, parseCookiesText
, renderCookiesText
, expiresFormat
, formatCookieExpires
, parseCookieExpires
) where
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy as L
import Data.Char (toLower, isDigit)
import Data.ByteString.Builder (Builder, byteString, char8, toLazyByteString)
import Data.ByteString.Builder.Extra (byteStringCopy)
#if !(MIN_VERSION_base(4,8,0))
import Data.Monoid (mempty, mappend, mconcat)
#endif
import Data.Word (Word8)
import Data.Ratio (numerator, denominator)
import Data.Time (UTCTime (UTCTime), toGregorian, fromGregorian, formatTime, parseTimeM, defaultTimeLocale)
import Data.Time.Clock (DiffTime, secondsToDiffTime)
import Control.Arrow (first, (***))
import Data.Text (Text)
import Data.Text.Encoding (encodeUtf8Builder, decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Data.Maybe (isJust)
import Data.Default.Class (Default (def))
import Control.DeepSeq (NFData (rnf))
type CookiesText = [(Text, Text)]
parseCookiesText :: S.ByteString -> CookiesText
parseCookiesText =
map (go *** go) . parseCookies
where
go = decodeUtf8With lenientDecode
renderCookiesText :: CookiesText -> Builder
renderCookiesText = renderCookiesBuilder . map (encodeUtf8Builder *** encodeUtf8Builder)
type Cookies = [(S.ByteString, S.ByteString)]
parseCookies :: S.ByteString -> Cookies
parseCookies s
| S.null s = []
| otherwise =
let (x, y) = breakDiscard 59 s
in parseCookie x : parseCookies y
parseCookie :: S.ByteString -> (S.ByteString, S.ByteString)
parseCookie s =
let (key, value) = breakDiscard 61 s
key' = S.dropWhile (== 32) key
in (key', value)
breakDiscard :: Word8 -> S.ByteString -> (S.ByteString, S.ByteString)
breakDiscard w s =
let (x, y) = S.break (== w) s
in (x, S.drop 1 y)
type CookieBuilder = (Builder, Builder)
renderCookiesBuilder :: [CookieBuilder] -> Builder
renderCookiesBuilder [] = mempty
renderCookiesBuilder cs =
foldr1 go $ map renderCookie cs
where
go x y = x `mappend` char8 ';' `mappend` y
renderCookie :: CookieBuilder -> Builder
renderCookie (k, v) = k `mappend` char8 '=' `mappend` v
renderCookies :: Cookies -> Builder
renderCookies = renderCookiesBuilder . map (byteString *** byteString)
renderCookiesBS :: Cookies -> S.ByteString
renderCookiesBS = L.toStrict . toLazyByteString . renderCookies
data SetCookie = SetCookie
{ setCookieName :: S.ByteString
, setCookieValue :: S.ByteString
, setCookiePath :: Maybe S.ByteString
, setCookieExpires :: Maybe UTCTime
, setCookieMaxAge :: Maybe DiffTime
, setCookieDomain :: Maybe S.ByteString
, setCookieHttpOnly :: Bool
, setCookieSecure :: Bool
, setCookieSameSite :: Maybe SameSiteOption
}
deriving (Eq, Show)
data SameSiteOption = Lax
| Strict
| None
deriving (Show, Eq)
instance NFData SameSiteOption where
rnf x = x `seq` ()
sameSiteLax :: SameSiteOption
sameSiteLax = Lax
sameSiteStrict :: SameSiteOption
sameSiteStrict = Strict
sameSiteNone :: SameSiteOption
sameSiteNone = None
instance NFData SetCookie where
rnf (SetCookie a b c d e f g h i) =
a `seq`
b `seq`
rnfMBS c `seq`
rnf d `seq`
rnf e `seq`
rnfMBS f `seq`
rnf g `seq`
rnf h `seq`
rnf i
where
rnfMBS Nothing = ()
rnfMBS (Just bs) = bs `seq` ()
instance Default SetCookie where
def = defaultSetCookie
defaultSetCookie :: SetCookie
defaultSetCookie = SetCookie
{ setCookieName = "name"
, setCookieValue = "value"
, setCookiePath = Nothing
, setCookieExpires = Nothing
, setCookieMaxAge = Nothing
, setCookieDomain = Nothing
, setCookieHttpOnly = False
, setCookieSecure = False
, setCookieSameSite = Nothing
}
renderSetCookie :: SetCookie -> Builder
renderSetCookie sc = mconcat
[ byteString (setCookieName sc)
, char8 '='
, byteString (setCookieValue sc)
, case setCookiePath sc of
Nothing -> mempty
Just path -> byteStringCopy "; Path="
`mappend` byteString path
, case setCookieExpires sc of
Nothing -> mempty
Just e -> byteStringCopy "; Expires=" `mappend`
byteString (formatCookieExpires e)
, case setCookieMaxAge sc of
Nothing -> mempty
Just ma -> byteStringCopy"; Max-Age=" `mappend`
byteString (formatCookieMaxAge ma)
, case setCookieDomain sc of
Nothing -> mempty
Just d -> byteStringCopy "; Domain=" `mappend`
byteString d
, if setCookieHttpOnly sc
then byteStringCopy "; HttpOnly"
else mempty
, if setCookieSecure sc
then byteStringCopy "; Secure"
else mempty
, case setCookieSameSite sc of
Nothing -> mempty
Just Lax -> byteStringCopy "; SameSite=Lax"
Just Strict -> byteStringCopy "; SameSite=Strict"
Just None -> byteStringCopy "; SameSite=None"
]
renderSetCookieBS :: SetCookie -> S.ByteString
renderSetCookieBS = L.toStrict . toLazyByteString . renderSetCookie
parseSetCookie :: S.ByteString -> SetCookie
parseSetCookie a = SetCookie
{ setCookieName = name
, setCookieValue = value
, setCookiePath = lookup "path" flags
, setCookieExpires =
lookup "expires" flags >>= parseCookieExpires
, setCookieMaxAge =
lookup "max-age" flags >>= parseCookieMaxAge
, setCookieDomain = lookup "domain" flags
, setCookieHttpOnly = isJust $ lookup "httponly" flags
, setCookieSecure = isJust $ lookup "secure" flags
, setCookieSameSite = case lookup "samesite" flags of
Just "Lax" -> Just Lax
Just "Strict" -> Just Strict
Just "None" -> Just None
_ -> Nothing
}
where
pairs = map (parsePair . dropSpace) $ S.split 59 a ++ [S8.empty]
(name, value) = head pairs
flags = map (first (S8.map toLower)) $ tail pairs
parsePair = breakDiscard 61
dropSpace = S.dropWhile (== 32)
expiresFormat :: String
expiresFormat = "%a, %d-%b-%Y %X GMT"
formatCookieExpires :: UTCTime -> S.ByteString
formatCookieExpires =
S8.pack . formatTime defaultTimeLocale expiresFormat
parseCookieExpires :: S.ByteString -> Maybe UTCTime
parseCookieExpires =
fmap fuzzYear . parseTimeM True defaultTimeLocale expiresFormat . S8.unpack
where
fuzzYear orig@(UTCTime day diff)
| x >= 70 && x <= 99 = addYear 1900
| x >= 0 && x <= 69 = addYear 2000
| otherwise = orig
where
(x, y, z) = toGregorian day
addYear x' = UTCTime (fromGregorian (x + x') y z) diff
formatCookieMaxAge :: DiffTime -> S.ByteString
formatCookieMaxAge difftime = S8.pack $ show (num `div` denom)
where rational = toRational difftime
num = numerator rational
denom = denominator rational
parseCookieMaxAge :: S.ByteString -> Maybe DiffTime
parseCookieMaxAge bs
| all isDigit unpacked = Just $ secondsToDiffTime $ read unpacked
| otherwise = Nothing
where unpacked = S8.unpack bs