{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict            #-}
module Tokstyle.C.Linter.Cast (analyse) where

import           Control.Monad                   (unless, zipWithM_)
import           Data.Functor.Identity           (Identity)
import qualified Data.Map                        as Map
import           Language.C.Analysis.AstAnalysis (ExprSide (..), defaultMD,
                                                  tExpr)
import           Language.C.Analysis.ConstEval   (constEval, intValue)
import           Language.C.Analysis.DefTable    (lookupTag)
import           Language.C.Analysis.SemError    (typeMismatch)
import           Language.C.Analysis.SemRep      (EnumType (..),
                                                  EnumTypeRef (..),
                                                  Enumerator (..), GlobalDecls,
                                                  IntType (..), TagDef (..),
                                                  Type (..), TypeName (..),
                                                  TypeQuals (..), noTypeQuals)
import           Language.C.Analysis.TravMonad   (MonadTrav, Trav, TravT,
                                                  getDefTable, recordError,
                                                  throwTravError)
import           Language.C.Analysis.TypeUtils   (canonicalType, sameType)
import           Language.C.Data.Error           (userErr)
import           Language.C.Data.Ident           (Ident (..))
import           Language.C.Pretty               (pretty)
import           Language.C.Syntax.AST           (CConstant (..), CExpr,
                                                  CExpression (..), annotation)
import           Language.C.Syntax.Constants     (CInteger (..))
import qualified Tokstyle.C.Env                  as Env
import           Tokstyle.C.Env                  (Env)
import           Tokstyle.C.Patterns
import           Tokstyle.C.TraverseAst          (AstActions (..), astActions,
                                                  traverseAst)
import           Tokstyle.C.TravUtils            (getJust)


sameEnum :: MonadTrav m => Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
sameEnum :: Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
sameEnum Type
leftTy Type
rightTy (Ident
leftId, CExpr
leftExpr) (Ident
rightId, CExpr
rightExpr) = do
    Integer
leftVal  <- String -> Maybe Integer -> m Integer
forall (m :: * -> *) a. MonadTrav m => String -> Maybe a -> m a
getJust String
failMsg (Maybe Integer -> m Integer)
-> (CExpr -> Maybe Integer) -> CExpr -> m Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CExpr -> Maybe Integer
intValue (CExpr -> m Integer) -> m CExpr -> m Integer
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
forall (m :: * -> *).
MonadTrav m =>
MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
constEval MachineDesc
defaultMD Map Ident CExpr
forall k a. Map k a
Map.empty CExpr
leftExpr
    Integer
rightVal <- String -> Maybe Integer -> m Integer
forall (m :: * -> *) a. MonadTrav m => String -> Maybe a -> m a
getJust String
failMsg (Maybe Integer -> m Integer)
-> (CExpr -> Maybe Integer) -> CExpr -> m Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CExpr -> Maybe Integer
intValue (CExpr -> m Integer) -> m CExpr -> m Integer
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
forall (m :: * -> *).
MonadTrav m =>
MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
constEval MachineDesc
defaultMD Map Ident CExpr
forall k a. Map k a
Map.empty CExpr
rightExpr
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Integer
leftVal Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
rightVal) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        TypeMismatch -> m ()
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (TypeMismatch -> m ()) -> TypeMismatch -> m ()
forall a b. (a -> b) -> a -> b
$ String -> (NodeInfo, Type) -> (NodeInfo, Type) -> TypeMismatch
typeMismatch
            (String
"invalid cast: enumerator value for `"
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Ident -> Doc
forall p. Pretty p => p -> Doc
pretty Ident
leftId) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" = " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
leftVal
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` does not match `"
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Ident -> Doc
forall p. Pretty p => p -> Doc
pretty Ident
rightId) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" = " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
rightVal String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`")
            (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
leftExpr, Type
leftTy)
            (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
rightExpr, Type
rightTy)
  where
    failMsg :: String
failMsg = String
"invalid cast: could not determine enumerator values"

checkEnumCast :: MonadTrav m => Type -> Type -> CExpr -> m ()
checkEnumCast :: Type -> Type -> CExpr -> m ()
checkEnumCast Type
castTy Type
exprTy CExpr
_ = do
    [(Ident, CExpr)]
castEnums <- Type -> m [(Ident, CExpr)]
forall (m :: * -> *). MonadTrav m => Type -> m [(Ident, CExpr)]
enumerators (Type -> Type
canonicalType Type
castTy)
    [(Ident, CExpr)]
exprEnums <- Type -> m [(Ident, CExpr)]
forall (m :: * -> *). MonadTrav m => Type -> m [(Ident, CExpr)]
enumerators (Type -> Type
canonicalType Type
exprTy)
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([(Ident, CExpr)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Ident, CExpr)]
castEnums Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [(Ident, CExpr)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Ident, CExpr)]
exprEnums) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        UserError -> m ()
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m ()) -> UserError -> m ()
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$
            String
"enum types `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
castTy) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` and `"
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
exprTy) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` have different a number of enumerators"
    ((Ident, CExpr) -> (Ident, CExpr) -> m ())
-> [(Ident, CExpr)] -> [(Ident, CExpr)] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
forall (m :: * -> *).
MonadTrav m =>
Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
sameEnum Type
castTy Type
exprTy) [(Ident, CExpr)]
castEnums [(Ident, CExpr)]
exprEnums

enumerators :: MonadTrav m => Type -> m [(Ident, CExpr)]
enumerators :: Type -> m [(Ident, CExpr)]
enumerators (DirectType (TyEnum (EnumTypeRef SUERef
name NodeInfo
_)) TypeQuals
_ Attributes
_) = do
    DefTable
defs <- m DefTable
forall (m :: * -> *). MonadSymtab m => m DefTable
getDefTable
    case SUERef -> DefTable -> Maybe TagEntry
lookupTag SUERef
name DefTable
defs of
      Just (Right (EnumDef (EnumType SUERef
_ [Enumerator]
enums Attributes
_ NodeInfo
_))) ->
          [(Ident, CExpr)] -> m [(Ident, CExpr)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Ident, CExpr)] -> m [(Ident, CExpr)])
-> [(Ident, CExpr)] -> m [(Ident, CExpr)]
forall a b. (a -> b) -> a -> b
$ (Enumerator -> (Ident, CExpr)) -> [Enumerator] -> [(Ident, CExpr)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Enumerator Ident
i CExpr
e EnumType
_ NodeInfo
_) -> (Ident
i, CExpr
e)) [Enumerator]
enums
      Maybe TagEntry
_ ->
        UserError -> m [(Ident, CExpr)]
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m [(Ident, CExpr)])
-> UserError -> m [(Ident, CExpr)]
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$
            String
"couldn't find enum type `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (SUERef -> Doc
forall p. Pretty p => p -> Doc
pretty SUERef
name) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`"
enumerators Type
ty =
    UserError -> m [(Ident, CExpr)]
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m [(Ident, CExpr)])
-> UserError -> m [(Ident, CExpr)]
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$ String
"invalid enum type `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
ty) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`"


unqual :: Type -> Type
unqual :: Type -> Type
unqual (PtrType Type
ty TypeQuals
_ Attributes
a)      = Type -> TypeQuals -> Attributes -> Type
PtrType (Type -> Type
unqual Type
ty) TypeQuals
noTypeQuals Attributes
a
unqual (DirectType TypeName
tn TypeQuals
_ Attributes
a)   = TypeName -> TypeQuals -> Attributes -> Type
DirectType TypeName
tn TypeQuals
noTypeQuals Attributes
a
unqual (ArrayType Type
ty ArraySize
sz TypeQuals
_ Attributes
a) = Type -> ArraySize -> TypeQuals -> Attributes -> Type
ArrayType (Type -> Type
unqual Type
ty) ArraySize
sz TypeQuals
noTypeQuals Attributes
a
unqual Type
ty                    = Type
ty

checkCast :: MonadTrav m => Type -> Type -> CExpr -> m ()
checkCast :: Type -> Type -> CExpr -> m ()
checkCast Type
castTy' Type
exprTy' CExpr
e
    | Type -> Bool
isCharOrUint8T Type
castTy' Bool -> Bool -> Bool
&& Type -> Bool
isCharOrUint8T Type
exprTy' = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    | Bool
otherwise = Type -> Type -> m ()
forall (m :: * -> *). MonadTrav m => Type -> Type -> m ()
check (Type -> Type
canonicalType Type
castTy') (Type -> Type
canonicalType Type
exprTy')
  where
    isCharOrUint8T :: Type -> Bool
isCharOrUint8T Type
ty = case Type
ty of
        Type
TY_char_ptr    -> Bool
True
        Type
TY_char_arr    -> Bool
True
        Type
TY_uint8_t_ptr -> Bool
True
        Type
TY_uint8_t_arr -> Bool
True
        Type
_              -> Bool
False

    check :: Type -> Type -> m ()
check Type
castTy Type
exprTy | Type -> Type -> Bool
sameType Type
castTy Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Casting from T* to const T* is OK. The other way around isn't, but is caught
    -- by clang and other compilers.
    check (PtrType Type
castPointee TypeQuals
_ Attributes
_) (PtrType Type
exprPointee TypeQuals
_ Attributes
_)
        | Type -> Type -> Bool
sameType (Type -> Type
unqual Type
castPointee) (Type -> Type
unqual Type
exprPointee) = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Casting from T[] to const T* is OK (array decay).
    check (PtrType Type
castPointee TypeQuals
_ Attributes
_) (ArrayType Type
elemTy ArraySize
_ TypeQuals
_ Attributes
_)
        | Type -> Type -> Bool
sameType (Type -> Type
unqual Type
castPointee) (Type -> Type
unqual Type
elemTy) = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Cast to void: OK.
    check (DirectType TypeName
TyVoid TypeQuals
_ Attributes
_) Type
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Casting between `void*` and `T*`: OK
    check PtrType{} Type
TY_void_ptr = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    check Type
TY_void_ptr PtrType{} = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Casting literal 0 to `T*`: OK
    check PtrType{} Type
_ | CExpr -> Bool
forall a. CExpression a -> Bool
isNullPtr CExpr
e = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Casting sockaddr_storage to any of the sockaddr_... types: OK
    check Type
TY_sockaddr_ptr     Type
TY_sockaddr_storage_ptr = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    check Type
TY_sockaddr_in_ptr  Type
TY_sockaddr_storage_ptr = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    check Type
TY_sockaddr_in6_ptr Type
TY_sockaddr_storage_ptr = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Casting between numeric types: OK
    check Type
castTy Type
exprTy | Type -> Bool
isNumeric Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isNumeric Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Casting from enum to int: OK
    check Type
castTy Type
exprTy | Type -> Bool
isIntegral Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isEnum Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Casting between enums: check whether they have the same enumerators.
    check Type
castTy Type
exprTy | Type -> Bool
isEnum Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isEnum Type
exprTy = Type -> Type -> CExpr -> m ()
forall (m :: * -> *). MonadTrav m => Type -> Type -> CExpr -> m ()
checkEnumCast Type
castTy Type
exprTy CExpr
e
    -- Casting to `Messenger**`: NOT OK, but toxav does this.
    -- TODO(iphydf): Fix this.
    check (PtrType (PtrType (TY_typedef String
"Messenger") TypeQuals
_ Attributes
_) TypeQuals
_ Attributes
_) Type
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Casting to `void**`: probably not ok, but toxav also does this.
    -- TODO(iphydf): Investigate.
    check (PtrType Type
TY_void_ptr TypeQuals
_ Attributes
_) Type
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- Casting from int to enum: actually NOT OK, but we do this a lot, so meh.
    -- TODO(iphydf): Fix these.
    check Type
castTy Type
exprTy | Type -> Bool
isEnum Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isIntegral Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    -- Any other casts: NOT OK
    check Type
_ Type
_ =
        let annot :: (NodeInfo, Type)
annot = (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
e, Type
castTy') in
        TypeMismatch -> m ()
forall (m :: * -> *) e. (MonadCError m, Error e) => e -> m ()
recordError (TypeMismatch -> m ()) -> TypeMismatch -> m ()
forall a b. (a -> b) -> a -> b
$ String -> (NodeInfo, Type) -> (NodeInfo, Type) -> TypeMismatch
typeMismatch (String
"disallowed cast from " String -> String -> String
forall a. Semigroup a => a -> a -> a
<>
            Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
exprTy') String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" to " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
castTy')) (NodeInfo, Type)
annot (NodeInfo, Type)
annot

    isNullPtr :: CExpression a -> Bool
isNullPtr (CConst (CIntConst (CInteger Integer
0 CIntRepr
_ Flags CIntFlag
_) a
_)) = Bool
True
    isNullPtr CExpression a
_                                       = Bool
False


-- | Some exemptions where weird casts like int* -> char* may happen.
exemptions :: [String]
exemptions :: [String]
exemptions = [String
"call:getsockopt", String
"call:setsockopt", String
"call:bs_list_add", String
"call:bs_list_remove", String
"call:bs_list_find", String
"call:random_bytes", String
"call:randombytes"]


linter :: AstActions (TravT Env Identity)
linter :: AstActions (TravT Env Identity)
linter = AstActions (TravT Env Identity)
forall (f :: * -> *). Applicative f => AstActions f
astActions
    { doExpr :: CExpr -> TravT Env Identity () -> TravT Env Identity ()
doExpr = \CExpr
node TravT Env Identity ()
act -> case CExpr
node of
        cast :: CExpr
cast@(CCast CDeclaration NodeInfo
_ CExpr
e NodeInfo
_) -> do
            Type
castTy <- [StmtCtx] -> ExprSide -> CExpr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> CExpr -> m Type
tExpr [] ExprSide
RValue CExpr
cast
            Type
exprTy <- [StmtCtx] -> ExprSide -> CExpr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> CExpr -> m Type
tExpr [] ExprSide
RValue CExpr
e
            [String]
ctx <- Trav Env [String]
Env.getCtx
            Bool -> TravT Env Identity () -> TravT Env Identity ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([String] -> String
forall a. [a] -> a
head [String]
ctx String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [String]
exemptions) (TravT Env Identity () -> TravT Env Identity ())
-> TravT Env Identity () -> TravT Env Identity ()
forall a b. (a -> b) -> a -> b
$
                Type -> Type -> CExpr -> TravT Env Identity ()
forall (m :: * -> *). MonadTrav m => Type -> Type -> CExpr -> m ()
checkCast Type
castTy Type
exprTy CExpr
e
            TravT Env Identity ()
act

        CCall (CVar (Ident String
fname Int
_ NodeInfo
_) NodeInfo
_) [CExpr]
_ NodeInfo
_ -> do
            String -> TravT Env Identity ()
Env.pushCtx (String -> TravT Env Identity ())
-> String -> TravT Env Identity ()
forall a b. (a -> b) -> a -> b
$ String
"call:" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
fname
            TravT Env Identity ()
act
            TravT Env Identity ()
Env.popCtx

        CExpr
_ -> TravT Env Identity ()
act
    }


analyse :: GlobalDecls -> Trav Env ()
analyse :: GlobalDecls -> TravT Env Identity ()
analyse = AstActions (TravT Env Identity)
-> GlobalDecls -> TravT Env Identity ()
forall a (f :: * -> *).
(TraverseAst a, Applicative f) =>
AstActions f -> a -> f ()
traverseAst AstActions (TravT Env Identity)
linter