{-# LANGUAGE BangPatterns, ForeignFunctionInterface, CPP #-}
module ByteStringUtils
       (unsafeWithInternals, unpackPSfromUTF8, gzReadFilePS, mmapFilePS,
        gzWriteFilePS, gzWriteFilePSs, ifHeadThenTail, dropSpace,
        breakSpace, linesPS, unlinesPS, hashPS, breakFirstPS, breakLastPS,
        substrPS, readIntPS, is_funky, fromHex2PS, fromPS2Hex,
        betweenLinesPS, break_after_nth_newline, break_before_nth_newline,
        intercalate)
       where
import Prelude hiding (catch)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Internal as BI
import Data.ByteString (intercalate, uncons)
import Data.ByteString.Internal (fromForeignPtr)
import Control.Exception (catch)
import System.IO
import System.IO.Unsafe (unsafePerformIO)
import Foreign.Storable (peekElemOff, peek)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (mallocArray, peekArray, advancePtr)
import Foreign.C.Types (CInt)
import Data.Bits (rotateL)
import Data.Char (chr, ord, isSpace)
import Data.Word (Word8)
import Data.Int (Int32)
import Control.Monad (when)
import Foreign.Ptr (nullPtr)
import Foreign.ForeignPtr (ForeignPtr)
import Foreign.Ptr (plusPtr, Ptr)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.ForeignPtr (addForeignPtrFinalizer)
import Foreign.Ptr (FunPtr)
import qualified Data.ByteString.Lazy as BL
import qualified Codec.Compression.GZip as GZ
import Foreign.C.String (CString, withCString)
import System.IO.MMap (mmapFileByteString)
import System.Mem (performGC)
import System.Posix.Files (fileSize, getSymbolicLinkStatus)

debugForeignPtr :: ForeignPtr a -> String -> IO ()

foreign import ccall unsafe "static fpstring.h debug_alloc"
               debug_alloc :: Ptr a -> CString -> IO ()

foreign import ccall unsafe "static fpstring.h & debug_free"
               debug_free :: FunPtr (Ptr a -> IO ())
debugForeignPtr fp n
  = withCString n $
      \ cname ->
        withForeignPtr fp $
          \ p ->
            do debug_alloc p cname
               addForeignPtrFinalizer debug_free fp
debugForeignPtr _ _ = return ()

unsafeWithInternals ::
                    B.ByteString -> (Ptr Word8 -> Int -> IO a) -> IO a
unsafeWithInternals ps f
  = case BI.toForeignPtr ps of
        (fp, s, l) -> withForeignPtr fp $ \ p -> f (p `plusPtr` s) l

readIntPS :: B.ByteString -> Maybe (Int, B.ByteString)
readIntPS = BC.readInt . BC.dropWhile isSpace

unpackPSfromUTF8 :: B.ByteString -> String
unpackPSfromUTF8 ps
  = case BI.toForeignPtr ps of
        (_, _, 0) -> ""
        (x, s, l) -> unsafePerformIO $
                       withForeignPtr x $
                         \ p ->
                           do outbuf <- mallocArray l
                              lout <- fromIntegral `fmap`
                                        utf8_to_ints outbuf (p `plusPtr` s) (fromIntegral l)
                              when (lout < 0) $ error "Bad UTF8!"
                              str <- (map (chr . fromIntegral)) `fmap` peekArray lout outbuf
                              free outbuf
                              return str

foreign import ccall unsafe "static fpstring.h utf8_to_ints"
               utf8_to_ints :: Ptr Int -> Ptr Word8 -> CInt -> IO CInt

{-# INLINE ifHeadThenTail #-}

ifHeadThenTail :: Word8 -> B.ByteString -> Maybe B.ByteString
ifHeadThenTail c s
  = case uncons s of
        Just (w, t) | w == c -> Just t
        _ -> Nothing

isSpaceWord8 :: Word8 -> Bool
isSpaceWord8 w = w == 32 || w == 9 || w == 10 || w == 13

{-# INLINE isSpaceWord8 #-}

firstnonspace :: Ptr Word8 -> Int -> Int -> IO Int
firstnonspace !ptr !n !m
  | n >= m = return n
  | otherwise =
    do w <- peekElemOff ptr n
       if isSpaceWord8 w then firstnonspace ptr (n + 1) m else return n

firstspace :: Ptr Word8 -> Int -> Int -> IO Int
firstspace !ptr !n !m
  | n >= m = return n
  | otherwise =
    do w <- peekElemOff ptr n
       if (not . isSpaceWord8) w then firstspace ptr (n + 1) m else
         return n

dropSpace :: B.ByteString -> B.ByteString
dropSpace (BI.PS x s l)
  = BI.inlinePerformIO $
      withForeignPtr x $
        \ p ->
          do i <- firstnonspace (p `plusPtr` s) 0 l
             return $! if i == l then B.empty else BI.PS x (s + i) (l - i)

{-# INLINE dropSpace #-}

breakSpace :: B.ByteString -> (B.ByteString, B.ByteString)
breakSpace (BI.PS x s l)
  = BI.inlinePerformIO $
      withForeignPtr x $
        \ p ->
          do i <- firstspace (p `plusPtr` s) 0 l
             return $!
               case () of
                   _ | i == 0 -> (B.empty, BI.PS x s l)
                     | i == l -> (BI.PS x s l, B.empty)
                     | otherwise -> (BI.PS x s i, BI.PS x (s + i) (l - i))

{-# INLINE breakSpace #-}

{-# INLINE is_funky #-}

is_funky :: B.ByteString -> Bool
is_funky ps
  = case BI.toForeignPtr ps of
        (x, s, l) -> unsafePerformIO $
                       withForeignPtr x $
                         \ p ->
                           (/= 0) `fmap` has_funky_char (p `plusPtr` s) (fromIntegral l)

foreign import ccall unsafe "fpstring.h has_funky_char"
               has_funky_char :: Ptr Word8 -> CInt -> IO CInt

{-# INLINE hashPS #-}

hashPS :: B.ByteString -> Int32
hashPS ps
  = case BI.toForeignPtr ps of
        (x, s, l) -> unsafePerformIO $
                       withForeignPtr x $ \ p -> do hash (p `plusPtr` s) l

hash :: Ptr Word8 -> Int -> IO Int32
hash ptr len = f (0 :: Int32) ptr len
  where f h _ 0 = return h
        f h p n
          = do x <- peek p
               let !h' = (fromIntegral x) + (rotateL h 8)
               f h' (p `advancePtr` 1) (n - 1)

{-# INLINE substrPS #-}

substrPS :: B.ByteString -> B.ByteString -> Maybe Int
substrPS tok str
  | B.null tok = Just 0
  | B.length tok > B.length str = Nothing
  | otherwise =
    do n <- BC.elemIndex (BC.head tok) str
       let ttok = B.tail tok
           reststr = B.drop (n + 1) str
       if ttok == B.take (B.length ttok) reststr then Just n else
         ((n + 1) +) `fmap` substrPS tok reststr

{-# INLINE breakFirstPS #-}

breakFirstPS ::
             Char -> B.ByteString -> Maybe (B.ByteString, B.ByteString)
breakFirstPS c p
  = case BC.elemIndex c p of
        Nothing -> Nothing
        Just n -> Just (B.take n p, B.drop (n + 1) p)

{-# INLINE breakLastPS #-}

breakLastPS ::
            Char -> B.ByteString -> Maybe (B.ByteString, B.ByteString)
breakLastPS c p
  = case BC.elemIndexEnd c p of
        Nothing -> Nothing
        Just n -> Just (B.take n p, B.drop (n + 1) p)

{-# INLINE linesPS #-}

linesPS :: B.ByteString -> [B.ByteString]
linesPS ps
  | B.null ps = [B.empty]
  | otherwise = BC.split '\n' ps

unlinesPS :: [B.ByteString] -> B.ByteString
unlinesPS [] = BC.empty
unlinesPS x = BC.init $ BC.unlines x

{-# INLINE unlinesPS #-}

foreign import ccall unsafe "static zlib.h gzopen" c_gzopen ::
               CString -> CString -> IO (Ptr ())

foreign import ccall unsafe "static zlib.h gzclose" c_gzclose ::
               Ptr () -> IO ()

foreign import ccall unsafe "static zlib.h gzread" c_gzread ::
               Ptr () -> Ptr Word8 -> CInt -> IO CInt

foreign import ccall unsafe "static zlib.h gzwrite" c_gzwrite ::
               Ptr () -> Ptr Word8 -> CInt -> IO CInt

gzReadFilePS :: FilePath -> IO B.ByteString
gzReadFilePS f
  = do h <- openBinaryFile f ReadMode
       header <- B.hGet h 2
       if header /= BC.pack "\US\139" then
         do hClose h
            mmapFilePS f
         else
         do hSeek h SeekFromEnd (-4)
            len <- hGetLittleEndInt h
            hClose h
            let decompress
                  = GZ.decompressWith
                      GZ.defaultDecompressParams{GZ.decompressBufferSize = len}
            fmap (B.concat . BL.toChunks . decompress) $
              fmap (BL.fromChunks . (: [])) $ B.readFile f BL.readFile f
            withCString f $
              \ fstr ->
                withCString "rb" $
                  \ rb ->
                    do gzf <- c_gzopen fstr rb
                       when (gzf == nullPtr) $ fail $ "problem opening file " ++ f
                       fp <- BI.mallocByteString len
                       debugForeignPtr fp $ "gzReadFilePS " ++ f
                       lread <- withForeignPtr fp $
                                  \ p -> c_gzread gzf p (fromIntegral len)
                       c_gzclose gzf
                       when (fromIntegral lread /= len) $
                         fail $ "problem gzreading file " ++ f
                       return $ fromForeignPtr fp 0 len

hGetLittleEndInt :: Handle -> IO Int
hGetLittleEndInt h
  = do b1 <- ord `fmap` hGetChar h
       b2 <- ord `fmap` hGetChar h
       b3 <- ord `fmap` hGetChar h
       b4 <- ord `fmap` hGetChar h
       return $ b1 + 256 * b2 + 65536 * b3 + 16777216 * b4

gzWriteFilePS :: FilePath -> B.ByteString -> IO ()
gzWriteFilePS f ps = gzWriteFilePSs f [ps]

gzWriteFilePSs :: FilePath -> [B.ByteString] -> IO ()
gzWriteFilePSs f pss
  = BL.writeFile f $
      GZ.compress $
        BL.fromChunks pss withCString f $
          \ fstr ->
            withCString "wb" $
              \ wb ->
                do gzf <- c_gzopen fstr wb
                   when (gzf == nullPtr) $
                     fail $ "problem gzopening file for write: " ++ f
                   mapM_ (gzWriteToGzf gzf) pss `catch`
                     \ _ -> fail $ "problem gzwriting file: " ++ f
                   c_gzclose gzf

gzWriteToGzf :: Ptr () -> B.ByteString -> IO ()
gzWriteToGzf gzf ps
  = case BI.toForeignPtr ps of
        (_, _, 0) -> return ()
        (x, s, l) -> do lw <- withForeignPtr x $
                                \ p -> c_gzwrite gzf (p `plusPtr` s) (fromIntegral l)
                        when (fromIntegral lw /= l) $ fail $ "problem in gzWriteToGzf"

mmapFilePS :: FilePath -> IO B.ByteString
mmapFilePS f
  = do x <- mmapFileByteString f Nothing `catch`
              (\ _ ->
                 do size <- fileSize `fmap` getSymbolicLinkStatus f
                    if size == 0 then return B.empty else
                      performGC >> mmapFileByteString f Nothing)
       return x
mmapFilePS = B.readFile

foreign import ccall unsafe "static fpstring.h conv_to_hex"
               conv_to_hex :: Ptr Word8 -> Ptr Word8 -> CInt -> IO ()

fromPS2Hex :: B.ByteString -> B.ByteString
fromPS2Hex ps
  = case BI.toForeignPtr ps of
        (x, s, l) -> BI.unsafeCreate (2 * l) $
                       \ p ->
                         withForeignPtr x $
                           \ f -> conv_to_hex p (f `plusPtr` s) $ fromIntegral l

foreign import ccall unsafe "static fpstring.h conv_from_hex"
               conv_from_hex :: Ptr Word8 -> Ptr Word8 -> CInt -> IO ()

fromHex2PS :: B.ByteString -> B.ByteString
fromHex2PS ps
  = case BI.toForeignPtr ps of
        (x, s, l) -> BI.unsafeCreate (l `div` 2) $
                       \ p ->
                         withForeignPtr x $
                           \ f -> conv_from_hex p (f `plusPtr` s) (fromIntegral $ l `div` 2)

betweenLinesPS ::
               B.ByteString ->
                 B.ByteString -> B.ByteString -> Maybe (B.ByteString)
betweenLinesPS start end ps
  = case break (start ==) (linesPS ps) of
        (_, _ : rest@(bs1 : _)) -> case BI.toForeignPtr bs1 of
                                       (ps1, s1, _) -> case break (end ==) rest of
                                                           (_, bs2 : _) -> case BI.toForeignPtr bs2
                                                                             of
                                                                               (_, s2, _) -> Just $
                                                                                               fromForeignPtr
                                                                                                 ps1
                                                                                                 s1
                                                                                                 (s2
                                                                                                    -
                                                                                                    s1)
                                                           _ -> Nothing
        _ -> Nothing

break_after_nth_newline ::
                        Int -> B.ByteString -> Maybe (B.ByteString, B.ByteString)
break_after_nth_newline 0 the_ps
  | B.null the_ps = Just (B.empty, B.empty)
break_after_nth_newline n the_ps
  = case BI.toForeignPtr the_ps of
        (fp, the_s, l) -> unsafePerformIO $
                            withForeignPtr fp $
                              \ p ->
                                do let findit 0 s | s == end = return $ Just (the_ps, B.empty)
                                       findit _ s | s == end = return Nothing
                                       findit 0 s
                                         = let left_l = s - the_s in
                                             return $
                                               Just
                                                 (fromForeignPtr fp the_s left_l,
                                                  fromForeignPtr fp s (l - left_l))
                                       findit i s
                                         = do w <- peekElemOff p s
                                              if w == nl then findit (i - 1) (s + 1) else
                                                findit i (s + 1)
                                       nl = BI.c2w '\n'
                                       end = the_s + l
                                   findit n the_s

break_before_nth_newline ::
                         Int -> B.ByteString -> (B.ByteString, B.ByteString)
break_before_nth_newline 0 the_ps
  | B.null the_ps = (B.empty, B.empty)
break_before_nth_newline n the_ps
  = case BI.toForeignPtr the_ps of
        (fp, the_s, l) -> unsafePerformIO $
                            withForeignPtr fp $
                              \ p ->
                                do let findit _ s | s == end = return (the_ps, B.empty)
                                       findit i s
                                         = do w <- peekElemOff p s
                                              if w == nl then
                                                if i == 0 then
                                                  let left_l = s - the_s in
                                                    return
                                                      (fromForeignPtr fp the_s left_l,
                                                       fromForeignPtr fp s (l - left_l))
                                                  else findit (i - 1) (s + 1)
                                                else findit i (s + 1)
                                       nl = BI.c2w '\n'
                                       end = the_s + l
                                   findit n the_s