{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Encrypted tokens/tickets to keep state in the client side.
module Crypto.Token (
    -- * Configuration
    Config,
    defaultConfig,
    interval,
    tokenLifetime,
    threadName,

    -- * Token manager
    TokenManager,
    spawnTokenManager,
    killTokenManager,

    -- * Encryption and decryption
    encryptToken,
    decryptToken,
) where

import Control.Concurrent
import Crypto.Cipher.AES (AES256)
import Crypto.Cipher.Types (AEADMode (..), AuthTag (..))
import qualified Crypto.Cipher.Types as C
import Crypto.Error (maybeCryptoError, throwCryptoError)
import Crypto.Random (getRandomBytes)
import Data.Array.IO
import Data.Bits (xor)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.IORef as I
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import GHC.Conc.Sync (labelThread)
import Network.ByteOrder

----------------------------------------------------------------

type Index = Word16
type Counter = Word64

-- | Configuration for token manager.
data Config = Config
    { Config -> Int
interval :: Int
    -- ^ The interval to generate a new secret and remove the oldest one in seconds.
    , Config -> Int
tokenLifetime :: Int
    -- ^ The token lifetime, that is, tokens can be decrypted in this period.
    , Config -> String
threadName :: String
    }
    deriving (Config -> Config -> Bool
(Config -> Config -> Bool)
-> (Config -> Config -> Bool) -> Eq Config
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Config -> Config -> Bool
== :: Config -> Config -> Bool
$c/= :: Config -> Config -> Bool
/= :: Config -> Config -> Bool
Eq, Int -> Config -> ShowS
[Config] -> ShowS
Config -> String
(Int -> Config -> ShowS)
-> (Config -> String) -> ([Config] -> ShowS) -> Show Config
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Config -> ShowS
showsPrec :: Int -> Config -> ShowS
$cshow :: Config -> String
show :: Config -> String
$cshowList :: [Config] -> ShowS
showList :: [Config] -> ShowS
Show)

-- | Default configuration to update secrets in 30 minutes (1,800 seconds) and token lifetime is 2 hours (7,200 seconds)
--
-- >>> defaultConfig
-- Config {interval = 1800, tokenLifetime = 7200}
defaultConfig :: Config
defaultConfig :: Config
defaultConfig =
    Config
        { interval :: Int
interval = Int
1800
        , tokenLifetime :: Int
tokenLifetime = Int
7200
        , threadName :: String
threadName = String
"Crypto token manager"
        }

----------------------------------------------------------------

-- fixme: mask

-- | The abstract data type for token manager.
data TokenManager = TokenManager
    { TokenManager -> Header
headerMask :: Header
    , TokenManager -> IO (Secret, Word16)
getEncryptSecret :: IO (Secret, Index)
    , TokenManager -> Word16 -> IO Secret
getDecryptSecret :: Index -> IO Secret
    , TokenManager -> ThreadId
threadId :: ThreadId
    }

-- | Spawning a token manager.
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager Config{Int
String
interval :: Config -> Int
tokenLifetime :: Config -> Int
threadName :: Config -> String
interval :: Int
tokenLifetime :: Int
threadName :: String
..} = do
    emp <- IO Secret
emptySecret
    let lim = Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
tokenLifetime Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
interval)
    arr <- newArray (0, lim - 1) emp
    ent <- generateSecret
    writeArray arr 0 ent
    ref <- I.newIORef 0
    tid <- forkIO $ loop arr ref
    labelThread tid threadName
    msk <- newHeaderMask
    return $ TokenManager msk (readCurrentSecret arr ref) (readSecret arr) tid
  where
    loop :: IOArray Word16 Secret -> IORef Word16 -> IO b
loop IOArray Word16 Secret
arr IORef Word16
ref = do
        Int -> IO ()
threadDelay (Int
interval Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000000)
        IOArray Word16 Secret -> IORef Word16 -> IO ()
update IOArray Word16 Secret
arr IORef Word16
ref
        IOArray Word16 Secret -> IORef Word16 -> IO b
loop IOArray Word16 Secret
arr IORef Word16
ref
    update :: IOArray Index Secret -> I.IORef Index -> IO ()
    update :: IOArray Word16 Secret -> IORef Word16 -> IO ()
update IOArray Word16 Secret
arr IORef Word16
ref = do
        idx0 <- IORef Word16 -> IO Word16
forall a. IORef a -> IO a
I.readIORef IORef Word16
ref
        (_, n) <- getBounds arr
        let idx = (Word16
idx0 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
1) Word16 -> Word16 -> Word16
forall a. Integral a => a -> a -> a
`mod` (Word16
n Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
1)
        sec <- generateSecret
        writeArray arr idx sec
        I.writeIORef ref idx

-- | Killing a token manager.
killTokenManager :: TokenManager -> IO ()
killTokenManager :: TokenManager -> IO ()
killTokenManager TokenManager{IO (Secret, Word16)
Header
ThreadId
Word16 -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Word16)
getDecryptSecret :: TokenManager -> Word16 -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Word16)
getDecryptSecret :: Word16 -> IO Secret
threadId :: ThreadId
..} = ThreadId -> IO ()
killThread ThreadId
threadId

----------------------------------------------------------------

readSecret :: IOArray Index Secret -> Index -> IO Secret
readSecret :: IOArray Word16 Secret -> Word16 -> IO Secret
readSecret IOArray Word16 Secret
secrets Word16
idx0 = do
    (_, n) <- IOArray Word16 Secret -> IO (Word16, Word16)
forall i. Ix i => IOArray i Secret -> IO (i, i)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds IOArray Word16 Secret
secrets
    let idx = Word16
idx0 Word16 -> Word16 -> Word16
forall a. Integral a => a -> a -> a
`mod` (Word16
n Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
1)
    readArray secrets idx

readCurrentSecret :: IOArray Index Secret -> I.IORef Index -> IO (Secret, Index)
readCurrentSecret :: IOArray Word16 Secret -> IORef Word16 -> IO (Secret, Word16)
readCurrentSecret IOArray Word16 Secret
arr IORef Word16
ref = do
    idx <- IORef Word16 -> IO Word16
forall a. IORef a -> IO a
I.readIORef IORef Word16
ref
    sec <- readSecret arr idx
    return (sec, idx)

----------------------------------------------------------------

data Secret = Secret
    { Secret -> ByteString
secretIV :: ByteString
    , Secret -> ByteString
secretKey :: ByteString
    , Secret -> IORef Word64
secretCounter :: I.IORef Counter
    }

emptySecret :: IO Secret
emptySecret :: IO Secret
emptySecret = ByteString -> ByteString -> IORef Word64 -> Secret
Secret ByteString
BS.empty ByteString
BS.empty (IORef Word64 -> Secret) -> IO (IORef Word64) -> IO Secret
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Word64 -> IO (IORef Word64)
forall a. a -> IO (IORef a)
I.newIORef Word64
0

generateSecret :: IO Secret
generateSecret :: IO Secret
generateSecret =
    ByteString -> ByteString -> IORef Word64 -> Secret
Secret
        (ByteString -> ByteString -> IORef Word64 -> Secret)
-> IO ByteString -> IO (ByteString -> IORef Word64 -> Secret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ByteString
genIV
        IO (ByteString -> IORef Word64 -> Secret)
-> IO ByteString -> IO (IORef Word64 -> Secret)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO ByteString
genKey
        IO (IORef Word64 -> Secret) -> IO (IORef Word64) -> IO Secret
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Word64 -> IO (IORef Word64)
forall a. a -> IO (IORef a)
I.newIORef Word64
0

genKey :: IO ByteString
genKey :: IO ByteString
genKey = Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
keyLength

genIV :: IO ByteString
genIV :: IO ByteString
genIV = Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
ivLength

----------------------------------------------------------------

ivLength :: Int
ivLength :: Int
ivLength = Int
8

keyLength :: Int
keyLength :: Int
keyLength = Int
32

indexLength :: Int
indexLength :: Int
indexLength = Int
2

counterLength :: Int
counterLength :: Int
counterLength = Int
8

tagLength :: Int
tagLength :: Int
tagLength = Int
16

----------------------------------------------------------------

data Header = Header
    { Header -> Word16
headerIndex :: Index
    , Header -> Word64
headerCounter :: Counter
    }

encodeHeader :: Header -> IO ByteString
encodeHeader :: Header -> IO ByteString
encodeHeader Header{Word16
Word64
headerIndex :: Header -> Word16
headerCounter :: Header -> Word64
headerIndex :: Word16
headerCounter :: Word64
..} = Int -> (WriteBuffer -> IO ()) -> IO ByteString
withWriteBuffer (Int
indexLength Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
counterLength) ((WriteBuffer -> IO ()) -> IO ByteString)
-> (WriteBuffer -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \WriteBuffer
wbuf -> do
    WriteBuffer -> Word16 -> IO ()
write16 WriteBuffer
wbuf Word16
headerIndex
    WriteBuffer -> Word64 -> IO ()
write64 WriteBuffer
wbuf Word64
headerCounter

decodeHeader :: ByteString -> IO Header
decodeHeader :: ByteString -> IO Header
decodeHeader ByteString
bs = ByteString -> (ReadBuffer -> IO Header) -> IO Header
forall a. ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer ByteString
bs ((ReadBuffer -> IO Header) -> IO Header)
-> (ReadBuffer -> IO Header) -> IO Header
forall a b. (a -> b) -> a -> b
$ \ReadBuffer
rbuf ->
    Word16 -> Word64 -> Header
Header (Word16 -> Word64 -> Header) -> IO Word16 -> IO (Word64 -> Header)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Word16
forall a. Readable a => a -> IO Word16
read16 ReadBuffer
rbuf IO (Word64 -> Header) -> IO Word64 -> IO Header
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReadBuffer -> IO Word64
forall a. Readable a => a -> IO Word64
read64 ReadBuffer
rbuf

newHeaderMask :: IO Header
newHeaderMask :: IO Header
newHeaderMask = do
    bin <- Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes (Int
indexLength Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
counterLength) :: IO ByteString
    decodeHeader bin

----------------------------------------------------------------

xorHeader :: Header -> Header -> Header
xorHeader :: Header -> Header -> Header
xorHeader Header
x Header
y =
    Header
        { headerIndex :: Word16
headerIndex = Header -> Word16
headerIndex Header
x Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
`xor` Header -> Word16
headerIndex Header
y
        , headerCounter :: Word64
headerCounter = Header -> Word64
headerCounter Header
x Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
`xor` Header -> Word64
headerCounter Header
y
        }

addHeader :: TokenManager -> Index -> Counter -> ByteString -> IO ByteString
addHeader :: TokenManager -> Word16 -> Word64 -> ByteString -> IO ByteString
addHeader TokenManager{IO (Secret, Word16)
Header
ThreadId
Word16 -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Word16)
getDecryptSecret :: TokenManager -> Word16 -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Word16)
getDecryptSecret :: Word16 -> IO Secret
threadId :: ThreadId
..} Word16
idx Word64
counter ByteString
cipher = do
    let hdr :: Header
hdr = Word16 -> Word64 -> Header
Header Word16
idx Word64
counter
        mskhdr :: Header
mskhdr = Header
headerMask Header -> Header -> Header
`xorHeader` Header
hdr
    hdrbin <- Header -> IO ByteString
encodeHeader Header
mskhdr
    return (hdrbin `BS.append` cipher)

delHeader
    :: TokenManager -> ByteString -> IO (Maybe (Index, Counter, ByteString))
delHeader :: TokenManager
-> ByteString -> IO (Maybe (Word16, Word64, ByteString))
delHeader TokenManager{IO (Secret, Word16)
Header
ThreadId
Word16 -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Word16)
getDecryptSecret :: TokenManager -> Word16 -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Word16)
getDecryptSecret :: Word16 -> IO Secret
threadId :: ThreadId
..} ByteString
token
    | ByteString -> Int
BS.length ByteString
token Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
minlen = Maybe (Word16, Word64, ByteString)
-> IO (Maybe (Word16, Word64, ByteString))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Word16, Word64, ByteString)
forall a. Maybe a
Nothing
    | Bool
otherwise = do
        let (ByteString
hdrbin, ByteString
cipher) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
minlen ByteString
token
        mskhdr <- ByteString -> IO Header
decodeHeader ByteString
hdrbin
        let hdr = Header
headerMask Header -> Header -> Header
`xorHeader` Header
mskhdr
            idx = Header -> Word16
headerIndex Header
hdr
            counter = Header -> Word64
headerCounter Header
hdr
        return $ Just (idx, counter, cipher)
  where
    minlen :: Int
minlen = Int
indexLength Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
counterLength

-- | Encrypting a target value to get a token.
encryptToken
    :: TokenManager
    -> ByteString
    -> IO ByteString
encryptToken :: TokenManager -> ByteString -> IO ByteString
encryptToken TokenManager
mgr ByteString
x = do
    (secret, idx) <- TokenManager -> IO (Secret, Word16)
getEncryptSecret TokenManager
mgr
    (counter, cipher) <- encrypt secret x
    addHeader mgr idx counter cipher

encrypt
    :: Secret
    -> ByteString
    -> IO (Counter, ByteString)
encrypt :: Secret -> ByteString -> IO (Word64, ByteString)
encrypt Secret
secret ByteString
plain = do
    counter <- IORef Word64 -> (Word64 -> (Word64, Word64)) -> IO Word64
forall a b. IORef a -> (a -> (a, b)) -> IO b
I.atomicModifyIORef' (Secret -> IORef Word64
secretCounter Secret
secret) (\Word64
i -> (Word64
i Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1, Word64
i))
    nonce <- makeNonce counter $ secretIV secret
    let cipher = ByteString -> ByteString -> ByteString -> ByteString
aes256gcmEncrypt ByteString
plain (Secret -> ByteString
secretKey Secret
secret) ByteString
nonce
    return (counter, cipher)

-- | Decrypting a token to get a target value.
decryptToken
    :: TokenManager
    -> ByteString
    -> IO (Maybe ByteString)
decryptToken :: TokenManager -> ByteString -> IO (Maybe ByteString)
decryptToken TokenManager
mgr ByteString
token = do
    mx <- TokenManager
-> ByteString -> IO (Maybe (Word16, Word64, ByteString))
delHeader TokenManager
mgr ByteString
token
    case mx of
        Maybe (Word16, Word64, ByteString)
Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        Just (Word16
idx, Word64
counter, ByteString
cipher) -> do
            secret <- TokenManager -> Word16 -> IO Secret
getDecryptSecret TokenManager
mgr Word16
idx
            decrypt secret counter cipher

decrypt
    :: Secret
    -> Counter
    -> ByteString
    -> IO (Maybe ByteString)
decrypt :: Secret -> Word64 -> ByteString -> IO (Maybe ByteString)
decrypt Secret
secret Word64
counter ByteString
cipher = do
    nonce <- Word64 -> ByteString -> IO ByteString
makeNonce Word64
counter (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Secret -> ByteString
secretIV Secret
secret
    return $ aes256gcmDecrypt cipher (secretKey secret) nonce

makeNonce :: Counter -> ByteString -> IO ByteString
makeNonce :: Word64 -> ByteString -> IO ByteString
makeNonce Word64
counter ByteString
iv = do
    cv <- Int -> (Ptr Word8 -> IO ()) -> IO ByteString
BS.create Int
ivLength ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8 -> Ptr Word64
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr) Word64
counter
    return $ iv `BA.xor` cv

----------------------------------------------------------------

constantAdditionalData :: ByteString
constantAdditionalData :: ByteString
constantAdditionalData = ByteString
BS.empty

aes256gcmEncrypt
    :: ByteString
    -> ByteString
    -> ByteString
    -> ByteString
aes256gcmEncrypt :: ByteString -> ByteString -> ByteString -> ByteString
aes256gcmEncrypt ByteString
plain ByteString
key ByteString
nonce = ByteString
cipher ByteString -> ByteString -> ByteString
`BS.append` (Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Bytes
tag)
  where
    conn :: AES256
conn = CryptoFailable AES256 -> AES256
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable AES256
forall key. ByteArray key => key -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
C.cipherInit ByteString
key) :: AES256
    aeadIni :: AEAD AES256
aeadIni = CryptoFailable (AEAD AES256) -> AEAD AES256
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable (AEAD AES256) -> AEAD AES256)
-> CryptoFailable (AEAD AES256) -> AEAD AES256
forall a b. (a -> b) -> a -> b
$ AEADMode -> AES256 -> ByteString -> CryptoFailable (AEAD AES256)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
forall iv.
ByteArrayAccess iv =>
AEADMode -> AES256 -> iv -> CryptoFailable (AEAD AES256)
C.aeadInit AEADMode
AEAD_GCM AES256
conn ByteString
nonce
    (AuthTag Bytes
tag, ByteString
cipher) = AEAD AES256
-> ByteString -> ByteString -> Int -> (AuthTag, ByteString)
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> Int -> (AuthTag, ba)
C.aeadSimpleEncrypt AEAD AES256
aeadIni ByteString
constantAdditionalData ByteString
plain Int
tagLength

aes256gcmDecrypt
    :: ByteString
    -> ByteString
    -> ByteString
    -> Maybe ByteString
aes256gcmDecrypt :: ByteString -> ByteString -> ByteString -> Maybe ByteString
aes256gcmDecrypt ByteString
ctexttag ByteString
key ByteString
nonce = do
    aes <- CryptoFailable AES256 -> Maybe AES256
forall a. CryptoFailable a -> Maybe a
maybeCryptoError (CryptoFailable AES256 -> Maybe AES256)
-> CryptoFailable AES256 -> Maybe AES256
forall a b. (a -> b) -> a -> b
$ ByteString -> CryptoFailable AES256
forall key. ByteArray key => key -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
C.cipherInit ByteString
key :: Maybe AES256
    aead <- maybeCryptoError $ C.aeadInit AEAD_GCM aes nonce
    let (ctext, tag) = BS.splitAt (BS.length ctexttag - tagLength) ctexttag
        authtag = Bytes -> AuthTag
AuthTag (Bytes -> AuthTag) -> Bytes -> AuthTag
forall a b. (a -> b) -> a -> b
$ ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
tag
    C.aeadSimpleDecrypt aead constantAdditionalData ctext authtag