{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict #-}
module Tokstyle.C.Linter.Cast (analyse) where
import Control.Monad (unless, zipWithM_)
import Data.Functor.Identity (Identity)
import qualified Data.Map as Map
import Language.C.Analysis.AstAnalysis (ExprSide (..), defaultMD,
tExpr)
import Language.C.Analysis.ConstEval (constEval, intValue)
import Language.C.Analysis.DefTable (lookupTag)
import Language.C.Analysis.SemError (typeMismatch)
import Language.C.Analysis.SemRep (EnumType (..),
EnumTypeRef (..),
Enumerator (..), GlobalDecls,
IntType (..), TagDef (..),
Type (..), TypeName (..),
TypeQuals (..), noTypeQuals)
import Language.C.Analysis.TravMonad (MonadTrav, Trav, TravT,
getDefTable, recordError,
throwTravError)
import Language.C.Analysis.TypeUtils (canonicalType, sameType)
import Language.C.Data.Error (userErr)
import Language.C.Data.Ident (Ident (..))
import Language.C.Pretty (pretty)
import Language.C.Syntax.AST (CConstant (..), CExpr,
CExpression (..), annotation)
import Language.C.Syntax.Constants (CInteger (..))
import qualified Tokstyle.C.Env as Env
import Tokstyle.C.Env (Env)
import Tokstyle.C.Patterns
import Tokstyle.C.TraverseAst (AstActions (..), astActions,
traverseAst)
import Tokstyle.C.TravUtils (getJust)
sameEnum :: MonadTrav m => Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
sameEnum :: Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
sameEnum Type
leftTy Type
rightTy (Ident
leftId, CExpr
leftExpr) (Ident
rightId, CExpr
rightExpr) = do
Integer
leftVal <- String -> Maybe Integer -> m Integer
forall (m :: * -> *) a. MonadTrav m => String -> Maybe a -> m a
getJust String
failMsg (Maybe Integer -> m Integer)
-> (CExpr -> Maybe Integer) -> CExpr -> m Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CExpr -> Maybe Integer
intValue (CExpr -> m Integer) -> m CExpr -> m Integer
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
forall (m :: * -> *).
MonadTrav m =>
MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
constEval MachineDesc
defaultMD Map Ident CExpr
forall k a. Map k a
Map.empty CExpr
leftExpr
Integer
rightVal <- String -> Maybe Integer -> m Integer
forall (m :: * -> *) a. MonadTrav m => String -> Maybe a -> m a
getJust String
failMsg (Maybe Integer -> m Integer)
-> (CExpr -> Maybe Integer) -> CExpr -> m Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CExpr -> Maybe Integer
intValue (CExpr -> m Integer) -> m CExpr -> m Integer
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
forall (m :: * -> *).
MonadTrav m =>
MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
constEval MachineDesc
defaultMD Map Ident CExpr
forall k a. Map k a
Map.empty CExpr
rightExpr
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Integer
leftVal Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
rightVal) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
TypeMismatch -> m ()
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (TypeMismatch -> m ()) -> TypeMismatch -> m ()
forall a b. (a -> b) -> a -> b
$ String -> (NodeInfo, Type) -> (NodeInfo, Type) -> TypeMismatch
typeMismatch
(String
"invalid cast: enumerator value for `"
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Ident -> Doc
forall p. Pretty p => p -> Doc
pretty Ident
leftId) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" = " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
leftVal
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` does not match `"
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Ident -> Doc
forall p. Pretty p => p -> Doc
pretty Ident
rightId) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" = " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
rightVal String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`")
(CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
leftExpr, Type
leftTy)
(CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
rightExpr, Type
rightTy)
where
failMsg :: String
failMsg = String
"invalid cast: could not determine enumerator values"
checkEnumCast :: MonadTrav m => Type -> Type -> CExpr -> m ()
checkEnumCast :: Type -> Type -> CExpr -> m ()
checkEnumCast Type
castTy Type
exprTy CExpr
_ = do
[(Ident, CExpr)]
castEnums <- Type -> m [(Ident, CExpr)]
forall (m :: * -> *). MonadTrav m => Type -> m [(Ident, CExpr)]
enumerators (Type -> Type
canonicalType Type
castTy)
[(Ident, CExpr)]
exprEnums <- Type -> m [(Ident, CExpr)]
forall (m :: * -> *). MonadTrav m => Type -> m [(Ident, CExpr)]
enumerators (Type -> Type
canonicalType Type
exprTy)
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([(Ident, CExpr)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Ident, CExpr)]
castEnums Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [(Ident, CExpr)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Ident, CExpr)]
exprEnums) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
UserError -> m ()
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m ()) -> UserError -> m ()
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$
String
"enum types `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
castTy) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` and `"
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
exprTy) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` have different a number of enumerators"
((Ident, CExpr) -> (Ident, CExpr) -> m ())
-> [(Ident, CExpr)] -> [(Ident, CExpr)] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
forall (m :: * -> *).
MonadTrav m =>
Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
sameEnum Type
castTy Type
exprTy) [(Ident, CExpr)]
castEnums [(Ident, CExpr)]
exprEnums
enumerators :: MonadTrav m => Type -> m [(Ident, CExpr)]
enumerators :: Type -> m [(Ident, CExpr)]
enumerators (DirectType (TyEnum (EnumTypeRef SUERef
name NodeInfo
_)) TypeQuals
_ Attributes
_) = do
DefTable
defs <- m DefTable
forall (m :: * -> *). MonadSymtab m => m DefTable
getDefTable
case SUERef -> DefTable -> Maybe TagEntry
lookupTag SUERef
name DefTable
defs of
Just (Right (EnumDef (EnumType SUERef
_ [Enumerator]
enums Attributes
_ NodeInfo
_))) ->
[(Ident, CExpr)] -> m [(Ident, CExpr)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Ident, CExpr)] -> m [(Ident, CExpr)])
-> [(Ident, CExpr)] -> m [(Ident, CExpr)]
forall a b. (a -> b) -> a -> b
$ (Enumerator -> (Ident, CExpr)) -> [Enumerator] -> [(Ident, CExpr)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Enumerator Ident
i CExpr
e EnumType
_ NodeInfo
_) -> (Ident
i, CExpr
e)) [Enumerator]
enums
Maybe TagEntry
_ ->
UserError -> m [(Ident, CExpr)]
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m [(Ident, CExpr)])
-> UserError -> m [(Ident, CExpr)]
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$
String
"couldn't find enum type `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (SUERef -> Doc
forall p. Pretty p => p -> Doc
pretty SUERef
name) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`"
enumerators Type
ty =
UserError -> m [(Ident, CExpr)]
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m [(Ident, CExpr)])
-> UserError -> m [(Ident, CExpr)]
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$ String
"invalid enum type `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
ty) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`"
unqual :: Type -> Type
unqual :: Type -> Type
unqual (PtrType Type
ty TypeQuals
_ Attributes
a) = Type -> TypeQuals -> Attributes -> Type
PtrType (Type -> Type
unqual Type
ty) TypeQuals
noTypeQuals Attributes
a
unqual (DirectType TypeName
tn TypeQuals
_ Attributes
a) = TypeName -> TypeQuals -> Attributes -> Type
DirectType TypeName
tn TypeQuals
noTypeQuals Attributes
a
unqual (ArrayType Type
ty ArraySize
sz TypeQuals
_ Attributes
a) = Type -> ArraySize -> TypeQuals -> Attributes -> Type
ArrayType (Type -> Type
unqual Type
ty) ArraySize
sz TypeQuals
noTypeQuals Attributes
a
unqual Type
ty = Type
ty
checkCast :: MonadTrav m => Type -> Type -> CExpr -> m ()
checkCast :: Type -> Type -> CExpr -> m ()
checkCast Type
castTy' Type
exprTy' CExpr
e
| Type -> Bool
isCharOrUint8T Type
castTy' Bool -> Bool -> Bool
&& Type -> Bool
isCharOrUint8T Type
exprTy' = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise = Type -> Type -> m ()
forall (m :: * -> *). MonadTrav m => Type -> Type -> m ()
check (Type -> Type
canonicalType Type
castTy') (Type -> Type
canonicalType Type
exprTy')
where
isCharOrUint8T :: Type -> Bool
isCharOrUint8T Type
ty = case Type
ty of
Type
TY_char_ptr -> Bool
True
Type
TY_char_arr -> Bool
True
Type
TY_uint8_t_ptr -> Bool
True
Type
TY_uint8_t_arr -> Bool
True
Type
_ -> Bool
False
check :: Type -> Type -> m ()
check Type
castTy Type
exprTy | Type -> Type -> Bool
sameType Type
castTy Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check (PtrType Type
castPointee TypeQuals
_ Attributes
_) (PtrType Type
exprPointee TypeQuals
_ Attributes
_)
| Type -> Type -> Bool
sameType (Type -> Type
unqual Type
castPointee) (Type -> Type
unqual Type
exprPointee) = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check (PtrType Type
castPointee TypeQuals
_ Attributes
_) (ArrayType Type
elemTy ArraySize
_ TypeQuals
_ Attributes
_)
| Type -> Type -> Bool
sameType (Type -> Type
unqual Type
castPointee) (Type -> Type
unqual Type
elemTy) = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check (DirectType TypeName
TyVoid TypeQuals
_ Attributes
_) Type
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check PtrType{} Type
TY_void_ptr = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Type
TY_void_ptr PtrType{} = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check PtrType{} Type
_ | CExpr -> Bool
forall a. CExpression a -> Bool
isNullPtr CExpr
e = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Type
TY_sockaddr_ptr Type
TY_sockaddr_storage_ptr = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Type
TY_sockaddr_in_ptr Type
TY_sockaddr_storage_ptr = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Type
TY_sockaddr_in6_ptr Type
TY_sockaddr_storage_ptr = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Type
castTy Type
exprTy | Type -> Bool
isNumeric Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isNumeric Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Type
castTy Type
exprTy | Type -> Bool
isIntegral Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isEnum Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Type
castTy Type
exprTy | Type -> Bool
isEnum Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isEnum Type
exprTy = Type -> Type -> CExpr -> m ()
forall (m :: * -> *). MonadTrav m => Type -> Type -> CExpr -> m ()
checkEnumCast Type
castTy Type
exprTy CExpr
e
check (PtrType (PtrType (TY_typedef String
"Messenger") TypeQuals
_ Attributes
_) TypeQuals
_ Attributes
_) Type
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check (PtrType Type
TY_void_ptr TypeQuals
_ Attributes
_) Type
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Type
castTy Type
exprTy | Type -> Bool
isEnum Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isIntegral Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Type
_ Type
_ =
let annot :: (NodeInfo, Type)
annot = (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
e, Type
castTy') in
TypeMismatch -> m ()
forall (m :: * -> *) e. (MonadCError m, Error e) => e -> m ()
recordError (TypeMismatch -> m ()) -> TypeMismatch -> m ()
forall a b. (a -> b) -> a -> b
$ String -> (NodeInfo, Type) -> (NodeInfo, Type) -> TypeMismatch
typeMismatch (String
"disallowed cast from " String -> String -> String
forall a. Semigroup a => a -> a -> a
<>
Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
exprTy') String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" to " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
castTy')) (NodeInfo, Type)
annot (NodeInfo, Type)
annot
isNullPtr :: CExpression a -> Bool
isNullPtr (CConst (CIntConst (CInteger Integer
0 CIntRepr
_ Flags CIntFlag
_) a
_)) = Bool
True
isNullPtr CExpression a
_ = Bool
False
exemptions :: [String]
exemptions :: [String]
exemptions = [String
"call:getsockopt", String
"call:setsockopt", String
"call:bs_list_add", String
"call:bs_list_remove", String
"call:bs_list_find", String
"call:random_bytes", String
"call:randombytes"]
linter :: AstActions (TravT Env Identity)
linter :: AstActions (TravT Env Identity)
linter = AstActions (TravT Env Identity)
forall (f :: * -> *). Applicative f => AstActions f
astActions
{ doExpr :: CExpr -> TravT Env Identity () -> TravT Env Identity ()
doExpr = \CExpr
node TravT Env Identity ()
act -> case CExpr
node of
cast :: CExpr
cast@(CCast CDeclaration NodeInfo
_ CExpr
e NodeInfo
_) -> do
Type
castTy <- [StmtCtx] -> ExprSide -> CExpr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> CExpr -> m Type
tExpr [] ExprSide
RValue CExpr
cast
Type
exprTy <- [StmtCtx] -> ExprSide -> CExpr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> CExpr -> m Type
tExpr [] ExprSide
RValue CExpr
e
[String]
ctx <- Trav Env [String]
Env.getCtx
Bool -> TravT Env Identity () -> TravT Env Identity ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([String] -> String
forall a. [a] -> a
head [String]
ctx String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [String]
exemptions) (TravT Env Identity () -> TravT Env Identity ())
-> TravT Env Identity () -> TravT Env Identity ()
forall a b. (a -> b) -> a -> b
$
Type -> Type -> CExpr -> TravT Env Identity ()
forall (m :: * -> *). MonadTrav m => Type -> Type -> CExpr -> m ()
checkCast Type
castTy Type
exprTy CExpr
e
TravT Env Identity ()
act
CCall (CVar (Ident String
fname Int
_ NodeInfo
_) NodeInfo
_) [CExpr]
_ NodeInfo
_ -> do
String -> TravT Env Identity ()
Env.pushCtx (String -> TravT Env Identity ())
-> String -> TravT Env Identity ()
forall a b. (a -> b) -> a -> b
$ String
"call:" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
fname
TravT Env Identity ()
act
TravT Env Identity ()
Env.popCtx
CExpr
_ -> TravT Env Identity ()
act
}
analyse :: GlobalDecls -> Trav Env ()
analyse :: GlobalDecls -> TravT Env Identity ()
analyse = AstActions (TravT Env Identity)
-> GlobalDecls -> TravT Env Identity ()
forall a (f :: * -> *).
(TraverseAst a, Applicative f) =>
AstActions f -> a -> f ()
traverseAst AstActions (TravT Env Identity)
linter