1 package com.reallifedeveloper.tools.test.database.dbunit;
2
3 import java.io.FileNotFoundException;
4 import java.io.IOException;
5 import java.io.InputStream;
6 import java.io.Serializable;
7 import java.lang.reflect.Constructor;
8 import java.lang.reflect.Field;
9 import java.lang.reflect.InvocationTargetException;
10 import java.lang.reflect.Method;
11 import java.lang.reflect.ParameterizedType;
12 import java.math.BigDecimal;
13 import java.math.BigInteger;
14 import java.util.ArrayList;
15 import java.util.Date;
16 import java.util.HashSet;
17 import java.util.List;
18 import java.util.Optional;
19 import java.util.Set;
20
21 import javax.xml.parsers.DocumentBuilder;
22 import javax.xml.parsers.DocumentBuilderFactory;
23 import javax.xml.parsers.ParserConfigurationException;
24
25 import org.checkerframework.checker.nullness.qual.NonNull;
26 import org.checkerframework.checker.nullness.qual.Nullable;
27 import org.slf4j.Logger;
28 import org.slf4j.LoggerFactory;
29 import org.springframework.data.jpa.repository.JpaRepository;
30 import org.w3c.dom.Document;
31 import org.w3c.dom.Element;
32 import org.w3c.dom.NamedNodeMap;
33 import org.w3c.dom.Node;
34 import org.w3c.dom.NodeList;
35 import org.xml.sax.SAXException;
36
37 import jakarta.persistence.Column;
38 import jakarta.persistence.Id;
39 import jakarta.persistence.JoinColumn;
40 import jakarta.persistence.JoinTable;
41 import jakarta.persistence.Table;
42
43 import com.reallifedeveloper.tools.test.TestUtil;
44
45
46
47
48
49
50
51
52
53
54
55
56 @SuppressWarnings("PMD")
57 public final class DbUnitFlatXmlReader {
58
59 private static final Logger LOG = LoggerFactory.getLogger(DbUnitFlatXmlReader.class);
60
61 private final DocumentBuilder documentBuilder;
62 private final Set<Class<?>> classes = new HashSet<>();
63 private final List<Object> entities = new ArrayList<>();
64
65
66
67
68 public DbUnitFlatXmlReader() {
69 try {
70 DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
71 dbf.setValidating(false);
72 dbf.setNamespaceAware(true);
73 dbf.setFeature("http://xml.org/sax/features/namespaces", false);
74 dbf.setFeature("http://xml.org/sax/features/validation", false);
75 dbf.setFeature("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", false);
76 dbf.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false);
77 documentBuilder = dbf.newDocumentBuilder();
78 } catch (ParserConfigurationException e) {
79 throw new IllegalStateException("Unexpected problem creating XML parser", e);
80 }
81 }
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96 public <T, ID extends Serializable> void read(String resourceName, JpaRepository<T, ID> repository, Class<T> entityType,
97 Class<ID> primaryKeyType) throws IOException, SAXException {
98 try (InputStream in = DbUnitFlatXmlReader.class.getResourceAsStream(resourceName)) {
99 if (in == null) {
100 throw new FileNotFoundException(resourceName);
101 }
102 Document doc = documentBuilder.parse(in);
103 Element dataset = doc.getDocumentElement();
104 NodeList tableRows = dataset.getChildNodes();
105
106 LOG.info("Reading from {}", resourceName);
107 for (int i = 0; i < tableRows.getLength(); i++) {
108 Node tableRowNode = (@NonNull Node) tableRows.item(i);
109 if (tableRowNode.getNodeType() == Node.ELEMENT_NODE) {
110 Element tableRow = (Element) tableRowNode;
111 String tableName = tableRow.getNodeName();
112 if (tableName.equalsIgnoreCase(getTableName(entityType))) {
113 handleTableRow(tableRow, entityType, primaryKeyType, repository);
114 } else {
115 handlePotentialJoinTable(tableRowNode, tableName);
116 }
117 }
118 }
119 } catch (ReflectiveOperationException | SecurityException e) {
120 throw new IllegalStateException("Unexpected problem reading XML file from '" + resourceName + "'", e);
121 }
122 }
123
124 private <T, ID extends Serializable> void handleTableRow(Element tableRow, Class<T> entityType, Class<ID> primaryKeyType,
125 JpaRepository<T, ID> repository) throws ReflectiveOperationException {
126 T entity = createEntity(entityType);
127 NamedNodeMap attributes = (@NonNull NamedNodeMap) tableRow.getAttributes();
128 for (int j = 0; j < attributes.getLength(); j++) {
129 Node attribute = (@NonNull Node) attributes.item(j);
130 String fieldName = getFieldName(attribute.getNodeName(), entityType);
131 String attributeValue = attribute.getNodeValue();
132 setField(entity, fieldName, attributeValue, primaryKeyType);
133 }
134 entity = repository.save(entity);
135 entities.add(entity);
136 classes.add(entity.getClass());
137 }
138
139 private void handlePotentialJoinTable(Node tableRow, String tableName)
140 throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
141 joinTableField(tableName).ifPresent(joinTableField -> {
142 joinTableField.setAccessible(true);
143 ParameterizedType parameterizedType = (ParameterizedType) joinTableField.getGenericType();
144 Class<?> targetType = (Class<?>) parameterizedType.getActualTypeArguments()[0];
145 JoinTable joinTable = joinTableField.getAnnotation(JoinTable.class);
146 assert joinTable != null : "JoinTable annotation should be present when the joinTableField method returns a non-empty value";
147 for (JoinColumn joinColumn : joinTable.joinColumns()) {
148 for (JoinColumn inverseJoinColumn : joinTable.inverseJoinColumns()) {
149 addEntityFromJoinTable(tableRow, joinTableField, targetType, joinColumn, inverseJoinColumn);
150 }
151 }
152 });
153 }
154
155 private Optional<Field> joinTableField(String tableName) {
156 for (Class<?> c : classes) {
157 for (Field field : c.getDeclaredFields()) {
158 JoinTable joinTable = field.getAnnotation(JoinTable.class);
159 if (joinTable != null && tableName.equalsIgnoreCase(joinTable.name())) {
160 return Optional.of(field);
161 }
162 }
163 }
164 return Optional.empty();
165 }
166
167 private void addEntityFromJoinTable(Node tableRow, Field joinTableField, Class<?> targetType, JoinColumn joinColumn,
168 JoinColumn inverseJoinColumn) {
169 NamedNodeMap attributes = tableRow.getAttributes();
170 String lhsPrimaryKey = null;
171 String rhsPrimaryKey = null;
172 for (int j = 0; j < attributes.getLength(); j++) {
173 Node attribute = attributes.item(j);
174 if (attribute.getNodeName().equalsIgnoreCase(joinColumn.name())) {
175 lhsPrimaryKey = attribute.getNodeValue();
176 } else if (attribute.getNodeName().equalsIgnoreCase(inverseJoinColumn.name())) {
177 rhsPrimaryKey = attribute.getNodeValue();
178 }
179 }
180 Object lhs = findEntity(lhsPrimaryKey, joinTableField.getDeclaringClass());
181 Object rhs = findEntity(rhsPrimaryKey, targetType);
182 try {
183 Method add = joinTableField.getType().getMethod("add", Object.class);
184 add.invoke(joinTableField.get(lhs), rhs);
185 } catch (NoSuchMethodException e) {
186 throw new IllegalStateException("Method 'add' not found -- @JoinTable annotation should be on a Collection", e);
187 } catch (IllegalAccessException | InvocationTargetException e) {
188 throw new IllegalStateException("Unexpected problem", e);
189 }
190 }
191
192 private <T> String getTableName(Class<T> entityType) {
193 Table table = entityType.getAnnotation(Table.class);
194 if (table == null) {
195 return entityType.getSimpleName();
196 } else {
197 return table.name();
198 }
199 }
200
201 private <T> String getFieldName(String attributeName, Class<T> entityType) {
202 for (Field field : entityType.getDeclaredFields()) {
203 if (checkFieldName(attributeName, field)) {
204 return field.getName();
205 }
206 }
207 if (entityType.getSuperclass() == null) {
208 throw new IllegalArgumentException("Cannot find any field matching attribute '" + attributeName + "' for " + entityType);
209 } else {
210 return getFieldName(attributeName, entityType.getSuperclass());
211 }
212 }
213
214 private boolean checkFieldName(String attributeName, Field field) {
215 Column column = field.getAnnotation(Column.class);
216 if (column == null || column.name() == null) {
217 JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
218 if (joinColumn == null || joinColumn.name() == null) {
219 return field.getName().equalsIgnoreCase(attributeName);
220 } else {
221 return joinColumn.name().equalsIgnoreCase(attributeName);
222 }
223 } else {
224 return column.name().equalsIgnoreCase(attributeName);
225 }
226 }
227
228 private <T> T createEntity(Class<T> entityType) throws ReflectiveOperationException {
229 Constructor<T> constructor = entityType.getDeclaredConstructor();
230 constructor.setAccessible(true);
231 return constructor.newInstance();
232 }
233
234 private <T, ID> void setField(T entity, String fieldName, @Nullable String attributeValue, Class<ID> primaryKeyType)
235 throws ReflectiveOperationException {
236 Field field = getField(entity, fieldName);
237 field.setAccessible(true);
238 Object fieldValue = createObjectFromString(attributeValue, field, primaryKeyType);
239 field.set(entity, fieldValue);
240 }
241
242 private Field getField(Object entity, String fieldName) throws NoSuchFieldException {
243 Class<?> entityType = entity.getClass();
244 while (entityType != null) {
245 for (Field field : entityType.getDeclaredFields()) {
246 if (field.getName().equalsIgnoreCase(fieldName)) {
247 return field;
248 }
249 }
250 entityType = entityType.getSuperclass();
251 }
252 throw new NoSuchFieldException(fieldName);
253 }
254
255 private Object createObjectFromString(String s, Field field, Class<?> primaryKeyType) {
256 Class<?> type;
257 if (field.getAnnotation(Id.class) != null) {
258 type = primaryKeyType;
259 } else {
260 type = field.getType();
261 }
262 return createObjectFromString(s, type);
263 }
264
265 private Object createObjectFromString(String s, Class<?> type) {
266 if (type == Byte.class) {
267 return Byte.parseByte(s);
268 } else if (type == Short.class) {
269 return Short.parseShort(s);
270 } else if (type == Integer.class) {
271 return Integer.parseInt(s);
272 } else if (type == Long.class) {
273 return Long.parseLong(s);
274 } else if (type == Float.class) {
275 return Float.parseFloat(s);
276 } else if (type == Double.class) {
277 return Double.parseDouble(s);
278 } else if (type == Boolean.class) {
279 return Boolean.parseBoolean(s);
280 } else if (type == Character.class) {
281 return s.charAt(0);
282 } else if (type == String.class) {
283 return s;
284 } else if (type == Date.class) {
285 return TestUtil.parseDate(s);
286 } else if (type == BigDecimal.class) {
287 return new BigDecimal(s);
288 } else if (type == BigInteger.class) {
289 return new BigInteger(s);
290 } else {
291 return findEntity(s, type);
292 }
293 }
294
295 @SuppressWarnings({ "rawtypes", "unchecked" })
296 private Object findEntity(String strId, Class<?> entityType) {
297 if (entityType.isEnum()) {
298 Class<? extends Enum> enumType = (Class<? extends Enum>) entityType;
299 return Enum.valueOf(enumType, strId);
300 }
301 for (Object entity : entities) {
302 if (entity.getClass().equals(entityType)) {
303 Field idField = getIdField(entity);
304 idField.setAccessible(true);
305 try {
306 Object id = idField.get(entity);
307 if (id != null && id.equals(createObjectFromString(strId, id.getClass()))) {
308 return entity;
309 }
310 } catch (IllegalAccessException e) {
311 throw new IllegalStateException("Unexpected problem looking up entity of " + entityType + " with primary key " + strId,
312 e);
313 }
314 }
315 }
316 throw new IllegalArgumentException("Entity of " + entityType + " with primary key " + strId + " not found");
317 }
318
319 private Field getIdField(Object entity) {
320 Class<?> entityType = entity.getClass();
321 while (entityType != null) {
322 for (Field field : entityType.getDeclaredFields()) {
323 if (field.getDeclaredAnnotation(Id.class) != null) {
324 return field;
325 }
326 }
327 entityType = entityType.getSuperclass();
328 }
329 throw new IllegalStateException("Id field not found for entity " + entity);
330 }
331
332 }