{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

-- | The Ollama plugin for GHC
module GHC.Plugin.OllamaHoles (plugin) where

import Control.Monad (unless, when)
import Data.Char (isSpace)
import Data.Text (Text)
import Data.Text qualified as T
import Data.Text.IO qualified as T
import GHC.Plugins hiding ((<>))
import GHC.Tc.Types
import GHC.Tc.Types.Constraint (Hole (..), ctLocEnv, ctLocSpan)
import GHC.Tc.Utils.Monad (getGblEnv, newTcRef)
import Ollama (GenerateOps (..))
import Ollama qualified

-- | Options for the LLM model
genOps :: Ollama.GenerateOps
genOps :: GenerateOps
genOps =
    GenerateOps
Ollama.defaultGenerateOps
        { modelName = ""
        , prompt = ""
        }

promptTemplate :: Text
promptTemplate :: Text
promptTemplate =
    Text
"You are a typed-hole plugin within GHC, the Glasgow Haskell Compiler.\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"You are given a hole in a Haskell program, and you need to fill it in.\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"The hole is represented by the following information:\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"{module}\n{location}\n{imports}\n{hole_var}\n{hole_type}\n{relevant_constraints}\n{local_env}\n{global_env}\n{candidate_fits}\n\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Provide one or more Haskell expressions that could fill this hole.\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"This means coming up with an expression of the correct type that satisfies the constraints.\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Pay special attention to the type of the hole, specifically whether it is a function.\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Make sure you synthesize an expression that matches the type of the hole.\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Output ONLY the raw Haskell expression(s), one per line.\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Do not include explanations, introductions, or any surrounding text.\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Feel free to include any other functions from the list of imports to generate more complicated expressions.\n"
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Output a maximum of {numexpr} expresssions.\n"

-- | Ollama plugin for GHC
plugin :: Plugin
plugin :: Plugin
plugin =
    Plugin
defaultPlugin
        { holeFitPlugin = \[[Char]]
opts ->
            HoleFitPluginR -> Maybe HoleFitPluginR
forall a. a -> Maybe a
Just (HoleFitPluginR -> Maybe HoleFitPluginR)
-> HoleFitPluginR -> Maybe HoleFitPluginR
forall a b. (a -> b) -> a -> b
$
                HoleFitPluginR
                    { hfPluginInit :: TcM (TcRef ())
hfPluginInit = () -> TcM (TcRef ())
forall a gbl lcl. a -> TcRnIf gbl lcl (TcRef a)
newTcRef ()
                    , hfPluginStop :: TcRef () -> TcM ()
hfPluginStop = \TcRef ()
_ -> () -> TcM ()
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                    , hfPluginRun :: TcRef () -> HoleFitPlugin
hfPluginRun =
                        HoleFitPlugin -> TcRef () -> HoleFitPlugin
forall a b. a -> b -> a
const
                            HoleFitPlugin
                                { candPlugin :: CandPlugin
candPlugin = \TypedHole
_ [HoleFitCandidate]
c -> [HoleFitCandidate]
-> IOEnv (Env TcGblEnv TcLclEnv) [HoleFitCandidate]
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return [HoleFitCandidate]
c -- Don't filter candidates
                                , fitPlugin :: FitPlugin
fitPlugin = \TypedHole
hole [HoleFit]
fits -> do
                                    let Flags{Bool
Int
Text
model_name :: Text
num_expr :: Int
debug :: Bool
model_name :: Flags -> Text
num_expr :: Flags -> Int
debug :: Flags -> Bool
..} = [[Char]] -> Flags
parseFlags [[Char]]
opts
                                    DynFlags
dflags <- IOEnv (Env TcGblEnv TcLclEnv) DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
getDynFlags
                                    TcGblEnv
gbl_env <- TcRnIf TcGblEnv TcLclEnv TcGblEnv
forall gbl lcl. TcRnIf gbl lcl gbl
getGblEnv
                                    let mod_name :: [Char]
mod_name = ModuleName -> [Char]
moduleNameString (ModuleName -> [Char]) -> ModuleName -> [Char]
forall a b. (a -> b) -> a -> b
$ GenModule Unit -> ModuleName
forall unit. GenModule unit -> ModuleName
moduleName (GenModule Unit -> ModuleName) -> GenModule Unit -> ModuleName
forall a b. (a -> b) -> a -> b
$ TcGblEnv -> GenModule Unit
tcg_mod TcGblEnv
gbl_env
                                        imports :: ImportAvails
imports = TcGblEnv -> ImportAvails
tcg_imports TcGblEnv
gbl_env
                                    IO [HoleFit] -> TcM [HoleFit]
forall a. IO a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [HoleFit] -> TcM [HoleFit]) -> IO [HoleFit] -> TcM [HoleFit]
forall a b. (a -> b) -> a -> b
$ do
                                        Maybe Models
available_models <- IO (Maybe Models)
Ollama.list
                                        case Maybe Models
available_models of
                                            Maybe Models
Nothing -> Text -> IO ()
T.putStrLn Text
"--- Ollama plugin: No models available.--"
                                            Just (Ollama.Models [ModelInfo]
models) -> do
                                                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((ModelInfo -> Bool) -> [ModelInfo] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
model_name) (Text -> Bool) -> (ModelInfo -> Text) -> ModelInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModelInfo -> Text
Ollama.name) [ModelInfo]
models) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                                                    [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$
                                                        [Char]
"--- Ollama plugin: Model "
                                                            [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Text -> [Char]
T.unpack Text
model_name
                                                            [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" not found. "
                                                            [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"Use `ollama pull` to download the model, or specify another model using "
                                                            [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"`-fplugin-opt=GHC.Plugin.OllamaHoles:model=<model_name>` ---"
                                        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
debug (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> IO ()
T.putStrLn Text
"--- Ollama Plugin: Hole Found ---"
                                        let mn :: [Char]
mn = [Char]
"Module: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
mod_name
                                        let lc :: [Char]
lc = [Char]
"Location: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags (Maybe RealSrcSpan -> SDoc
forall a. Outputable a => a -> SDoc
ppr (Maybe RealSrcSpan -> SDoc) -> Maybe RealSrcSpan -> SDoc
forall a b. (a -> b) -> a -> b
$ CtLoc -> RealSrcSpan
ctLocSpan (CtLoc -> RealSrcSpan) -> (Hole -> CtLoc) -> Hole -> RealSrcSpan
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Hole -> CtLoc
hole_loc (Hole -> RealSrcSpan) -> Maybe Hole -> Maybe RealSrcSpan
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypedHole -> Maybe Hole
th_hole TypedHole
hole)
                                        let im :: [Char]
im = [Char]
"Imports: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags ([GenModule Unit] -> SDoc
forall a. Outputable a => a -> SDoc
ppr ([GenModule Unit] -> SDoc) -> [GenModule Unit] -> SDoc
forall a b. (a -> b) -> a -> b
$ ModuleEnv [ImportedBy] -> [GenModule Unit]
forall a. ModuleEnv a -> [GenModule Unit]
moduleEnvKeys (ModuleEnv [ImportedBy] -> [GenModule Unit])
-> ModuleEnv [ImportedBy] -> [GenModule Unit]
forall a b. (a -> b) -> a -> b
$ ImportAvails -> ModuleEnv [ImportedBy]
imp_mods ImportAvails
imports)

                                        case TypedHole -> Maybe Hole
th_hole TypedHole
hole of
                                            Just Hole
h -> do
                                                let lcl_env :: TcLclEnv
lcl_env = CtLoc -> TcLclEnv
ctLocEnv (Hole -> CtLoc
hole_loc Hole
h)
                                                let hv :: [Char]
hv = [Char]
"Hole variable: _" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> OccName -> [Char]
occNameString (RdrName -> OccName
forall name. HasOccName name => name -> OccName
occName (RdrName -> OccName) -> RdrName -> OccName
forall a b. (a -> b) -> a -> b
$ Hole -> RdrName
hole_occ Hole
h)
                                                let ht :: [Char]
ht = [Char]
"Hole type: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags (TcType -> SDoc
forall a. Outputable a => a -> SDoc
ppr (TcType -> SDoc) -> TcType -> SDoc
forall a b. (a -> b) -> a -> b
$ Hole -> TcType
hole_ty Hole
h)
                                                let rc :: [Char]
rc = [Char]
"Relevant constraints: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags (Bag CtEvidence -> SDoc
forall a. Outputable a => a -> SDoc
ppr (Bag CtEvidence -> SDoc) -> Bag CtEvidence -> SDoc
forall a b. (a -> b) -> a -> b
$ TypedHole -> Bag CtEvidence
th_relevant_cts TypedHole
hole)
                                                let le :: [Char]
le = [Char]
"Local environment (bindings): " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags (LocalRdrEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr (LocalRdrEnv -> SDoc) -> LocalRdrEnv -> SDoc
forall a b. (a -> b) -> a -> b
$ TcLclEnv -> LocalRdrEnv
tcl_rdr TcLclEnv
lcl_env)
                                                let ge :: [Char]
ge = [Char]
"Global environment (bindings): " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags (LHsBinds GhcTc -> SDoc
forall a. Outputable a => a -> SDoc
ppr (LHsBinds GhcTc -> SDoc) -> LHsBinds GhcTc -> SDoc
forall a b. (a -> b) -> a -> b
$ TcGblEnv -> LHsBinds GhcTc
tcg_binds TcGblEnv
gbl_env)
                                                let cf :: [Char]
cf = [Char]
"Candidate fits: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags ([HoleFit] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [HoleFit]
fits)
                                                let prompt' :: Text
prompt' =
                                                        Text -> [(Text, [Char])] -> Text
replacePlaceholders
                                                            Text
promptTemplate
                                                            [ (Text
"{module}", [Char]
mn)
                                                            , (Text
"{location}", [Char]
lc)
                                                            , (Text
"{imports}", [Char]
im)
                                                            , (Text
"{hole_var}", [Char]
hv)
                                                            , (Text
"{hole_type}", [Char]
ht)
                                                            , (Text
"{relevant_constraints}", [Char]
rc)
                                                            , (Text
"{local_env}", [Char]
le)
                                                            , (Text
"{global_env}", [Char]
ge)
                                                            , (Text
"{candidate_fits}", [Char]
cf)
                                                            , (Text
"{numexpr}", Int -> [Char]
forall a. Show a => a -> [Char]
show Int
num_expr)
                                                            ]
                                                Either [Char] GenerateResponse
res <- GenerateOps -> IO (Either [Char] GenerateResponse)
Ollama.generate GenerateOps
genOps{prompt = prompt', modelName = model_name}
                                                case Either [Char] GenerateResponse
res of
                                                    Right GenerateResponse
rsp -> do
                                                        let lns :: [Text]
lns = ([Text] -> [Text]
preProcess ([Text] -> [Text]) -> (Text -> [Text]) -> Text -> [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> [Text]
T.lines) (Text -> [Text]) -> Text -> [Text]
forall a b. (a -> b) -> a -> b
$ GenerateResponse -> Text
Ollama.response_ GenerateResponse
rsp
                                                        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
debug (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                                                            Text -> IO ()
T.putStrLn (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"--- Ollama Plugin: Prompt ---\n" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
prompt'
                                                            Text -> IO ()
T.putStrLn (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"--- Ollama Plugin: Response ---\n" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> GenerateResponse -> Text
Ollama.response_ GenerateResponse
rsp
                                                        let fits' :: [HoleFit]
fits' = (Text -> HoleFit) -> [Text] -> [HoleFit]
forall a b. (a -> b) -> [a] -> [b]
map (SDoc -> HoleFit
RawHoleFit (SDoc -> HoleFit) -> (Text -> SDoc) -> Text -> HoleFit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> SDoc
forall doc. IsLine doc => [Char] -> doc
text ([Char] -> SDoc) -> (Text -> [Char]) -> Text -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> [Char]
T.unpack) [Text]
lns
                                                        -- Return the generated fits
                                                        [HoleFit] -> IO [HoleFit]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [HoleFit]
fits'
                                                    Left [Char]
err -> do
                                                        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
debug (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                                                            [Char] -> IO ()
putStrLn ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$
                                                                [Char]
"Ollama plugin failed to generate a response.\n" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
err
                                                        -- Return the original fits without modification
                                                        [HoleFit] -> IO [HoleFit]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [HoleFit]
fits
                                            Maybe Hole
Nothing -> [HoleFit] -> IO [HoleFit]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [HoleFit]
fits
                                }
                    }
        }

-- | Preprocess the response to remove empty lines, lines with only spaces, and code blocks
preProcess :: [Text] -> [Text]
preProcess :: [Text] -> [Text]
preProcess [] = []
preProcess (Text
ln : [Text]
lns) | Bool
should_drop = [Text] -> [Text]
preProcess [Text]
lns
  where
    should_drop :: Bool
    should_drop :: Bool
should_drop =
        Text -> Bool
T.null Text
ln
            Bool -> Bool -> Bool
|| (Char -> Bool) -> Text -> Bool
T.all Char -> Bool
isSpace Text
ln
            Bool -> Bool -> Bool
|| Text -> Text -> Bool
T.isPrefixOf Text
"```" Text
ln
preProcess (Text
ln : [Text]
lns) = Text -> Text
transform Text
ln Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text] -> [Text]
preProcess [Text]
lns
  where
    transform :: Text -> Text
    transform :: Text -> Text
transform = Text -> Text
forall a. a -> a
id

-- | Command line options for the plugin
data Flags = Flags
    { Flags -> Text
model_name :: Text
    , Flags -> Int
num_expr :: Int
    , Flags -> Bool
debug :: Bool
    }

-- | Default flags for the plugin
defaultFlags :: Flags
defaultFlags :: Flags
defaultFlags =
    Flags
        { model_name :: Text
model_name = Text
"gemma3:27b-it-qat"
        , num_expr :: Int
num_expr = Int
5
        , debug :: Bool
debug = Bool
False
        }

-- | Parse command line options
parseFlags :: [CommandLineOption] -> Flags
parseFlags :: [[Char]] -> Flags
parseFlags = Flags -> [[Char]] -> Flags
parseFlags' Flags
defaultFlags
  where
    parseFlags' :: Flags -> [CommandLineOption] -> Flags
    parseFlags' :: Flags -> [[Char]] -> Flags
parseFlags' Flags
flags [] = Flags
flags
    parseFlags' Flags
flags ([Char]
opt : [[Char]]
opts)
        | Text -> Text -> Bool
T.isPrefixOf Text
"model=" ([Char] -> Text
T.pack [Char]
opt) =
            let model_name :: Text
model_name = Int -> Text -> Text
T.drop (Text -> Int
T.length Text
"model=") ([Char] -> Text
T.pack [Char]
opt)
             in Flags -> [[Char]] -> Flags
parseFlags' Flags
flags{model_name = model_name} [[Char]]
opts
    parseFlags' Flags
flags ([Char]
opt : [[Char]]
opts)
        | Text -> Text -> Bool
T.isPrefixOf Text
"debug=" ([Char] -> Text
T.pack [Char]
opt) =
            let debug :: [Char]
debug = Text -> [Char]
T.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$ Int -> Text -> Text
T.drop (Text -> Int
T.length Text
"debug=") ([Char] -> Text
T.pack [Char]
opt)
             in Flags -> [[Char]] -> Flags
parseFlags' Flags
flags{debug = read debug} [[Char]]
opts
    parseFlags' Flags
flags ([Char]
opt : [[Char]]
opts)
        | Text -> Text -> Bool
T.isPrefixOf Text
"n=" ([Char] -> Text
T.pack [Char]
opt) =
            let num_expr :: [Char]
num_expr = Text -> [Char]
T.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$ Int -> Text -> Text
T.drop (Text -> Int
T.length Text
"n=") ([Char] -> Text
T.pack [Char]
opt)
             in Flags -> [[Char]] -> Flags
parseFlags' Flags
flags{num_expr = read num_expr} [[Char]]
opts
    parseFlags' Flags
flags [[Char]]
_ = Flags
flags

-- | Helper function to replace placeholders in a template string
replacePlaceholders :: Text -> [(Text, String)] -> Text
replacePlaceholders :: Text -> [(Text, [Char])] -> Text
replacePlaceholders = (Text -> (Text, [Char]) -> Text)
-> Text -> [(Text, [Char])] -> Text
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Text -> (Text, [Char]) -> Text
replacePlaceholder
  where
    replacePlaceholder :: Text -> (Text, String) -> Text
    replacePlaceholder :: Text -> (Text, [Char]) -> Text
replacePlaceholder Text
str (Text
placeholder, [Char]
value) = HasCallStack => Text -> Text -> Text -> Text
Text -> Text -> Text -> Text
T.replace Text
placeholder ([Char] -> Text
T.pack [Char]
value) Text
str