{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module GHC.Plugin.OllamaHoles (plugin) where
import Control.Monad (unless, when)
import Data.Char (isSpace)
import Data.Text (Text)
import Data.Text qualified as T
import Data.Text.IO qualified as T
import GHC.Plugins hiding ((<>))
import GHC.Tc.Types
import GHC.Tc.Types.Constraint (Hole (..), ctLocEnv, ctLocSpan)
import GHC.Tc.Utils.Monad (getGblEnv, newTcRef)
import Ollama (GenerateOps (..))
import Ollama qualified
genOps :: Ollama.GenerateOps
genOps :: GenerateOps
genOps =
GenerateOps
Ollama.defaultGenerateOps
{ modelName = ""
, prompt = ""
}
promptTemplate :: Text
promptTemplate :: Text
promptTemplate =
Text
"You are a typed-hole plugin within GHC, the Glasgow Haskell Compiler.\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"You are given a hole in a Haskell program, and you need to fill it in.\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"The hole is represented by the following information:\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"{module}\n{location}\n{imports}\n{hole_var}\n{hole_type}\n{relevant_constraints}\n{local_env}\n{global_env}\n{candidate_fits}\n\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Provide one or more Haskell expressions that could fill this hole.\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"This means coming up with an expression of the correct type that satisfies the constraints.\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Pay special attention to the type of the hole, specifically whether it is a function.\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Make sure you synthesize an expression that matches the type of the hole.\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Output ONLY the raw Haskell expression(s), one per line.\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Do not include explanations, introductions, or any surrounding text.\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Feel free to include any other functions from the list of imports to generate more complicated expressions.\n"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"Output a maximum of {numexpr} expresssions.\n"
plugin :: Plugin
plugin :: Plugin
plugin =
Plugin
defaultPlugin
{ holeFitPlugin = \[[Char]]
opts ->
HoleFitPluginR -> Maybe HoleFitPluginR
forall a. a -> Maybe a
Just (HoleFitPluginR -> Maybe HoleFitPluginR)
-> HoleFitPluginR -> Maybe HoleFitPluginR
forall a b. (a -> b) -> a -> b
$
HoleFitPluginR
{ hfPluginInit :: TcM (TcRef ())
hfPluginInit = () -> TcM (TcRef ())
forall a gbl lcl. a -> TcRnIf gbl lcl (TcRef a)
newTcRef ()
, hfPluginStop :: TcRef () -> TcM ()
hfPluginStop = \TcRef ()
_ -> () -> TcM ()
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
, hfPluginRun :: TcRef () -> HoleFitPlugin
hfPluginRun =
HoleFitPlugin -> TcRef () -> HoleFitPlugin
forall a b. a -> b -> a
const
HoleFitPlugin
{ candPlugin :: CandPlugin
candPlugin = \TypedHole
_ [HoleFitCandidate]
c -> [HoleFitCandidate]
-> IOEnv (Env TcGblEnv TcLclEnv) [HoleFitCandidate]
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return [HoleFitCandidate]
c
, fitPlugin :: FitPlugin
fitPlugin = \TypedHole
hole [HoleFit]
fits -> do
let Flags{Bool
Int
Text
model_name :: Text
num_expr :: Int
debug :: Bool
model_name :: Flags -> Text
num_expr :: Flags -> Int
debug :: Flags -> Bool
..} = [[Char]] -> Flags
parseFlags [[Char]]
opts
DynFlags
dflags <- IOEnv (Env TcGblEnv TcLclEnv) DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
getDynFlags
TcGblEnv
gbl_env <- TcRnIf TcGblEnv TcLclEnv TcGblEnv
forall gbl lcl. TcRnIf gbl lcl gbl
getGblEnv
let mod_name :: [Char]
mod_name = ModuleName -> [Char]
moduleNameString (ModuleName -> [Char]) -> ModuleName -> [Char]
forall a b. (a -> b) -> a -> b
$ GenModule Unit -> ModuleName
forall unit. GenModule unit -> ModuleName
moduleName (GenModule Unit -> ModuleName) -> GenModule Unit -> ModuleName
forall a b. (a -> b) -> a -> b
$ TcGblEnv -> GenModule Unit
tcg_mod TcGblEnv
gbl_env
imports :: ImportAvails
imports = TcGblEnv -> ImportAvails
tcg_imports TcGblEnv
gbl_env
IO [HoleFit] -> TcM [HoleFit]
forall a. IO a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [HoleFit] -> TcM [HoleFit]) -> IO [HoleFit] -> TcM [HoleFit]
forall a b. (a -> b) -> a -> b
$ do
Maybe Models
available_models <- IO (Maybe Models)
Ollama.list
case Maybe Models
available_models of
Maybe Models
Nothing -> Text -> IO ()
T.putStrLn Text
"--- Ollama plugin: No models available.--"
Just (Ollama.Models [ModelInfo]
models) -> do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((ModelInfo -> Bool) -> [ModelInfo] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
model_name) (Text -> Bool) -> (ModelInfo -> Text) -> ModelInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModelInfo -> Text
Ollama.name) [ModelInfo]
models) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
[Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$
[Char]
"--- Ollama plugin: Model "
[Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Text -> [Char]
T.unpack Text
model_name
[Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" not found. "
[Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"Use `ollama pull` to download the model, or specify another model using "
[Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"`-fplugin-opt=GHC.Plugin.OllamaHoles:model=<model_name>` ---"
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
debug (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> IO ()
T.putStrLn Text
"--- Ollama Plugin: Hole Found ---"
let mn :: [Char]
mn = [Char]
"Module: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
mod_name
let lc :: [Char]
lc = [Char]
"Location: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags (Maybe RealSrcSpan -> SDoc
forall a. Outputable a => a -> SDoc
ppr (Maybe RealSrcSpan -> SDoc) -> Maybe RealSrcSpan -> SDoc
forall a b. (a -> b) -> a -> b
$ CtLoc -> RealSrcSpan
ctLocSpan (CtLoc -> RealSrcSpan) -> (Hole -> CtLoc) -> Hole -> RealSrcSpan
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Hole -> CtLoc
hole_loc (Hole -> RealSrcSpan) -> Maybe Hole -> Maybe RealSrcSpan
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypedHole -> Maybe Hole
th_hole TypedHole
hole)
let im :: [Char]
im = [Char]
"Imports: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags ([GenModule Unit] -> SDoc
forall a. Outputable a => a -> SDoc
ppr ([GenModule Unit] -> SDoc) -> [GenModule Unit] -> SDoc
forall a b. (a -> b) -> a -> b
$ ModuleEnv [ImportedBy] -> [GenModule Unit]
forall a. ModuleEnv a -> [GenModule Unit]
moduleEnvKeys (ModuleEnv [ImportedBy] -> [GenModule Unit])
-> ModuleEnv [ImportedBy] -> [GenModule Unit]
forall a b. (a -> b) -> a -> b
$ ImportAvails -> ModuleEnv [ImportedBy]
imp_mods ImportAvails
imports)
case TypedHole -> Maybe Hole
th_hole TypedHole
hole of
Just Hole
h -> do
let lcl_env :: TcLclEnv
lcl_env = CtLoc -> TcLclEnv
ctLocEnv (Hole -> CtLoc
hole_loc Hole
h)
let hv :: [Char]
hv = [Char]
"Hole variable: _" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> OccName -> [Char]
occNameString (RdrName -> OccName
forall name. HasOccName name => name -> OccName
occName (RdrName -> OccName) -> RdrName -> OccName
forall a b. (a -> b) -> a -> b
$ Hole -> RdrName
hole_occ Hole
h)
let ht :: [Char]
ht = [Char]
"Hole type: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags (TcType -> SDoc
forall a. Outputable a => a -> SDoc
ppr (TcType -> SDoc) -> TcType -> SDoc
forall a b. (a -> b) -> a -> b
$ Hole -> TcType
hole_ty Hole
h)
let rc :: [Char]
rc = [Char]
"Relevant constraints: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags (Bag CtEvidence -> SDoc
forall a. Outputable a => a -> SDoc
ppr (Bag CtEvidence -> SDoc) -> Bag CtEvidence -> SDoc
forall a b. (a -> b) -> a -> b
$ TypedHole -> Bag CtEvidence
th_relevant_cts TypedHole
hole)
let le :: [Char]
le = [Char]
"Local environment (bindings): " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags (LocalRdrEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr (LocalRdrEnv -> SDoc) -> LocalRdrEnv -> SDoc
forall a b. (a -> b) -> a -> b
$ TcLclEnv -> LocalRdrEnv
tcl_rdr TcLclEnv
lcl_env)
let ge :: [Char]
ge = [Char]
"Global environment (bindings): " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags (LHsBinds GhcTc -> SDoc
forall a. Outputable a => a -> SDoc
ppr (LHsBinds GhcTc -> SDoc) -> LHsBinds GhcTc -> SDoc
forall a b. (a -> b) -> a -> b
$ TcGblEnv -> LHsBinds GhcTc
tcg_binds TcGblEnv
gbl_env)
let cf :: [Char]
cf = [Char]
"Candidate fits: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DynFlags -> SDoc -> [Char]
showSDoc DynFlags
dflags ([HoleFit] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [HoleFit]
fits)
let prompt' :: Text
prompt' =
Text -> [(Text, [Char])] -> Text
replacePlaceholders
Text
promptTemplate
[ (Text
"{module}", [Char]
mn)
, (Text
"{location}", [Char]
lc)
, (Text
"{imports}", [Char]
im)
, (Text
"{hole_var}", [Char]
hv)
, (Text
"{hole_type}", [Char]
ht)
, (Text
"{relevant_constraints}", [Char]
rc)
, (Text
"{local_env}", [Char]
le)
, (Text
"{global_env}", [Char]
ge)
, (Text
"{candidate_fits}", [Char]
cf)
, (Text
"{numexpr}", Int -> [Char]
forall a. Show a => a -> [Char]
show Int
num_expr)
]
Either [Char] GenerateResponse
res <- GenerateOps -> IO (Either [Char] GenerateResponse)
Ollama.generate GenerateOps
genOps{prompt = prompt', modelName = model_name}
case Either [Char] GenerateResponse
res of
Right GenerateResponse
rsp -> do
let lns :: [Text]
lns = ([Text] -> [Text]
preProcess ([Text] -> [Text]) -> (Text -> [Text]) -> Text -> [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> [Text]
T.lines) (Text -> [Text]) -> Text -> [Text]
forall a b. (a -> b) -> a -> b
$ GenerateResponse -> Text
Ollama.response_ GenerateResponse
rsp
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
debug (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Text -> IO ()
T.putStrLn (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"--- Ollama Plugin: Prompt ---\n" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
prompt'
Text -> IO ()
T.putStrLn (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"--- Ollama Plugin: Response ---\n" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> GenerateResponse -> Text
Ollama.response_ GenerateResponse
rsp
let fits' :: [HoleFit]
fits' = (Text -> HoleFit) -> [Text] -> [HoleFit]
forall a b. (a -> b) -> [a] -> [b]
map (SDoc -> HoleFit
RawHoleFit (SDoc -> HoleFit) -> (Text -> SDoc) -> Text -> HoleFit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> SDoc
forall doc. IsLine doc => [Char] -> doc
text ([Char] -> SDoc) -> (Text -> [Char]) -> Text -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> [Char]
T.unpack) [Text]
lns
[HoleFit] -> IO [HoleFit]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [HoleFit]
fits'
Left [Char]
err -> do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
debug (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
[Char] -> IO ()
putStrLn ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$
[Char]
"Ollama plugin failed to generate a response.\n" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
err
[HoleFit] -> IO [HoleFit]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [HoleFit]
fits
Maybe Hole
Nothing -> [HoleFit] -> IO [HoleFit]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [HoleFit]
fits
}
}
}
preProcess :: [Text] -> [Text]
preProcess :: [Text] -> [Text]
preProcess [] = []
preProcess (Text
ln : [Text]
lns) | Bool
should_drop = [Text] -> [Text]
preProcess [Text]
lns
where
should_drop :: Bool
should_drop :: Bool
should_drop =
Text -> Bool
T.null Text
ln
Bool -> Bool -> Bool
|| (Char -> Bool) -> Text -> Bool
T.all Char -> Bool
isSpace Text
ln
Bool -> Bool -> Bool
|| Text -> Text -> Bool
T.isPrefixOf Text
"```" Text
ln
preProcess (Text
ln : [Text]
lns) = Text -> Text
transform Text
ln Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text] -> [Text]
preProcess [Text]
lns
where
transform :: Text -> Text
transform :: Text -> Text
transform = Text -> Text
forall a. a -> a
id
data Flags = Flags
{ Flags -> Text
model_name :: Text
, Flags -> Int
num_expr :: Int
, Flags -> Bool
debug :: Bool
}
defaultFlags :: Flags
defaultFlags :: Flags
defaultFlags =
Flags
{ model_name :: Text
model_name = Text
"gemma3:27b-it-qat"
, num_expr :: Int
num_expr = Int
5
, debug :: Bool
debug = Bool
False
}
parseFlags :: [CommandLineOption] -> Flags
parseFlags :: [[Char]] -> Flags
parseFlags = Flags -> [[Char]] -> Flags
parseFlags' Flags
defaultFlags
where
parseFlags' :: Flags -> [CommandLineOption] -> Flags
parseFlags' :: Flags -> [[Char]] -> Flags
parseFlags' Flags
flags [] = Flags
flags
parseFlags' Flags
flags ([Char]
opt : [[Char]]
opts)
| Text -> Text -> Bool
T.isPrefixOf Text
"model=" ([Char] -> Text
T.pack [Char]
opt) =
let model_name :: Text
model_name = Int -> Text -> Text
T.drop (Text -> Int
T.length Text
"model=") ([Char] -> Text
T.pack [Char]
opt)
in Flags -> [[Char]] -> Flags
parseFlags' Flags
flags{model_name = model_name} [[Char]]
opts
parseFlags' Flags
flags ([Char]
opt : [[Char]]
opts)
| Text -> Text -> Bool
T.isPrefixOf Text
"debug=" ([Char] -> Text
T.pack [Char]
opt) =
let debug :: [Char]
debug = Text -> [Char]
T.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$ Int -> Text -> Text
T.drop (Text -> Int
T.length Text
"debug=") ([Char] -> Text
T.pack [Char]
opt)
in Flags -> [[Char]] -> Flags
parseFlags' Flags
flags{debug = read debug} [[Char]]
opts
parseFlags' Flags
flags ([Char]
opt : [[Char]]
opts)
| Text -> Text -> Bool
T.isPrefixOf Text
"n=" ([Char] -> Text
T.pack [Char]
opt) =
let num_expr :: [Char]
num_expr = Text -> [Char]
T.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$ Int -> Text -> Text
T.drop (Text -> Int
T.length Text
"n=") ([Char] -> Text
T.pack [Char]
opt)
in Flags -> [[Char]] -> Flags
parseFlags' Flags
flags{num_expr = read num_expr} [[Char]]
opts
parseFlags' Flags
flags [[Char]]
_ = Flags
flags
replacePlaceholders :: Text -> [(Text, String)] -> Text
replacePlaceholders :: Text -> [(Text, [Char])] -> Text
replacePlaceholders = (Text -> (Text, [Char]) -> Text)
-> Text -> [(Text, [Char])] -> Text
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Text -> (Text, [Char]) -> Text
replacePlaceholder
where
replacePlaceholder :: Text -> (Text, String) -> Text
replacePlaceholder :: Text -> (Text, [Char]) -> Text
replacePlaceholder Text
str (Text
placeholder, [Char]
value) = HasCallStack => Text -> Text -> Text -> Text
Text -> Text -> Text -> Text
T.replace Text
placeholder ([Char] -> Text
T.pack [Char]
value) Text
str