diff --git a/pmd-java/src/main/java/net/sourceforge/pmd/lang/java/typeresolution/typedefinition/JavaTypeDefinition.java b/pmd-java/src/main/java/net/sourceforge/pmd/lang/java/typeresolution/typedefinition/JavaTypeDefinition.java index ab909332ef..a104ed42c7 100644 --- a/pmd-java/src/main/java/net/sourceforge/pmd/lang/java/typeresolution/typedefinition/JavaTypeDefinition.java +++ b/pmd-java/src/main/java/net/sourceforge/pmd/lang/java/typeresolution/typedefinition/JavaTypeDefinition.java @@ -13,8 +13,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; public class JavaTypeDefinition implements TypeDefinition { // contains TypeDefs where only the clazz field is used @@ -296,4 +298,42 @@ public class JavaTypeDefinition implements TypeDefinition { public int hashCode() { return clazz.hashCode(); } + + public Set getSuperTypeSet() { + return getSuperTypeSet(new HashSet()); + } + + private Set getSuperTypeSet(Set destinationSet) { + destinationSet.add(this); + + if (this.clazz != Object.class) { + + resolveTypeDefinition(clazz.getGenericSuperclass()).getSuperTypeSet(destinationSet); + + for (Type type : clazz.getGenericInterfaces()) { + resolveTypeDefinition(type).getSuperTypeSet(destinationSet); + } + } + + return destinationSet; + } + + public Set> getErasedSuperTypeSet() { + Set> result = new HashSet<>(); + result.add(Object.class); + return getErasedSuperTypeSet(this.clazz, result); + } + + private static Set> getErasedSuperTypeSet(Class clazz, Set> destinationSet) { + if (clazz != null) { + destinationSet.add(clazz); + getErasedSuperTypeSet(clazz.getSuperclass(), destinationSet); + + for(Class interfaceType : clazz.getInterfaces()) { + getErasedSuperTypeSet(interfaceType, destinationSet); + } + } + + return destinationSet; + } } diff --git a/pmd-java/src/test/java/net/sourceforge/pmd/typeresolution/ClassTypeResolverTest.java b/pmd-java/src/test/java/net/sourceforge/pmd/typeresolution/ClassTypeResolverTest.java index 7293ab9f9c..50decbf906 100644 --- a/pmd-java/src/test/java/net/sourceforge/pmd/typeresolution/ClassTypeResolverTest.java +++ b/pmd-java/src/test/java/net/sourceforge/pmd/typeresolution/ClassTypeResolverTest.java @@ -4,6 +4,7 @@ package net.sourceforge.pmd.typeresolution; +import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNull; @@ -13,8 +14,11 @@ import java.io.IOException; import java.io.InputStream; import java.io.StringReader; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.Comparator; import java.util.List; +import java.util.Set; import java.util.StringTokenizer; import net.sourceforge.pmd.typeresolution.testdata.dummytypes.JavaTypeDefinitionEquals; @@ -1499,9 +1503,36 @@ public class ClassTypeResolverTest { JavaTypeDefinition.forClass(List.class, a)); assertEquals(a, b); assertEquals(b, a); - } + @Test + public void testJavaTypeDefinitionGetSuperTypeSet() { + JavaTypeDefinition originalTypeDef = JavaTypeDefinition.forClass(List.class, + JavaTypeDefinition.forClass(Integer.class)); + Set set = originalTypeDef.getSuperTypeSet(); + + assertEquals(set.size(), 4); + assertTrue(set.contains(JavaTypeDefinition.forClass(Object.class))); + assertTrue(set.contains(originalTypeDef)); + assertTrue(set.contains(JavaTypeDefinition.forClass(Collection.class, + JavaTypeDefinition.forClass(Integer.class)))); + assertTrue(set.contains(JavaTypeDefinition.forClass(Iterable.class, + JavaTypeDefinition.forClass(Integer.class)))); + } + + @Test + public void testJavaTypeDefinitionGetErasedSuperTypeSet() { + JavaTypeDefinition originalTypeDef = JavaTypeDefinition.forClass(List.class, + JavaTypeDefinition.forClass(Integer.class)); + Set> set = originalTypeDef.getErasedSuperTypeSet(); + assertEquals(set.size(), 4); + assertTrue(set.contains(Object.class)); + assertTrue(set.contains(Collection.class)); + assertTrue(set.contains(Iterable.class)); + assertTrue(set.contains(List.class)); + } + + private Class getChildType(Node node, int childIndex) { return ((TypeNode) node.jjtGetChild(childIndex)).getType();