module Hasql.Generate.Internal.Introspect
    ( ColumnInfo (..)
    , introspectColumns
    , introspectEnumLabels
    , introspectPrimaryKey
    ) where

----------------------------------------------------------------------------------------------------

import           Control.Exception         ( throwIO )
import           Control.Monad             ( return )

import           Data.Bool                 ( Bool (..), otherwise )
import qualified Data.ByteString           as BS
import qualified Data.ByteString.Char8     as BS8
import           Data.Eq                   ( (==) )
import           Data.Function             ( ($) )
import           Data.Functor              ( (<$>) )
import           Data.Maybe                ( Maybe (..), maybe )
import           Data.Semigroup            ( (<>) )
import           Data.String               ( String )

import qualified Database.PostgreSQL.LibPQ as PQ

import           Prelude
    ( Applicative (pure)
    , userError
    , (+)
    )

import           System.IO                 ( IO )

----------------------------------------------------------------------------------------------------

data ColumnInfo
    = ColumnInfo
      { ColumnInfo -> String
colName       :: String
      , ColumnInfo -> String
colPgSchema   :: String
      , ColumnInfo -> String
colPgType     :: String
      , ColumnInfo -> Bool
colIsEnum     :: Bool
      , ColumnInfo -> Bool
colNotNull    :: Bool
      , ColumnInfo -> Bool
colHasDefault :: Bool
      }

----------------------------------------------------------------------------------------------------

{-  Query @pg_catalog@ for all visible user columns of the given table, including
    both base types (@typtype = 'b'@) and enum types (@typtype = 'e'@), excluding
    composite, pseudo, and range types. Results are ordered by column position.
-}
introspectColumns :: PQ.Connection -> String -> String -> IO [ColumnInfo]
introspectColumns :: Connection -> String -> String -> IO [ColumnInfo]
introspectColumns Connection
conn String
schema String
table = do
  result <- Connection
-> ByteString
-> [Maybe (Oid, ByteString, Format)]
-> Format
-> IO (Maybe Result)
PQ.execParams Connection
conn ByteString
columnSQL [String -> Maybe (Oid, ByteString, Format)
textParam String
schema, String -> Maybe (Oid, ByteString, Format)
textParam String
table] Format
PQ.Text
  withQueryResult "column" result parseColumnRows

----------------------------------------------------------------------------------------------------

{-  Query @pg_catalog@ for the primary key column names of the given table,
    ordered by their position within the index.
-}
introspectPrimaryKey :: PQ.Connection -> String -> String -> IO [String]
introspectPrimaryKey :: Connection -> String -> String -> IO [String]
introspectPrimaryKey Connection
conn String
schema String
table = do
  result <- Connection
-> ByteString
-> [Maybe (Oid, ByteString, Format)]
-> Format
-> IO (Maybe Result)
PQ.execParams Connection
conn ByteString
primaryKeySQL [String -> Maybe (Oid, ByteString, Format)
textParam String
schema, String -> Maybe (Oid, ByteString, Format)
textParam String
table] Format
PQ.Text
  withQueryResult "primary key" result parsePkRows

----------------------------------------------------------------------------------------------------

{-  Query @pg_catalog@ for the labels of a PostgreSQL enum type in the given
    schema, ordered by their sort position (@enumsortorder@).
-}
introspectEnumLabels :: PQ.Connection -> String -> String -> IO [String]
introspectEnumLabels :: Connection -> String -> String -> IO [String]
introspectEnumLabels Connection
conn String
schema String
typeName = do
  result <- Connection
-> ByteString
-> [Maybe (Oid, ByteString, Format)]
-> Format
-> IO (Maybe Result)
PQ.execParams Connection
conn ByteString
enumLabelSQL [String -> Maybe (Oid, ByteString, Format)
textParam String
schema, String -> Maybe (Oid, ByteString, Format)
textParam String
typeName] Format
PQ.Text
  withQueryResult "enum label" result parseEnumRows

----------------------------------------------------------------------------------------------------

columnSQL :: BS.ByteString
columnSQL :: ByteString
columnSQL =
  ByteString
"SELECT tn.nspname, a.attname, t.typname, t.typtype, a.attnotnull, a.atthasdef \
  \FROM pg_catalog.pg_attribute a \
  \JOIN pg_catalog.pg_class c ON c.oid = a.attrelid \
  \JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
  \JOIN pg_catalog.pg_type t ON t.oid = a.atttypid \
  \JOIN pg_catalog.pg_namespace tn ON tn.oid = t.typnamespace \
  \WHERE n.nspname = $1 \
  \  AND c.relname = $2 \
  \  AND a.attnum > 0 \
  \  AND NOT a.attisdropped \
  \  AND t.typtype IN ('b', 'e') \
  \  AND t.typcategory NOT IN ('C', 'P', 'X') \
  \ORDER BY a.attnum"

primaryKeySQL :: BS.ByteString
primaryKeySQL :: ByteString
primaryKeySQL =
  ByteString
"SELECT a.attname \
  \FROM pg_catalog.pg_index i \
  \JOIN pg_catalog.pg_class c ON c.oid = i.indrelid \
  \JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace \
  \JOIN pg_catalog.pg_attribute a ON a.attrelid = i.indrelid \
  \  AND a.attnum = ANY(i.indkey) \
  \WHERE n.nspname = $1 \
  \  AND c.relname = $2 \
  \  AND i.indisprimary \
  \ORDER BY array_position(i.indkey, a.attnum)"

enumLabelSQL :: BS.ByteString
enumLabelSQL :: ByteString
enumLabelSQL =
  ByteString
"SELECT e.enumlabel \
  \FROM pg_catalog.pg_enum e \
  \JOIN pg_catalog.pg_type t ON t.oid = e.enumtypid \
  \JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace \
  \WHERE n.nspname = $1 \
  \  AND t.typname = $2 \
  \ORDER BY e.enumsortorder"

----------------------------------------------------------------------------------------------------

parseColumnRows :: PQ.Result -> IO [ColumnInfo]
parseColumnRows :: Result -> IO [ColumnInfo]
parseColumnRows Result
res = do
  nRows <- Result -> IO Row
PQ.ntuples Result
res
  mapRows nRows $ \Row
row -> do
    schemaVal <- Result -> Row -> Column -> IO (Maybe ByteString)
PQ.getvalue Result
res Row
row Column
0
    nameVal <- PQ.getvalue res row 1
    typeVal <- PQ.getvalue res row 2
    typtypeVal <- PQ.getvalue res row 3
    notNullVal <- PQ.getvalue res row 4
    defVal <- PQ.getvalue res row 5
    let schema = String -> (ByteString -> String) -> Maybe ByteString -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" ByteString -> String
BS8.unpack Maybe ByteString
schemaVal
        name = String -> (ByteString -> String) -> Maybe ByteString -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" ByteString -> String
BS8.unpack Maybe ByteString
nameVal
        typ = String -> (ByteString -> String) -> Maybe ByteString -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" ByteString -> String
BS8.unpack Maybe ByteString
typeVal
        isEnum = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"e" Maybe ByteString -> Maybe ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe ByteString
typtypeVal
        notNull = Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False ByteString -> Bool
parseBool Maybe ByteString
notNullVal
        hasDef = Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False ByteString -> Bool
parseBool Maybe ByteString
defVal
    return (ColumnInfo name schema typ isEnum notNull hasDef)

parsePkRows :: PQ.Result -> IO [String]
parsePkRows :: Result -> IO [String]
parsePkRows Result
res = do
  nRows <- Result -> IO Row
PQ.ntuples Result
res
  mapRows nRows $ \Row
row -> do
    nameVal <- Result -> Row -> Column -> IO (Maybe ByteString)
PQ.getvalue Result
res Row
row Column
0
    return (maybe "" BS8.unpack nameVal)

parseEnumRows :: PQ.Result -> IO [String]
parseEnumRows :: Result -> IO [String]
parseEnumRows Result
res = do
  nRows <- Result -> IO Row
PQ.ntuples Result
res
  mapRows nRows $ \Row
row -> do
    labelVal <- Result -> Row -> Column -> IO (Maybe ByteString)
PQ.getvalue Result
res Row
row Column
0
    return (maybe "" BS8.unpack labelVal)

parseBool :: BS.ByteString -> Bool
parseBool :: ByteString -> Bool
parseBool ByteString
bs
  | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"t" = Bool
True
  | Bool
otherwise = Bool
False

----------------------------------------------------------------------------------------------------

-- | Wrap a 'String' as a libpq text parameter (oid 25 = @text@).
textParam :: String -> Maybe (PQ.Oid, BS.ByteString, PQ.Format)
textParam :: String -> Maybe (Oid, ByteString, Format)
textParam String
str = (Oid, ByteString, Format) -> Maybe (Oid, ByteString, Format)
forall a. a -> Maybe a
Just (CUInt -> Oid
PQ.Oid CUInt
25, String -> ByteString
BS8.pack String
str, Format
PQ.Text)

{-  Unwrap and validate a libpq query result, passing it to a row-parser on
    success. The @label@ argument names the introspection step for error
    messages.
-}
withQueryResult :: String -> Maybe PQ.Result -> (PQ.Result -> IO a) -> IO a
withQueryResult :: forall a. String -> Maybe Result -> (Result -> IO a) -> IO a
withQueryResult String
label Maybe Result
result Result -> IO a
parse = case Maybe Result
result of
  Maybe Result
Nothing -> IOError -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (String -> IOError
userError (String
"hasql-generate: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
label String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" introspection query returned no result"))
  Just Result
res -> do
    status <- Result -> IO ExecStatus
PQ.resultStatus Result
res
    if status == PQ.TuplesOk
      then parse res
      else do
        msg <- maybe "unknown error" BS8.unpack <$> PQ.resultErrorMessage res
        throwIO (userError ("hasql-generate: " <> label <> " introspection failed: " <> msg))

-- |  Iterate over row indices [0 .. nRows-1] collecting results.
mapRows :: PQ.Row -> (PQ.Row -> IO a) -> IO [a]
mapRows :: forall a. Row -> (Row -> IO a) -> IO [a]
mapRows Row
nRows Row -> IO a
fn = Row -> IO [a]
go (CInt -> Row
PQ.Row CInt
0)
  where
    go :: Row -> IO [a]
go Row
i
      | Row
i Row -> Row -> Bool
forall a. Eq a => a -> a -> Bool
== Row
nRows = [a] -> IO [a]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
      | Bool
otherwise = do
          x <- Row -> IO a
fn Row
i
          xs <- go (nextRow i)
          pure (x : xs)

    nextRow :: PQ.Row -> PQ.Row
    nextRow :: Row -> Row
nextRow (PQ.Row CInt
n) = CInt -> Row
PQ.Row (CInt
n CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
+ CInt
1)