View Javadoc
1   package com.reallifedeveloper.tools.test.database.dbunit;
2   
3   import java.io.FileNotFoundException;
4   import java.io.IOException;
5   import java.io.InputStream;
6   import java.io.Serializable;
7   import java.lang.reflect.Constructor;
8   import java.lang.reflect.Field;
9   import java.lang.reflect.InvocationTargetException;
10  import java.lang.reflect.Method;
11  import java.lang.reflect.ParameterizedType;
12  import java.math.BigDecimal;
13  import java.math.BigInteger;
14  import java.util.ArrayList;
15  import java.util.Date;
16  import java.util.HashSet;
17  import java.util.List;
18  import java.util.Optional;
19  import java.util.Set;
20  
21  import javax.xml.parsers.DocumentBuilder;
22  import javax.xml.parsers.DocumentBuilderFactory;
23  import javax.xml.parsers.ParserConfigurationException;
24  
25  import org.checkerframework.checker.nullness.qual.NonNull;
26  import org.checkerframework.checker.nullness.qual.Nullable;
27  import org.slf4j.Logger;
28  import org.slf4j.LoggerFactory;
29  import org.springframework.data.jpa.repository.JpaRepository;
30  import org.w3c.dom.Document;
31  import org.w3c.dom.Element;
32  import org.w3c.dom.NamedNodeMap;
33  import org.w3c.dom.Node;
34  import org.w3c.dom.NodeList;
35  import org.xml.sax.SAXException;
36  
37  import jakarta.persistence.Column;
38  import jakarta.persistence.Id;
39  import jakarta.persistence.JoinColumn;
40  import jakarta.persistence.JoinTable;
41  import jakarta.persistence.Table;
42  
43  import com.reallifedeveloper.tools.test.TestUtil;
44  
45  /**
46   * A class to read a DBUnit flat XML dataset file and populate a {@code JpaRepository} using the information in the file.
47   * <p>
48   * This is useful for testing in-memory repositories using the same test cases as for real repository implementations, and also for
49   * populating in-memory repositories for testing services, without having to use a real database.
50   * <p>
51   * TODO: The current implementation only has basic support for "to many" associations (there must be a &amp;JoinTable annotation on a field,
52   * with &amp;JoinColumn annotations), and for enums (an enum must be stored as a string).
53   *
54   * @author RealLifeDeveloper
55   */
56  @SuppressWarnings("PMD")
57  public final class DbUnitFlatXmlReader {
58  
59      private static final Logger LOG = LoggerFactory.getLogger(DbUnitFlatXmlReader.class);
60  
61      private final DocumentBuilder documentBuilder;
62      private final Set<Class<?>> classes = new HashSet<>();
63      private final List<Object> entities = new ArrayList<>();
64  
65      /**
66       * Creates a new {@code DbUnitFlatXmlReader}.
67       */
68      public DbUnitFlatXmlReader() {
69          try {
70              DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
71              dbf.setValidating(false);
72              dbf.setNamespaceAware(true);
73              dbf.setFeature("http://xml.org/sax/features/namespaces", false);
74              dbf.setFeature("http://xml.org/sax/features/validation", false);
75              dbf.setFeature("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", false);
76              dbf.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false);
77              documentBuilder = dbf.newDocumentBuilder();
78          } catch (ParserConfigurationException e) {
79              throw new IllegalStateException("Unexpected problem creating XML parser", e);
80          }
81      }
82  
83      /**
84       * Reads a DBUnit flat XML file from the named resource, populating the given repository with entities of the given type.
85       *
86       * @param resourceName   the classpath resource containing a DBUnit flat XML document
87       * @param repository     the repository to populate with the entities from the XML document
88       * @param entityType     the entity class to read
89       * @param primaryKeyType the type of primary key the entities use
90       * @param <T>            the type of entity to read
91       * @param <ID>           the type of the primary key of the entities
92       *
93       * @throws IOException  if reading the file failed
94       * @throws SAXException if parsing the file failed
95       */
96      public <T, ID extends Serializable> void read(String resourceName, JpaRepository<T, ID> repository, Class<T> entityType,
97              Class<ID> primaryKeyType) throws IOException, SAXException {
98          try (InputStream in = DbUnitFlatXmlReader.class.getResourceAsStream(resourceName)) {
99              if (in == null) {
100                 throw new FileNotFoundException(resourceName);
101             }
102             Document doc = documentBuilder.parse(in);
103             Element dataset = doc.getDocumentElement();
104             NodeList tableRows = dataset.getChildNodes();
105 
106             LOG.info("Reading from {}", resourceName);
107             for (int i = 0; i < tableRows.getLength(); i++) {
108                 Node tableRowNode = (@NonNull Node) tableRows.item(i);
109                 if (tableRowNode.getNodeType() == Node.ELEMENT_NODE) {
110                     Element tableRow = (Element) tableRowNode;
111                     String tableName = tableRow.getNodeName();
112                     if (tableName.equalsIgnoreCase(getTableName(entityType))) {
113                         handleTableRow(tableRow, entityType, primaryKeyType, repository);
114                     } else {
115                         handlePotentialJoinTable(tableRowNode, tableName);
116                     }
117                 }
118             }
119         } catch (ReflectiveOperationException | SecurityException e) {
120             throw new IllegalStateException("Unexpected problem reading XML file from '" + resourceName + "'", e);
121         }
122     }
123 
124     private <T, ID extends Serializable> void handleTableRow(Element tableRow, Class<T> entityType, Class<ID> primaryKeyType,
125             JpaRepository<T, ID> repository) throws ReflectiveOperationException {
126         T entity = createEntity(entityType);
127         NamedNodeMap attributes = (@NonNull NamedNodeMap) tableRow.getAttributes();
128         for (int j = 0; j < attributes.getLength(); j++) {
129             Node attribute = (@NonNull Node) attributes.item(j);
130             String fieldName = getFieldName(attribute.getNodeName(), entityType);
131             String attributeValue = attribute.getNodeValue();
132             setField(entity, fieldName, attributeValue, primaryKeyType);
133         }
134         entity = repository.save(entity);
135         entities.add(entity);
136         classes.add(entity.getClass());
137     }
138 
139     private void handlePotentialJoinTable(Node tableRow, String tableName)
140             throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
141         joinTableField(tableName).ifPresent(joinTableField -> {
142             joinTableField.setAccessible(true);
143             ParameterizedType parameterizedType = (ParameterizedType) joinTableField.getGenericType();
144             Class<?> targetType = (Class<?>) parameterizedType.getActualTypeArguments()[0];
145             JoinTable joinTable = joinTableField.getAnnotation(JoinTable.class);
146             assert joinTable != null : "JoinTable annotation should be present when the joinTableField method returns a non-empty value";
147             for (JoinColumn joinColumn : joinTable.joinColumns()) {
148                 for (JoinColumn inverseJoinColumn : joinTable.inverseJoinColumns()) {
149                     addEntityFromJoinTable(tableRow, joinTableField, targetType, joinColumn, inverseJoinColumn);
150                 }
151             }
152         });
153     }
154 
155     private Optional<Field> joinTableField(String tableName) {
156         for (Class<?> c : classes) {
157             for (Field field : c.getDeclaredFields()) {
158                 JoinTable joinTable = field.getAnnotation(JoinTable.class);
159                 if (joinTable != null && tableName.equalsIgnoreCase(joinTable.name())) {
160                     return Optional.of(field);
161                 }
162             }
163         }
164         return Optional.empty();
165     }
166 
167     private void addEntityFromJoinTable(Node tableRow, Field joinTableField, Class<?> targetType, JoinColumn joinColumn,
168             JoinColumn inverseJoinColumn) {
169         NamedNodeMap attributes = tableRow.getAttributes();
170         String lhsPrimaryKey = null;
171         String rhsPrimaryKey = null;
172         for (int j = 0; j < attributes.getLength(); j++) {
173             Node attribute = attributes.item(j);
174             if (attribute.getNodeName().equalsIgnoreCase(joinColumn.name())) {
175                 lhsPrimaryKey = attribute.getNodeValue();
176             } else if (attribute.getNodeName().equalsIgnoreCase(inverseJoinColumn.name())) {
177                 rhsPrimaryKey = attribute.getNodeValue();
178             }
179         }
180         Object lhs = findEntity(lhsPrimaryKey, joinTableField.getDeclaringClass());
181         Object rhs = findEntity(rhsPrimaryKey, targetType);
182         try {
183             Method add = joinTableField.getType().getMethod("add", Object.class);
184             add.invoke(joinTableField.get(lhs), rhs);
185         } catch (NoSuchMethodException e) {
186             throw new IllegalStateException("Method 'add' not found -- @JoinTable annotation should be on a Collection", e);
187         } catch (IllegalAccessException | InvocationTargetException e) {
188             throw new IllegalStateException("Unexpected problem", e);
189         }
190     }
191 
192     private <T> String getTableName(Class<T> entityType) {
193         Table table = entityType.getAnnotation(Table.class);
194         if (table == null) {
195             return entityType.getSimpleName();
196         } else {
197             return table.name();
198         }
199     }
200 
201     private <T> String getFieldName(String attributeName, Class<T> entityType) {
202         for (Field field : entityType.getDeclaredFields()) {
203             if (checkFieldName(attributeName, field)) {
204                 return field.getName();
205             }
206         }
207         if (entityType.getSuperclass() == null) {
208             throw new IllegalArgumentException("Cannot find any field matching attribute '" + attributeName + "' for " + entityType);
209         } else {
210             return getFieldName(attributeName, entityType.getSuperclass());
211         }
212     }
213 
214     private boolean checkFieldName(String attributeName, Field field) {
215         Column column = field.getAnnotation(Column.class);
216         if (column == null || column.name() == null) {
217             JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
218             if (joinColumn == null || joinColumn.name() == null) {
219                 return field.getName().equalsIgnoreCase(attributeName);
220             } else {
221                 return joinColumn.name().equalsIgnoreCase(attributeName);
222             }
223         } else {
224             return column.name().equalsIgnoreCase(attributeName);
225         }
226     }
227 
228     private <T> T createEntity(Class<T> entityType) throws ReflectiveOperationException {
229         Constructor<T> constructor = entityType.getDeclaredConstructor();
230         constructor.setAccessible(true);
231         return constructor.newInstance();
232     }
233 
234     private <T, ID> void setField(T entity, String fieldName, @Nullable String attributeValue, Class<ID> primaryKeyType)
235             throws ReflectiveOperationException {
236         Field field = getField(entity, fieldName);
237         field.setAccessible(true);
238         Object fieldValue = createObjectFromString(attributeValue, field, primaryKeyType);
239         field.set(entity, fieldValue);
240     }
241 
242     private Field getField(Object entity, String fieldName) throws NoSuchFieldException {
243         Class<?> entityType = entity.getClass();
244         while (entityType != null) {
245             for (Field field : entityType.getDeclaredFields()) {
246                 if (field.getName().equalsIgnoreCase(fieldName)) {
247                     return field;
248                 }
249             }
250             entityType = entityType.getSuperclass();
251         }
252         throw new NoSuchFieldException(fieldName);
253     }
254 
255     private Object createObjectFromString(String s, Field field, Class<?> primaryKeyType) {
256         Class<?> type;
257         if (field.getAnnotation(Id.class) != null) {
258             type = primaryKeyType;
259         } else {
260             type = field.getType();
261         }
262         return createObjectFromString(s, type);
263     }
264 
265     private Object createObjectFromString(String s, Class<?> type) {
266         if (type == Byte.class) {
267             return Byte.parseByte(s);
268         } else if (type == Short.class) {
269             return Short.parseShort(s);
270         } else if (type == Integer.class) {
271             return Integer.parseInt(s);
272         } else if (type == Long.class) {
273             return Long.parseLong(s);
274         } else if (type == Float.class) {
275             return Float.parseFloat(s);
276         } else if (type == Double.class) {
277             return Double.parseDouble(s);
278         } else if (type == Boolean.class) {
279             return Boolean.parseBoolean(s);
280         } else if (type == Character.class) {
281             return s.charAt(0);
282         } else if (type == String.class) {
283             return s;
284         } else if (type == Date.class) {
285             return TestUtil.parseDate(s);
286         } else if (type == BigDecimal.class) {
287             return new BigDecimal(s);
288         } else if (type == BigInteger.class) {
289             return new BigInteger(s);
290         } else {
291             return findEntity(s, type);
292         }
293     }
294 
295     @SuppressWarnings({ "rawtypes", "unchecked" })
296     private Object findEntity(String strId, Class<?> entityType) {
297         if (entityType.isEnum()) {
298             Class<? extends Enum> enumType = (Class<? extends Enum>) entityType;
299             return Enum.valueOf(enumType, strId);
300         }
301         for (Object entity : entities) {
302             if (entity.getClass().equals(entityType)) {
303                 Field idField = getIdField(entity);
304                 idField.setAccessible(true);
305                 try {
306                     Object id = idField.get(entity);
307                     if (id != null && id.equals(createObjectFromString(strId, id.getClass()))) {
308                         return entity;
309                     }
310                 } catch (IllegalAccessException e) {
311                     throw new IllegalStateException("Unexpected problem looking up entity of " + entityType + " with primary key " + strId,
312                             e);
313                 }
314             }
315         }
316         throw new IllegalArgumentException("Entity of " + entityType + " with primary key " + strId + " not found");
317     }
318 
319     private Field getIdField(Object entity) {
320         Class<?> entityType = entity.getClass();
321         while (entityType != null) {
322             for (Field field : entityType.getDeclaredFields()) {
323                 if (field.getDeclaredAnnotation(Id.class) != null) {
324                     return field;
325                 }
326             }
327             entityType = entityType.getSuperclass();
328         }
329         throw new IllegalStateException("Id field not found for entity " + entity);
330     }
331 
332 }