diff --git a/pmd-apex/src/main/java/net/sourceforge/pmd/lang/apex/rule/security/ApexSOQLInjectionRule.java b/pmd-apex/src/main/java/net/sourceforge/pmd/lang/apex/rule/security/ApexSOQLInjectionRule.java index df4d0fa59a..2851699a97 100644 --- a/pmd-apex/src/main/java/net/sourceforge/pmd/lang/apex/rule/security/ApexSOQLInjectionRule.java +++ b/pmd-apex/src/main/java/net/sourceforge/pmd/lang/apex/rule/security/ApexSOQLInjectionRule.java @@ -4,6 +4,8 @@ package net.sourceforge.pmd.lang.apex.rule.security; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -35,12 +37,12 @@ import net.sourceforge.pmd.lang.apex.rule.internal.Helper; * */ public class ApexSOQLInjectionRule extends AbstractApexRule { - private static final String DOUBLE = "double"; - private static final String LONG = "long"; - private static final String DECIMAL = "decimal"; - private static final String BOOLEAN = "boolean"; - private static final String ID = "id"; - private static final String INTEGER = "integer"; + private static final Set SAFE_VARIABLE_TYPES = Collections.unmodifiableSet( + new HashSet<>(Arrays.asList( + "double", "long", "decimal", "boolean", "id", "integer", + "sobjecttype", "schema.sobjecttype", "sobjectfield", "schema.sobjectfield" + ))); + private static final String JOIN = "join"; private static final String ESCAPE_SINGLE_QUOTES = "escapeSingleQuotes"; private static final String STRING = "String"; @@ -108,23 +110,16 @@ public class ApexSOQLInjectionRule extends AbstractApexRule { return Helper.isMethodName(m, DATABASE, QUERY) || Helper.isMethodName(m, DATABASE, COUNT_QUERY); } + private boolean isSafeVariableType(String typeName) { + return SAFE_VARIABLE_TYPES.contains(typeName.toLowerCase(Locale.ROOT)); + } + private void findSafeVariablesInSignature(ASTMethod m) { for (ASTParameter p : m.findChildrenOfType(ASTParameter.class)) { - switch (p.getType().toLowerCase(Locale.ROOT)) { - case ID: - case INTEGER: - case BOOLEAN: - case DECIMAL: - case LONG: - case DOUBLE: + if (isSafeVariableType(p.getType())) { safeVariables.add(Helper.getFQVariableName(p)); - break; - default: - break; } - } - } private void findSanitizedVariables(ApexNode node) { @@ -159,17 +154,8 @@ public class ApexSOQLInjectionRule extends AbstractApexRule { } if (node instanceof ASTVariableDeclaration) { - switch (((ASTVariableDeclaration) node).getType().toLowerCase(Locale.ROOT)) { - case INTEGER: - case ID: - case BOOLEAN: - case DECIMAL: - case LONG: - case DOUBLE: + if (isSafeVariableType(((ASTVariableDeclaration) node).getType())) { safeVariables.add(Helper.getFQVariableName(left)); - break; - default: - break; } } }