View Javadoc
1   package com.reallifedeveloper.tools.test.database.dbunit;
2   
3   import java.io.FileNotFoundException;
4   import java.io.IOException;
5   import java.io.InputStream;
6   import java.io.Serializable;
7   import java.lang.reflect.Constructor;
8   import java.lang.reflect.Field;
9   import java.lang.reflect.InvocationTargetException;
10  import java.lang.reflect.Method;
11  import java.lang.reflect.ParameterizedType;
12  import java.math.BigDecimal;
13  import java.math.BigInteger;
14  import java.util.ArrayList;
15  import java.util.Date;
16  import java.util.HashSet;
17  import java.util.List;
18  import java.util.Optional;
19  import java.util.Set;
20  
21  import javax.xml.XMLConstants;
22  import javax.xml.parsers.DocumentBuilder;
23  import javax.xml.parsers.DocumentBuilderFactory;
24  import javax.xml.parsers.ParserConfigurationException;
25  
26  import org.checkerframework.checker.nullness.qual.NonNull;
27  import org.slf4j.Logger;
28  import org.slf4j.LoggerFactory;
29  import org.springframework.data.jpa.repository.JpaRepository;
30  import org.w3c.dom.Document;
31  import org.w3c.dom.Element;
32  import org.w3c.dom.NamedNodeMap;
33  import org.w3c.dom.Node;
34  import org.w3c.dom.NodeList;
35  import org.xml.sax.SAXException;
36  
37  import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
38  import jakarta.persistence.Column;
39  import jakarta.persistence.Id;
40  import jakarta.persistence.JoinColumn;
41  import jakarta.persistence.JoinTable;
42  import jakarta.persistence.Table;
43  
44  import com.reallifedeveloper.tools.test.TestUtil;
45  
46  /**
47   * A class to read a DBUnit flat XML dataset file and populate a {@code JpaRepository} using the information in the file.
48   * <p>
49   * This is useful for testing in-memory repositories using the same test cases as for real repository implementations, and also for
50   * populating in-memory repositories for testing services, without having to use a real database.
51   * <p>
52   * TODO: The current implementation only has basic support for "to many" associations (there must be a &amp;JoinTable annotation on a field,
53   * with &amp;JoinColumn annotations), and for enums (an enum must be stored as a string).
54   *
55   * @author RealLifeDeveloper
56   */
57  @SuppressWarnings("PMD")
58  @SuppressFBWarnings(value = "IMPROPER_UNICODE", justification = "equalsIgnoreCase only being used to compare table, column or field names")
59  public final class DbUnitFlatXmlReader {
60  
61      private static final Logger LOG = LoggerFactory.getLogger(DbUnitFlatXmlReader.class);
62  
63      private final DocumentBuilder documentBuilder;
64      private final Set<Class<?>> classes = new HashSet<>();
65      private final List<Object> entities = new ArrayList<>();
66  
67      /**
68       * Creates a new {@code DbUnitFlatXmlReader}.
69       */
70      public DbUnitFlatXmlReader() {
71          try {
72              DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
73              dbf.setValidating(false);
74              dbf.setNamespaceAware(true);
75              dbf.setFeature("http://xml.org/sax/features/namespaces", false);
76              dbf.setFeature("http://xml.org/sax/features/validation", false);
77              dbf.setFeature("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", false);
78              dbf.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false);
79              // Configuration to make the parser safe from XXE attacks while still allowing DTDs.
80              // See https://community.veracode.com/s/article/Java-Remediation-Guidance-for-XXE
81              dbf.setFeature("http://xml.org/sax/features/external-general-entities", false);
82              dbf.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
83              dbf.setXIncludeAware(false);
84              dbf.setExpandEntityReferences(false);
85              dbf.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);
86              documentBuilder = dbf.newDocumentBuilder();
87          } catch (ParserConfigurationException e) {
88              throw new IllegalStateException("Unexpected problem creating XML parser", e);
89          }
90      }
91  
92      /**
93       * Reads a DBUnit flat XML file from the named resource, populating the given repository with entities of the given type.
94       *
95       * @param resourceName   the classpath resource containing a DBUnit flat XML document
96       * @param repository     the repository to populate with the entities from the XML document
97       * @param entityType     the entity class to read
98       * @param primaryKeyType the type of primary key the entities use
99       * @param <T>            the type of entity to read
100      * @param <ID>           the type of the primary key of the entities
101      *
102      * @throws IOException  if reading the file failed
103      * @throws SAXException if parsing the file failed
104      */
105     @SuppressFBWarnings(value = "XXE_DOCUMENT", justification = "XML parser hardened as much as possible, see constructor")
106     public <T, ID extends Serializable> void read(String resourceName, JpaRepository<T, ID> repository, Class<T> entityType,
107             Class<ID> primaryKeyType) throws IOException, SAXException {
108         try (InputStream in = DbUnitFlatXmlReader.class.getResourceAsStream(resourceName)) {
109             if (in == null) {
110                 throw new FileNotFoundException(resourceName);
111             }
112             Document doc = documentBuilder.parse(in);
113             Element dataset = doc.getDocumentElement();
114             NodeList tableRows = dataset.getChildNodes();
115 
116             LOG.info("Reading from {}", resourceName.replaceAll("[\r\n]", ""));
117             for (int i = 0; i < tableRows.getLength(); i++) {
118                 Node tableRowNode = (@NonNull Node) tableRows.item(i);
119                 if (tableRowNode.getNodeType() == Node.ELEMENT_NODE) {
120                     Element tableRow = (Element) tableRowNode;
121                     String tableName = tableRow.getNodeName();
122                     if (tableName.equalsIgnoreCase(getTableName(entityType))) {
123                         handleTableRow(tableRow, entityType, primaryKeyType, repository);
124                     } else {
125                         handlePotentialJoinTable(tableRowNode, tableName);
126                     }
127                 }
128             }
129         } catch (ReflectiveOperationException | SecurityException e) {
130             throw new IllegalStateException("Unexpected problem reading XML file from '" + resourceName + "'", e);
131         }
132     }
133 
134     private <T, ID extends Serializable> void handleTableRow(Element tableRow, Class<T> entityType, Class<ID> primaryKeyType,
135             JpaRepository<T, ID> repository) throws ReflectiveOperationException {
136         T entity = createEntity(entityType);
137         NamedNodeMap attributes = (@NonNull NamedNodeMap) tableRow.getAttributes();
138         for (int j = 0; j < attributes.getLength(); j++) {
139             Node attribute = (@NonNull Node) attributes.item(j);
140             String fieldName = getFieldName(attribute.getNodeName(), entityType);
141             String attributeValue = attribute.getNodeValue();
142             setField(entity, fieldName, attributeValue, primaryKeyType);
143         }
144         entity = repository.save(entity);
145         entities.add(entity);
146         classes.add(entity.getClass());
147     }
148 
149     private void handlePotentialJoinTable(Node tableRow, String tableName)
150             throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
151         joinTableField(tableName).ifPresent(joinTableField -> {
152             joinTableField.setAccessible(true);
153             ParameterizedType parameterizedType = (ParameterizedType) joinTableField.getGenericType();
154             Class<?> targetType = (Class<?>) parameterizedType.getActualTypeArguments()[0];
155             JoinTable joinTable = joinTableField.getAnnotation(JoinTable.class);
156             assert joinTable != null : "JoinTable annotation should be present when the joinTableField method returns a non-empty value";
157             for (JoinColumn joinColumn : joinTable.joinColumns()) {
158                 for (JoinColumn inverseJoinColumn : joinTable.inverseJoinColumns()) {
159                     addEntityFromJoinTable(tableRow, joinTableField, targetType, joinColumn, inverseJoinColumn);
160                 }
161             }
162         });
163     }
164 
165     private Optional<Field> joinTableField(String tableName) {
166         for (Class<?> c : classes) {
167             for (Field field : c.getDeclaredFields()) {
168                 JoinTable joinTable = field.getAnnotation(JoinTable.class);
169                 if (joinTable != null && tableName.equalsIgnoreCase(joinTable.name())) {
170                     return Optional.of(field);
171                 }
172             }
173         }
174         return Optional.empty();
175     }
176 
177     private void addEntityFromJoinTable(Node tableRow, Field joinTableField, Class<?> targetType, JoinColumn joinColumn,
178             JoinColumn inverseJoinColumn) {
179         NamedNodeMap attributes = tableRow.getAttributes();
180         String lhsPrimaryKey = null;
181         String rhsPrimaryKey = null;
182         for (int j = 0; j < attributes.getLength(); j++) {
183             Node attribute = attributes.item(j);
184             if (attribute.getNodeName().equalsIgnoreCase(joinColumn.name())) {
185                 lhsPrimaryKey = attribute.getNodeValue();
186             } else if (attribute.getNodeName().equalsIgnoreCase(inverseJoinColumn.name())) {
187                 rhsPrimaryKey = attribute.getNodeValue();
188             }
189         }
190         if (lhsPrimaryKey == null || rhsPrimaryKey == null) {
191             throw new IllegalStateException(
192                     "Failed to find join table: missing attribute in DBUnit XML file: '" + joinColumn.name() + "' or '"
193                             + inverseJoinColumn.name() + "'");
194         }
195         Object lhs = findEntity(lhsPrimaryKey, joinTableField.getDeclaringClass());
196         Object rhs = findEntity(rhsPrimaryKey, targetType);
197         try {
198             Method add = joinTableField.getType().getMethod("add", Object.class);
199             add.invoke(joinTableField.get(lhs), rhs);
200         } catch (NoSuchMethodException e) {
201             throw new IllegalStateException("Method 'add' not found -- @JoinTable annotation should be on a Collection", e);
202         } catch (IllegalAccessException | InvocationTargetException e) {
203             throw new IllegalStateException("Unexpected problem", e);
204         }
205     }
206 
207     private <T> String getTableName(Class<T> entityType) {
208         Table table = entityType.getAnnotation(Table.class);
209         if (table == null) {
210             return entityType.getSimpleName();
211         } else {
212             return table.name();
213         }
214     }
215 
216     private <T> String getFieldName(String attributeName, Class<T> entityType) {
217         for (Field field : entityType.getDeclaredFields()) {
218             if (checkFieldName(attributeName, field)) {
219                 return field.getName();
220             }
221         }
222         if (entityType.getSuperclass() == null) {
223             throw new IllegalArgumentException("Cannot find any field matching attribute '" + attributeName + "' for " + entityType);
224         } else {
225             return getFieldName(attributeName, entityType.getSuperclass());
226         }
227     }
228 
229     private boolean checkFieldName(String attributeName, Field field) {
230         Column column = field.getAnnotation(Column.class);
231         if (column == null || column.name() == null) {
232             JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
233             if (joinColumn == null || joinColumn.name() == null) {
234                 return field.getName().equalsIgnoreCase(attributeName);
235             } else {
236                 return joinColumn.name().equalsIgnoreCase(attributeName);
237             }
238         } else {
239             return column.name().equalsIgnoreCase(attributeName);
240         }
241     }
242 
243     private <T> T createEntity(Class<T> entityType) throws ReflectiveOperationException {
244         Constructor<T> constructor = entityType.getDeclaredConstructor();
245         constructor.setAccessible(true);
246         return constructor.newInstance();
247     }
248 
249     private <T, ID> void setField(T entity, String fieldName, String attributeValue, Class<ID> primaryKeyType)
250             throws ReflectiveOperationException {
251         Field field = getField(entity, fieldName);
252         field.setAccessible(true);
253         Object fieldValue = createObjectFromString(attributeValue, field, primaryKeyType);
254         field.set(entity, fieldValue);
255     }
256 
257     private Field getField(Object entity, String fieldName) throws NoSuchFieldException {
258         Class<?> entityType = entity.getClass();
259         while (entityType != null) {
260             for (Field field : entityType.getDeclaredFields()) {
261                 if (field.getName().equalsIgnoreCase(fieldName)) {
262                     return field;
263                 }
264             }
265             entityType = entityType.getSuperclass();
266         }
267         throw new NoSuchFieldException(fieldName);
268     }
269 
270     private Object createObjectFromString(String s, Field field, Class<?> primaryKeyType) {
271         Class<?> type;
272         if (field.getAnnotation(Id.class) != null) {
273             type = primaryKeyType;
274         } else {
275             type = field.getType();
276         }
277         return createObjectFromString(s, type);
278     }
279 
280     private Object createObjectFromString(String s, Class<?> type) {
281         if (type == Byte.class) {
282             return Byte.parseByte(s);
283         } else if (type == Short.class) {
284             return Short.parseShort(s);
285         } else if (type == Integer.class) {
286             return Integer.parseInt(s);
287         } else if (type == Long.class) {
288             return Long.parseLong(s);
289         } else if (type == Float.class) {
290             return Float.parseFloat(s);
291         } else if (type == Double.class) {
292             return Double.parseDouble(s);
293         } else if (type == Boolean.class) {
294             return Boolean.parseBoolean(s);
295         } else if (type == Character.class) {
296             return s.charAt(0);
297         } else if (type == String.class) {
298             return s;
299         } else if (type == Date.class) {
300             return TestUtil.parseDate(s);
301         } else if (type == BigDecimal.class) {
302             return new BigDecimal(s);
303         } else if (type == BigInteger.class) {
304             return new BigInteger(s);
305         } else {
306             return findEntity(s, type);
307         }
308     }
309 
310     @SuppressWarnings({ "rawtypes", "unchecked" })
311     private Object findEntity(String strId, Class<?> entityType) {
312         if (entityType.isEnum()) {
313             Class<? extends Enum> enumType = (Class<? extends Enum>) entityType;
314             return Enum.valueOf(enumType, strId);
315         }
316         for (Object entity : entities) {
317             if (entity.getClass().equals(entityType)) {
318                 Field idField = getIdField(entity);
319                 idField.setAccessible(true);
320                 try {
321                     Object id = idField.get(entity);
322                     if (id != null && id.equals(createObjectFromString(strId, id.getClass()))) {
323                         return entity;
324                     }
325                 } catch (IllegalAccessException e) {
326                     throw new IllegalStateException("Unexpected problem looking up entity of " + entityType + " with primary key " + strId,
327                             e);
328                 }
329             }
330         }
331         throw new IllegalArgumentException("Entity of " + entityType + " with primary key " + strId + " not found");
332     }
333 
334     private Field getIdField(Object entity) {
335         Class<?> entityType = entity.getClass();
336         while (entityType != null) {
337             for (Field field : entityType.getDeclaredFields()) {
338                 if (field.getDeclaredAnnotation(Id.class) != null) {
339                     return field;
340                 }
341             }
342             entityType = entityType.getSuperclass();
343         }
344         throw new IllegalStateException("Id field not found for entity " + entity);
345     }
346 
347 }