1 package com.reallifedeveloper.tools.test.database;
2
3 import java.io.Serializable;
4 import java.lang.reflect.Constructor;
5 import java.lang.reflect.Field;
6 import java.lang.reflect.ParameterizedType;
7 import java.lang.reflect.Type;
8 import java.math.BigDecimal;
9 import java.math.BigInteger;
10 import java.time.LocalDate;
11 import java.time.LocalDateTime;
12 import java.time.ZonedDateTime;
13 import java.util.ArrayList;
14 import java.util.Arrays;
15 import java.util.Collection;
16 import java.util.Date;
17 import java.util.HashSet;
18 import java.util.List;
19 import java.util.Map;
20 import java.util.Optional;
21 import java.util.Set;
22 import java.util.UUID;
23
24 import org.checkerframework.checker.nullness.qual.Nullable;
25 import org.slf4j.Logger;
26 import org.slf4j.LoggerFactory;
27 import org.springframework.data.repository.CrudRepository;
28
29 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
30 import jakarta.persistence.Embeddable;
31 import jakarta.persistence.Entity;
32 import jakarta.persistence.Id;
33 import jakarta.persistence.JoinColumn;
34 import jakarta.persistence.JoinTable;
35 import jakarta.persistence.OneToMany;
36 import jakarta.persistence.OneToOne;
37 import lombok.Getter;
38
39 import com.reallifedeveloper.tools.test.TestUtil;
40
41
42
43
44
45
46
47
48
49
50
51
52 @Getter
53 @SuppressWarnings("PMD")
54 @SuppressFBWarnings(value = { "CRLF_INJECTION_LOGS", "IMPROPER_UNICODE" })
55 public class CrudRepositoryWriter {
56
57 private static final Logger LOG = LoggerFactory.getLogger(CrudRepositoryWriter.class);
58
59 private final Set<Class<?>> classes = new HashSet<>();
60 private final List<Object> entities = new ArrayList<>();
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79 public <T, E, ID extends Serializable> boolean writeEntity(DbTableRow tableRow, Class<T> repositoryEntityType,
80 @Nullable Class<E> entityType, CrudRepository<T, ID> repository, String tableName) throws ReflectiveOperationException {
81 if (entityType == null) {
82 return false;
83 }
84 if (entityType.getAnnotation(Embeddable.class) != null) {
85 return writeEmbeddable(tableRow, repositoryEntityType, entityType, repository, tableName);
86 }
87 if (entityType.getAnnotation(Entity.class) == null || !JpaUtil.getTableName(entityType).equalsIgnoreCase(tableName)) {
88 return false;
89 }
90 E entity = createEntity(entityType);
91 for (DbTableField column : tableRow.columns()) {
92 String fieldName = JpaUtil.getFieldName(column.name(), entityType);
93 setField(entity, fieldName, column.value());
94 }
95 LOG.debug("Saving entity {}", entity);
96 entities.add(entity);
97 classes.add(entity.getClass());
98 if (entityType.equals(repositoryEntityType)) {
99 T entityToSave = repositoryEntityType.cast(entity);
100 repository.save(entityToSave);
101 }
102 return true;
103 }
104
105 @SuppressWarnings("UnusedVariable")
106 private <T, E, ID extends Serializable> boolean writeEmbeddable(DbTableRow tableRow, Class<T> repositoryEntityType,
107 @Nullable Class<E> entityType, CrudRepository<T, ID> repository, String tableName) {
108 LOG.debug("Saving embeddable {}", tableRow);
109 throw new UnsupportedOperationException("writeEmbeddable not yet implemented");
110 }
111
112
113
114
115
116
117
118 public void addEntitiesFromJoinTable(DbTableRow tableRow, String joinTtableName) {
119 joinTableField(joinTtableName).ifPresent(joinTableField -> {
120 joinTableField.setAccessible(true);
121 ParameterizedType parameterizedType = (ParameterizedType) joinTableField.getGenericType();
122 Class<?> targetType = (Class<?>) parameterizedType.getActualTypeArguments()[0];
123 JoinTable joinTable = joinTableField.getAnnotation(JoinTable.class);
124 assert joinTable != null : "JoinTable annotation should be present when the joinTableField method returns a non-empty value";
125 for (JoinColumn joinColumn : joinTable.joinColumns()) {
126 for (JoinColumn inverseJoinColumn : joinTable.inverseJoinColumns()) {
127 addEntityFromJoinTable(tableRow, joinTableField, targetType, joinColumn, inverseJoinColumn);
128 }
129 }
130 });
131 }
132
133
134
135
136
137
138 public void fillReferencesBetweenEntities() throws ReflectiveOperationException {
139 for (Object entity : entities) {
140 for (Field field : entity.getClass().getDeclaredFields()) {
141 field.setAccessible(true);
142 OneToOne oneToOne = field.getAnnotation(OneToOne.class);
143 if (oneToOne != null) {
144 handleOneToOne(entity, field, oneToOne);
145 }
146 OneToMany oneToMany = field.getAnnotation(OneToMany.class);
147 if (oneToMany != null) {
148 handleOneToMany(entity, field, oneToMany);
149 }
150 }
151 }
152 }
153
154 private void handleOneToOne(Object entity, Field field, OneToOne oneToOne) throws IllegalAccessException, NoSuchFieldException {
155 String mappedBy = oneToOne.mappedBy();
156 if (mappedBy == null || mappedBy.isEmpty()) {
157 JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
158 if (joinColumn != null) {
159 mappedBy = JpaUtil.getFieldName(joinColumn.name(), field.getType());
160 }
161 }
162 if (mappedBy == null || mappedBy.isEmpty()) {
163 throw new IllegalStateException("OneToOne field " + entity.getClass().getName() + "." + field.getName()
164 + " has no mappedBy and no JoinColumn annotation");
165 }
166 Object id = JpaUtil.getIdValue(entity);
167 List<?> entitiesToMap = findEntitiesByClassAndField(field.getType(), mappedBy, id);
168 if (entitiesToMap.size() > 1) {
169 throw new IllegalStateException("Found multiple candidates for OneToOne mapping: entity=" + entity + ", field={}" + field);
170 }
171 Object value = entitiesToMap.isEmpty() ? null : entitiesToMap.get(0);
172 LOG.debug("Setting OneToOne field {} to {}", entity.getClass().getName() + "." + field.getName(), value);
173 field.set(entity, value);
174 }
175
176 private void handleOneToMany(Object entity, Field field, OneToMany oneToMany) throws IllegalAccessException, NoSuchFieldException {
177 Class<?> collectionType = field.getType();
178 if (Collection.class.isAssignableFrom(collectionType)) {
179 saveEntityInCollection(entity, field, oneToMany);
180 } else if (Map.class.isAssignableFrom(collectionType)) {
181 saveEntityInMap(entity, field, oneToMany);
182 }
183 }
184
185 private void saveEntityInCollection(Object entity, Field field, OneToMany oneToMany)
186 throws IllegalAccessException, NoSuchFieldException {
187 LOG.trace("saveEntityInCollection: entity={}, field={}, oneToMany={}", entity, field, oneToMany);
188 LOG.trace("Not yet implemented");
189 }
190
191 private void saveEntityInMap(Object entity, Field field, OneToMany oneToMany) throws IllegalAccessException, NoSuchFieldException {
192 LOG.trace("saveEntityInMap: entity={}, field={}, oneToMany={}", entity, field, oneToMany);
193 ParameterizedType parameterizedType = (ParameterizedType) field.getGenericType();
194 Type[] targetTypes = parameterizedType.getActualTypeArguments();
195 Class<?> targetClass = getClass(targetTypes[1].getTypeName());
196 String mappedBy = oneToMany.mappedBy();
197 if (mappedBy == null || mappedBy.isEmpty()) {
198 JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
199 if (joinColumn != null) {
200 mappedBy = JpaUtil.getFieldName(joinColumn.name(), targetClass);
201 }
202 }
203 if (mappedBy == null || mappedBy.isEmpty()) {
204 throw new IllegalStateException("OneToMany field " + entity.getClass().getName() + "." + field.getName()
205 + " has no mappedBy and no JoinColumn annotation");
206 }
207 Object id = JpaUtil.getIdValue(entity);
208 List<?> entitiesToMap = findEntitiesByClassAndField(targetClass, mappedBy, id);
209 JpaUtil.addEntitiesToMapField(field, entity, entitiesToMap);
210 }
211
212 private Class<?> getClass(String className) {
213 try {
214 return Class.forName(className);
215 } catch (ClassNotFoundException e) {
216 throw new IllegalStateException("Class " + className + " not found", e);
217 }
218 }
219
220 private <T> List<T> findEntitiesByClassAndField(Class<T> entityClass, String fieldName, Object value)
221 throws IllegalAccessException, NoSuchFieldException {
222 LOG.trace("Finding entities by class={}, field={} and value={}", entityClass, fieldName, value);
223 List<T> foundEntities = new ArrayList<>();
224 for (T entity : entitiesOfType(entityClass)) {
225 Field field = entity.getClass().getDeclaredField(fieldName);
226 field.setAccessible(true);
227
228 Object fieldValue = field.get(entity);
229 if (fieldValue == null) {
230 continue;
231 }
232 if (fieldValue.equals(value)) {
233 foundEntities.add(entity);
234 } else if (fieldValue.getClass().getAnnotation(Entity.class) != null && JpaUtil.getIdValue(fieldValue).equals(value)) {
235 foundEntities.add(entity);
236 }
237 }
238 return foundEntities;
239 }
240
241 @SuppressWarnings("unchecked")
242 private <T> List<T> entitiesOfType(Class<T> entityType) {
243
244 return (List<T>) entities.stream().filter(entity -> entity.getClass().equals(entityType)).toList();
245 }
246
247 private Optional<Field> joinTableField(String tableName) {
248 for (Class<?> c : classes) {
249 for (Field field : c.getDeclaredFields()) {
250 JoinTable joinTable = field.getAnnotation(JoinTable.class);
251 if (joinTable != null && tableName.equalsIgnoreCase(joinTable.name())) {
252 return Optional.of(field);
253 }
254 }
255 }
256 return Optional.empty();
257 }
258
259 private void addEntityFromJoinTable(DbTableRow tableRow, Field joinTableField, Class<?> targetType, JoinColumn joinColumn,
260 JoinColumn inverseJoinColumn) {
261 String lhsPrimaryKey = null;
262 String rhsPrimaryKey = null;
263 for (DbTableField column : tableRow.columns()) {
264 if (column.name().equalsIgnoreCase(joinColumn.name())) {
265 lhsPrimaryKey = column.value();
266 } else if (column.name().equalsIgnoreCase(inverseJoinColumn.name())) {
267 rhsPrimaryKey = column.value();
268 }
269 }
270 if (lhsPrimaryKey == null || rhsPrimaryKey == null) {
271 throw new IllegalStateException("Failed to find join table: missing attribute in DBUnit XML file: '" + joinColumn.name()
272 + "' or '" + inverseJoinColumn.name() + "'");
273 }
274 Object lhs = findEntity(lhsPrimaryKey, joinTableField.getDeclaringClass());
275 Object rhs = findEntity(rhsPrimaryKey, targetType);
276 JpaUtil.addObjectToCollectionField(joinTableField, lhs, rhs);
277 }
278
279 private <T> T createEntity(Class<T> entityType) throws ReflectiveOperationException {
280 Constructor<T> constructor = entityType.getDeclaredConstructor();
281 constructor.setAccessible(true);
282 return constructor.newInstance();
283 }
284
285 private <T> void setField(T entity, String fieldName, String attributeValue) throws ReflectiveOperationException {
286 Field field = JpaUtil.getField(entity, fieldName);
287 field.setAccessible(true);
288 Object fieldValue = createObjectFromString(attributeValue, field, JpaUtil.getPrimaryKeyType(entity.getClass()));
289 LOG.trace("Setting field {} to {}", fieldName, fieldValue);
290 field.set(entity, fieldValue);
291 if (fieldValue != null && fieldValue.getClass().getAnnotation(Entity.class) != null) {
292 potentiallyAddValueToCollection(fieldValue, fieldName, entity);
293 }
294 }
295
296 private <T> void potentiallyAddValueToCollection(Object entity, String fieldName, T value) {
297 for (Field field : entity.getClass().getDeclaredFields()) {
298 field.setAccessible(true);
299 OneToMany oneToMany = field.getAnnotation(OneToMany.class);
300 if (oneToMany == null) {
301 continue;
302 }
303 if (oneToMany.mappedBy().equals(fieldName)) {
304 JpaUtil.addObjectToCollectionField(field, entity, value);
305 }
306 }
307 }
308
309 private @Nullable Object createObjectFromString(String s, Field field, Class<?> primaryKeyType) {
310 Class<?> type;
311 if (field.getAnnotation(Id.class) != null) {
312 type = primaryKeyType;
313 } else {
314 type = field.getType();
315 }
316 return createObjectFromString(s, type);
317 }
318
319 @SuppressWarnings("checkstyle:noReturnNull")
320 private @Nullable Object createObjectFromString(String s, Class<?> type) {
321 if (s == null || s.isEmpty()) {
322 return null;
323 }
324 if (type == Byte.class) {
325 return Byte.parseByte(s);
326 } else if (type == Short.class) {
327 return Short.parseShort(s);
328 } else if (type == Integer.class) {
329 return Integer.parseInt(s);
330 } else if (type == Long.class) {
331 return Long.parseLong(s);
332 } else if (type == Float.class) {
333 return Float.parseFloat(s);
334 } else if (type == Double.class) {
335 return Double.parseDouble(s);
336 } else if (type == Boolean.class) {
337 return Boolean.parseBoolean(s);
338 } else if (type == Character.class) {
339 return s.charAt(0);
340 } else if (type == String.class) {
341 return s;
342 } else if (type == Date.class) {
343 return TestUtil.parseDate(s);
344 } else if (type == LocalDate.class) {
345 return LocalDate.parse(s);
346 } else if (type == LocalDateTime.class) {
347 return LocalDateTime.parse(s);
348 } else if (type == ZonedDateTime.class) {
349 return ZonedDateTime.parse(s);
350 } else if (type == BigDecimal.class) {
351 return new BigDecimal(s);
352 } else if (type == BigInteger.class) {
353 return new BigInteger(s);
354 } else if (type == UUID.class) {
355 return UUID.fromString(s);
356 } else if (type == List.class) {
357 return Arrays.asList(s.replaceAll("[{}]", "").split(","));
358 } else {
359 return findEntity(s, type);
360 }
361 }
362
363 @SuppressWarnings({ "rawtypes", "unchecked" })
364 private Object findEntity(String strId, Class<?> entityType) {
365 if (entityType.isEnum()) {
366 Class<? extends Enum> enumType = (Class<? extends Enum>) entityType;
367 return Enum.valueOf(enumType, strId);
368 }
369 for (Object entity : entities) {
370 if (entity.getClass().equals(entityType)) {
371 Field idField = JpaUtil.getIdField(entity);
372 idField.setAccessible(true);
373 try {
374 Object id = idField.get(entity);
375 if (id != null && id.equals(createObjectFromString(strId, id.getClass()))) {
376 return entity;
377 }
378 } catch (IllegalAccessException e) {
379 throw new IllegalStateException("Unexpected problem looking up entity of " + entityType + " with primary key " + strId,
380 e);
381 }
382 }
383 }
384 throw new IllegalArgumentException("Entity of " + entityType + " with primary key " + strId + " not found");
385 }
386
387
388
389
390
391
392 public interface DbTableRow {
393
394
395
396
397
398 List<DbTableField> columns();
399 }
400
401
402
403
404
405
406
407 public record DbTableField(String name, String value) {
408 }
409 }