1 package com.reallifedeveloper.tools.test.database.dbunit;
2
3 import java.io.FileNotFoundException;
4 import java.io.IOException;
5 import java.io.InputStream;
6 import java.io.Serializable;
7 import java.lang.reflect.Constructor;
8 import java.lang.reflect.Field;
9 import java.lang.reflect.InvocationTargetException;
10 import java.lang.reflect.Method;
11 import java.lang.reflect.ParameterizedType;
12 import java.math.BigDecimal;
13 import java.math.BigInteger;
14 import java.util.ArrayList;
15 import java.util.Date;
16 import java.util.HashSet;
17 import java.util.List;
18 import java.util.Optional;
19 import java.util.Set;
20
21 import javax.xml.XMLConstants;
22 import javax.xml.parsers.DocumentBuilder;
23 import javax.xml.parsers.DocumentBuilderFactory;
24 import javax.xml.parsers.ParserConfigurationException;
25
26 import org.checkerframework.checker.nullness.qual.NonNull;
27 import org.slf4j.Logger;
28 import org.slf4j.LoggerFactory;
29 import org.springframework.data.jpa.repository.JpaRepository;
30 import org.w3c.dom.Document;
31 import org.w3c.dom.Element;
32 import org.w3c.dom.NamedNodeMap;
33 import org.w3c.dom.Node;
34 import org.w3c.dom.NodeList;
35 import org.xml.sax.SAXException;
36
37 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
38 import jakarta.persistence.Column;
39 import jakarta.persistence.Id;
40 import jakarta.persistence.JoinColumn;
41 import jakarta.persistence.JoinTable;
42 import jakarta.persistence.Table;
43
44 import com.reallifedeveloper.tools.test.TestUtil;
45
46
47
48
49
50
51
52
53
54
55
56
57 @SuppressWarnings("PMD")
58 @SuppressFBWarnings(value = "IMPROPER_UNICODE", justification = "equalsIgnoreCase only being used to compare table, column or field names")
59 public final class DbUnitFlatXmlReader {
60
61 private static final Logger LOG = LoggerFactory.getLogger(DbUnitFlatXmlReader.class);
62
63 private final DocumentBuilder documentBuilder;
64 private final Set<Class<?>> classes = new HashSet<>();
65 private final List<Object> entities = new ArrayList<>();
66
67
68
69
70 public DbUnitFlatXmlReader() {
71 try {
72 DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
73 dbf.setValidating(false);
74 dbf.setNamespaceAware(true);
75 dbf.setFeature("http://xml.org/sax/features/namespaces", false);
76 dbf.setFeature("http://xml.org/sax/features/validation", false);
77 dbf.setFeature("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", false);
78 dbf.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false);
79
80
81 dbf.setFeature("http://xml.org/sax/features/external-general-entities", false);
82 dbf.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
83 dbf.setXIncludeAware(false);
84 dbf.setExpandEntityReferences(false);
85 dbf.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);
86 documentBuilder = dbf.newDocumentBuilder();
87 } catch (ParserConfigurationException e) {
88 throw new IllegalStateException("Unexpected problem creating XML parser", e);
89 }
90 }
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105 @SuppressFBWarnings(value = "XXE_DOCUMENT", justification = "XML parser hardened as much as possible, see constructor")
106 public <T, ID extends Serializable> void read(String resourceName, JpaRepository<T, ID> repository, Class<T> entityType,
107 Class<ID> primaryKeyType) throws IOException, SAXException {
108 try (InputStream in = DbUnitFlatXmlReader.class.getResourceAsStream(resourceName)) {
109 if (in == null) {
110 throw new FileNotFoundException(resourceName);
111 }
112 Document doc = documentBuilder.parse(in);
113 Element dataset = doc.getDocumentElement();
114 NodeList tableRows = dataset.getChildNodes();
115
116 LOG.info("Reading from {}", resourceName.replaceAll("[\r\n]", ""));
117 for (int i = 0; i < tableRows.getLength(); i++) {
118 Node tableRowNode = (@NonNull Node) tableRows.item(i);
119 if (tableRowNode.getNodeType() == Node.ELEMENT_NODE) {
120 Element tableRow = (Element) tableRowNode;
121 String tableName = tableRow.getNodeName();
122 if (tableName.equalsIgnoreCase(getTableName(entityType))) {
123 handleTableRow(tableRow, entityType, primaryKeyType, repository);
124 } else {
125 handlePotentialJoinTable(tableRowNode, tableName);
126 }
127 }
128 }
129 } catch (ReflectiveOperationException | SecurityException e) {
130 throw new IllegalStateException("Unexpected problem reading XML file from '" + resourceName + "'", e);
131 }
132 }
133
134 private <T, ID extends Serializable> void handleTableRow(Element tableRow, Class<T> entityType, Class<ID> primaryKeyType,
135 JpaRepository<T, ID> repository) throws ReflectiveOperationException {
136 T entity = createEntity(entityType);
137 NamedNodeMap attributes = (@NonNull NamedNodeMap) tableRow.getAttributes();
138 for (int j = 0; j < attributes.getLength(); j++) {
139 Node attribute = (@NonNull Node) attributes.item(j);
140 String fieldName = getFieldName(attribute.getNodeName(), entityType);
141 String attributeValue = attribute.getNodeValue();
142 setField(entity, fieldName, attributeValue, primaryKeyType);
143 }
144 entity = repository.save(entity);
145 entities.add(entity);
146 classes.add(entity.getClass());
147 }
148
149 private void handlePotentialJoinTable(Node tableRow, String tableName)
150 throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
151 joinTableField(tableName).ifPresent(joinTableField -> {
152 joinTableField.setAccessible(true);
153 ParameterizedType parameterizedType = (ParameterizedType) joinTableField.getGenericType();
154 Class<?> targetType = (Class<?>) parameterizedType.getActualTypeArguments()[0];
155 JoinTable joinTable = joinTableField.getAnnotation(JoinTable.class);
156 assert joinTable != null : "JoinTable annotation should be present when the joinTableField method returns a non-empty value";
157 for (JoinColumn joinColumn : joinTable.joinColumns()) {
158 for (JoinColumn inverseJoinColumn : joinTable.inverseJoinColumns()) {
159 addEntityFromJoinTable(tableRow, joinTableField, targetType, joinColumn, inverseJoinColumn);
160 }
161 }
162 });
163 }
164
165 private Optional<Field> joinTableField(String tableName) {
166 for (Class<?> c : classes) {
167 for (Field field : c.getDeclaredFields()) {
168 JoinTable joinTable = field.getAnnotation(JoinTable.class);
169 if (joinTable != null && tableName.equalsIgnoreCase(joinTable.name())) {
170 return Optional.of(field);
171 }
172 }
173 }
174 return Optional.empty();
175 }
176
177 private void addEntityFromJoinTable(Node tableRow, Field joinTableField, Class<?> targetType, JoinColumn joinColumn,
178 JoinColumn inverseJoinColumn) {
179 NamedNodeMap attributes = tableRow.getAttributes();
180 String lhsPrimaryKey = null;
181 String rhsPrimaryKey = null;
182 for (int j = 0; j < attributes.getLength(); j++) {
183 Node attribute = attributes.item(j);
184 if (attribute.getNodeName().equalsIgnoreCase(joinColumn.name())) {
185 lhsPrimaryKey = attribute.getNodeValue();
186 } else if (attribute.getNodeName().equalsIgnoreCase(inverseJoinColumn.name())) {
187 rhsPrimaryKey = attribute.getNodeValue();
188 }
189 }
190 if (lhsPrimaryKey == null || rhsPrimaryKey == null) {
191 throw new IllegalStateException(
192 "Failed to find join table: missing attribute in DBUnit XML file: '" + joinColumn.name() + "' or '"
193 + inverseJoinColumn.name() + "'");
194 }
195 Object lhs = findEntity(lhsPrimaryKey, joinTableField.getDeclaringClass());
196 Object rhs = findEntity(rhsPrimaryKey, targetType);
197 try {
198 Method add = joinTableField.getType().getMethod("add", Object.class);
199 add.invoke(joinTableField.get(lhs), rhs);
200 } catch (NoSuchMethodException e) {
201 throw new IllegalStateException("Method 'add' not found -- @JoinTable annotation should be on a Collection", e);
202 } catch (IllegalAccessException | InvocationTargetException e) {
203 throw new IllegalStateException("Unexpected problem", e);
204 }
205 }
206
207 private <T> String getTableName(Class<T> entityType) {
208 Table table = entityType.getAnnotation(Table.class);
209 if (table == null) {
210 return entityType.getSimpleName();
211 } else {
212 return table.name();
213 }
214 }
215
216 private <T> String getFieldName(String attributeName, Class<T> entityType) {
217 for (Field field : entityType.getDeclaredFields()) {
218 if (checkFieldName(attributeName, field)) {
219 return field.getName();
220 }
221 }
222 if (entityType.getSuperclass() == null) {
223 throw new IllegalArgumentException("Cannot find any field matching attribute '" + attributeName + "' for " + entityType);
224 } else {
225 return getFieldName(attributeName, entityType.getSuperclass());
226 }
227 }
228
229 private boolean checkFieldName(String attributeName, Field field) {
230 Column column = field.getAnnotation(Column.class);
231 if (column == null || column.name() == null) {
232 JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
233 if (joinColumn == null || joinColumn.name() == null) {
234 return field.getName().equalsIgnoreCase(attributeName);
235 } else {
236 return joinColumn.name().equalsIgnoreCase(attributeName);
237 }
238 } else {
239 return column.name().equalsIgnoreCase(attributeName);
240 }
241 }
242
243 private <T> T createEntity(Class<T> entityType) throws ReflectiveOperationException {
244 Constructor<T> constructor = entityType.getDeclaredConstructor();
245 constructor.setAccessible(true);
246 return constructor.newInstance();
247 }
248
249 private <T, ID> void setField(T entity, String fieldName, String attributeValue, Class<ID> primaryKeyType)
250 throws ReflectiveOperationException {
251 Field field = getField(entity, fieldName);
252 field.setAccessible(true);
253 Object fieldValue = createObjectFromString(attributeValue, field, primaryKeyType);
254 field.set(entity, fieldValue);
255 }
256
257 private Field getField(Object entity, String fieldName) throws NoSuchFieldException {
258 Class<?> entityType = entity.getClass();
259 while (entityType != null) {
260 for (Field field : entityType.getDeclaredFields()) {
261 if (field.getName().equalsIgnoreCase(fieldName)) {
262 return field;
263 }
264 }
265 entityType = entityType.getSuperclass();
266 }
267 throw new NoSuchFieldException(fieldName);
268 }
269
270 private Object createObjectFromString(String s, Field field, Class<?> primaryKeyType) {
271 Class<?> type;
272 if (field.getAnnotation(Id.class) != null) {
273 type = primaryKeyType;
274 } else {
275 type = field.getType();
276 }
277 return createObjectFromString(s, type);
278 }
279
280 private Object createObjectFromString(String s, Class<?> type) {
281 if (type == Byte.class) {
282 return Byte.parseByte(s);
283 } else if (type == Short.class) {
284 return Short.parseShort(s);
285 } else if (type == Integer.class) {
286 return Integer.parseInt(s);
287 } else if (type == Long.class) {
288 return Long.parseLong(s);
289 } else if (type == Float.class) {
290 return Float.parseFloat(s);
291 } else if (type == Double.class) {
292 return Double.parseDouble(s);
293 } else if (type == Boolean.class) {
294 return Boolean.parseBoolean(s);
295 } else if (type == Character.class) {
296 return s.charAt(0);
297 } else if (type == String.class) {
298 return s;
299 } else if (type == Date.class) {
300 return TestUtil.parseDate(s);
301 } else if (type == BigDecimal.class) {
302 return new BigDecimal(s);
303 } else if (type == BigInteger.class) {
304 return new BigInteger(s);
305 } else {
306 return findEntity(s, type);
307 }
308 }
309
310 @SuppressWarnings({ "rawtypes", "unchecked" })
311 private Object findEntity(String strId, Class<?> entityType) {
312 if (entityType.isEnum()) {
313 Class<? extends Enum> enumType = (Class<? extends Enum>) entityType;
314 return Enum.valueOf(enumType, strId);
315 }
316 for (Object entity : entities) {
317 if (entity.getClass().equals(entityType)) {
318 Field idField = getIdField(entity);
319 idField.setAccessible(true);
320 try {
321 Object id = idField.get(entity);
322 if (id != null && id.equals(createObjectFromString(strId, id.getClass()))) {
323 return entity;
324 }
325 } catch (IllegalAccessException e) {
326 throw new IllegalStateException("Unexpected problem looking up entity of " + entityType + " with primary key " + strId,
327 e);
328 }
329 }
330 }
331 throw new IllegalArgumentException("Entity of " + entityType + " with primary key " + strId + " not found");
332 }
333
334 private Field getIdField(Object entity) {
335 Class<?> entityType = entity.getClass();
336 while (entityType != null) {
337 for (Field field : entityType.getDeclaredFields()) {
338 if (field.getDeclaredAnnotation(Id.class) != null) {
339 return field;
340 }
341 }
342 entityType = entityType.getSuperclass();
343 }
344 throw new IllegalStateException("Id field not found for entity " + entity);
345 }
346
347 }