View Javadoc
1   package com.reallifedeveloper.tools.test.database;
2   
3   import java.io.Serializable;
4   import java.lang.reflect.Constructor;
5   import java.lang.reflect.Field;
6   import java.lang.reflect.ParameterizedType;
7   import java.lang.reflect.Type;
8   import java.math.BigDecimal;
9   import java.math.BigInteger;
10  import java.time.LocalDate;
11  import java.time.LocalDateTime;
12  import java.time.ZonedDateTime;
13  import java.util.ArrayList;
14  import java.util.Arrays;
15  import java.util.Collection;
16  import java.util.Date;
17  import java.util.HashSet;
18  import java.util.List;
19  import java.util.Map;
20  import java.util.Optional;
21  import java.util.Set;
22  import java.util.UUID;
23  
24  import org.checkerframework.checker.nullness.qual.Nullable;
25  import org.slf4j.Logger;
26  import org.slf4j.LoggerFactory;
27  import org.springframework.data.repository.CrudRepository;
28  
29  import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
30  import jakarta.persistence.Embeddable;
31  import jakarta.persistence.Entity;
32  import jakarta.persistence.Id;
33  import jakarta.persistence.JoinColumn;
34  import jakarta.persistence.JoinTable;
35  import jakarta.persistence.OneToMany;
36  import jakarta.persistence.OneToOne;
37  import lombok.Getter;
38  
39  import com.reallifedeveloper.tools.test.TestUtil;
40  
41  /**
42   * A helper class to write data into a {@link CrudRepository} from some data source, e.g., a CSV file, where each entity is represented by a
43   * {@link DbTableRow}.
44   * <p>
45   * This can be useful for inserting test data into a repository, irrespective of whether the repository connects to a real database or not.
46   * <p>
47   * TODO: The current implementation only has basic support for "to many" associations (there must be a &amp;JoinTable annotation on a field,
48   * with &amp;JoinColumn annotations), and for enums (an enum must be stored as a string).
49   *
50   * @author RealLifeDeveloper
51   */
52  @Getter
53  @SuppressWarnings("PMD")
54  @SuppressFBWarnings(value = { "CRLF_INJECTION_LOGS", "IMPROPER_UNICODE" })
55  public class CrudRepositoryWriter {
56  
57      private static final Logger LOG = LoggerFactory.getLogger(CrudRepositoryWriter.class);
58  
59      private final Set<Class<?>> classes = new HashSet<>();
60      private final List<Object> entities = new ArrayList<>();
61  
62      /**
63       * Creates a new entity based on data from a {@link DbTableRow} and writes it into a repository if appropriate.
64       * <p>
65       * This method may create entities that are not directly handled by the repository, in which case they are assumed to be related to some
66       * entity in the repository.
67       *
68       * @param <T>                  the type of entities in the repository
69       * @param <E>                  the type of entity being created
70       * @param <ID>                 the type of the primary key of the entities in the repository
71       * @param tableRow             the {@code TableRow} with the data to insert into the fields of the newly created entity
72       * @param repositoryEntityType the class object representing {@code T}, i.e., the type ofrepository entities
73       * @param entityType           the class object representing {@code E}, i.e., the type of entity being created, or {@code null}
74       * @param repository           the repository in which to insert the newly created entity
75       * @param tableName            the name of the database table where the entity should be stored
76       * @return {@code true} if an entity was created, no matter if it was saved in the repository, {@code false} otherwise
77       * @throws ReflectiveOperationException if some reflection operation failed creating the entity or setting is fields
78       */
79      public <T, E, ID extends Serializable> boolean writeEntity(DbTableRow tableRow, Class<T> repositoryEntityType,
80              @Nullable Class<E> entityType, CrudRepository<T, ID> repository, String tableName) throws ReflectiveOperationException {
81          if (entityType == null) {
82              return false;
83          }
84          if (entityType.getAnnotation(Embeddable.class) != null) {
85              return writeEmbeddable(tableRow, repositoryEntityType, entityType, repository, tableName);
86          }
87          if (entityType.getAnnotation(Entity.class) == null || !JpaUtil.getTableName(entityType).equalsIgnoreCase(tableName)) {
88              return false;
89          }
90          E entity = createEntity(entityType);
91          for (DbTableField column : tableRow.columns()) {
92              String fieldName = JpaUtil.getFieldName(column.name(), entityType);
93              setField(entity, fieldName, column.value());
94          }
95          LOG.debug("Saving entity {}", entity);
96          entities.add(entity);
97          classes.add(entity.getClass());
98          if (entityType.equals(repositoryEntityType)) {
99              T entityToSave = repositoryEntityType.cast(entity);
100             repository.save(entityToSave);
101         }
102         return true;
103     }
104 
105     @SuppressWarnings("UnusedVariable")
106     private <T, E, ID extends Serializable> boolean writeEmbeddable(DbTableRow tableRow, Class<T> repositoryEntityType,
107             @Nullable Class<E> entityType, CrudRepository<T, ID> repository, String tableName) {
108         LOG.debug("Saving embeddable {}", tableRow);
109         throw new UnsupportedOperationException("writeEmbeddable not yet implemented");
110     }
111 
112     /**
113      * Connects entities based on data in a join table.
114      *
115      * @param tableRow       the {@code TableRow} with the data for the join table
116      * @param joinTtableName the name of the join table to use to connect entities
117      */
118     public void addEntitiesFromJoinTable(DbTableRow tableRow, String joinTtableName) {
119         joinTableField(joinTtableName).ifPresent(joinTableField -> {
120             joinTableField.setAccessible(true);
121             ParameterizedType parameterizedType = (ParameterizedType) joinTableField.getGenericType();
122             Class<?> targetType = (Class<?>) parameterizedType.getActualTypeArguments()[0];
123             JoinTable joinTable = joinTableField.getAnnotation(JoinTable.class);
124             assert joinTable != null : "JoinTable annotation should be present when the joinTableField method returns a non-empty value";
125             for (JoinColumn joinColumn : joinTable.joinColumns()) {
126                 for (JoinColumn inverseJoinColumn : joinTable.inverseJoinColumns()) {
127                     addEntityFromJoinTable(tableRow, joinTableField, targetType, joinColumn, inverseJoinColumn);
128                 }
129             }
130         });
131     }
132 
133     /**
134      * Goes through all entities that have been saved, trying to fix missing associations.
135      *
136      * @throws ReflectiveOperationException if something went wrong using reflection to analyze the entities
137      */
138     public void fillReferencesBetweenEntities() throws ReflectiveOperationException {
139         for (Object entity : entities) {
140             for (Field field : entity.getClass().getDeclaredFields()) {
141                 field.setAccessible(true);
142                 OneToOne oneToOne = field.getAnnotation(OneToOne.class);
143                 if (oneToOne != null) {
144                     handleOneToOne(entity, field, oneToOne);
145                 }
146                 OneToMany oneToMany = field.getAnnotation(OneToMany.class);
147                 if (oneToMany != null) {
148                     handleOneToMany(entity, field, oneToMany);
149                 }
150             }
151         }
152     }
153 
154     private void handleOneToOne(Object entity, Field field, OneToOne oneToOne) throws IllegalAccessException, NoSuchFieldException {
155         String mappedBy = oneToOne.mappedBy();
156         if (mappedBy == null || mappedBy.isEmpty()) {
157             JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
158             if (joinColumn != null) {
159                 mappedBy = JpaUtil.getFieldName(joinColumn.name(), field.getType());
160             }
161         }
162         if (mappedBy == null || mappedBy.isEmpty()) {
163             throw new IllegalStateException("OneToOne field " + entity.getClass().getName() + "." + field.getName()
164                     + " has no mappedBy and no JoinColumn annotation");
165         }
166         Object id = JpaUtil.getIdValue(entity);
167         List<?> entitiesToMap = findEntitiesByClassAndField(field.getType(), mappedBy, id);
168         if (entitiesToMap.size() > 1) {
169             throw new IllegalStateException("Found multiple candidates for OneToOne mapping: entity=" + entity + ", field={}" + field);
170         }
171         Object value = entitiesToMap.isEmpty() ? null : entitiesToMap.get(0);
172         LOG.debug("Setting OneToOne field {} to {}", entity.getClass().getName() + "." + field.getName(), value);
173         field.set(entity, value);
174     }
175 
176     private void handleOneToMany(Object entity, Field field, OneToMany oneToMany) throws IllegalAccessException, NoSuchFieldException {
177         Class<?> collectionType = field.getType();
178         if (Collection.class.isAssignableFrom(collectionType)) {
179             saveEntityInCollection(entity, field, oneToMany);
180         } else if (Map.class.isAssignableFrom(collectionType)) {
181             saveEntityInMap(entity, field, oneToMany);
182         }
183     }
184 
185     private void saveEntityInCollection(Object entity, Field field, OneToMany oneToMany)
186             throws IllegalAccessException, NoSuchFieldException {
187         LOG.trace("saveEntityInCollection: entity={}, field={}, oneToMany={}", entity, field, oneToMany);
188         LOG.trace("Not yet implemented");
189     }
190 
191     private void saveEntityInMap(Object entity, Field field, OneToMany oneToMany) throws IllegalAccessException, NoSuchFieldException {
192         LOG.trace("saveEntityInMap: entity={}, field={}, oneToMany={}", entity, field, oneToMany);
193         ParameterizedType parameterizedType = (ParameterizedType) field.getGenericType();
194         Type[] targetTypes = parameterizedType.getActualTypeArguments();
195         Class<?> targetClass = getClass(targetTypes[1].getTypeName());
196         String mappedBy = oneToMany.mappedBy();
197         if (mappedBy == null || mappedBy.isEmpty()) {
198             JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
199             if (joinColumn != null) {
200                 mappedBy = JpaUtil.getFieldName(joinColumn.name(), targetClass);
201             }
202         }
203         if (mappedBy == null || mappedBy.isEmpty()) {
204             throw new IllegalStateException("OneToMany field " + entity.getClass().getName() + "." + field.getName()
205                     + " has no mappedBy and no JoinColumn annotation");
206         }
207         Object id = JpaUtil.getIdValue(entity);
208         List<?> entitiesToMap = findEntitiesByClassAndField(targetClass, mappedBy, id);
209         JpaUtil.addEntitiesToMapField(field, entity, entitiesToMap);
210     }
211 
212     private Class<?> getClass(String className) {
213         try {
214             return Class.forName(className);
215         } catch (ClassNotFoundException e) {
216             throw new IllegalStateException("Class " + className + " not found", e);
217         }
218     }
219 
220     private <T> List<T> findEntitiesByClassAndField(Class<T> entityClass, String fieldName, Object value)
221             throws IllegalAccessException, NoSuchFieldException {
222         LOG.trace("Finding entities by class={}, field={} and value={}", entityClass, fieldName, value);
223         List<T> foundEntities = new ArrayList<>();
224         for (T entity : entitiesOfType(entityClass)) {
225             Field field = entity.getClass().getDeclaredField(fieldName);
226             field.setAccessible(true);
227             // LOG.debug("{}.{}={}", entityClass.getName(), fieldName, field.get(entity));
228             Object fieldValue = field.get(entity);
229             if (fieldValue == null) {
230                 continue;
231             }
232             if (fieldValue.equals(value)) {
233                 foundEntities.add(entity);
234             } else if (fieldValue.getClass().getAnnotation(Entity.class) != null && JpaUtil.getIdValue(fieldValue).equals(value)) {
235                 foundEntities.add(entity);
236             }
237         }
238         return foundEntities;
239     }
240 
241     @SuppressWarnings("unchecked")
242     private <T> List<T> entitiesOfType(Class<T> entityType) {
243         // LOG.debug("Getting entities of type {}", entityType.getName());
244         return (List<T>) entities.stream().filter(entity -> entity.getClass().equals(entityType)).toList();
245     }
246 
247     private Optional<Field> joinTableField(String tableName) {
248         for (Class<?> c : classes) {
249             for (Field field : c.getDeclaredFields()) {
250                 JoinTable joinTable = field.getAnnotation(JoinTable.class);
251                 if (joinTable != null && tableName.equalsIgnoreCase(joinTable.name())) {
252                     return Optional.of(field);
253                 }
254             }
255         }
256         return Optional.empty();
257     }
258 
259     private void addEntityFromJoinTable(DbTableRow tableRow, Field joinTableField, Class<?> targetType, JoinColumn joinColumn,
260             JoinColumn inverseJoinColumn) {
261         String lhsPrimaryKey = null;
262         String rhsPrimaryKey = null;
263         for (DbTableField column : tableRow.columns()) {
264             if (column.name().equalsIgnoreCase(joinColumn.name())) {
265                 lhsPrimaryKey = column.value();
266             } else if (column.name().equalsIgnoreCase(inverseJoinColumn.name())) {
267                 rhsPrimaryKey = column.value();
268             }
269         }
270         if (lhsPrimaryKey == null || rhsPrimaryKey == null) {
271             throw new IllegalStateException("Failed to find join table: missing attribute in DBUnit XML file: '" + joinColumn.name()
272                     + "' or '" + inverseJoinColumn.name() + "'");
273         }
274         Object lhs = findEntity(lhsPrimaryKey, joinTableField.getDeclaringClass());
275         Object rhs = findEntity(rhsPrimaryKey, targetType);
276         JpaUtil.addObjectToCollectionField(joinTableField, lhs, rhs);
277     }
278 
279     private <T> T createEntity(Class<T> entityType) throws ReflectiveOperationException {
280         Constructor<T> constructor = entityType.getDeclaredConstructor();
281         constructor.setAccessible(true);
282         return constructor.newInstance();
283     }
284 
285     private <T> void setField(T entity, String fieldName, String attributeValue) throws ReflectiveOperationException {
286         Field field = JpaUtil.getField(entity, fieldName);
287         field.setAccessible(true);
288         Object fieldValue = createObjectFromString(attributeValue, field, JpaUtil.getPrimaryKeyType(entity.getClass()));
289         LOG.trace("Setting field {} to {}", fieldName, fieldValue);
290         field.set(entity, fieldValue);
291         if (fieldValue != null && fieldValue.getClass().getAnnotation(Entity.class) != null) {
292             potentiallyAddValueToCollection(fieldValue, fieldName, entity);
293         }
294     }
295 
296     private <T> void potentiallyAddValueToCollection(Object entity, String fieldName, T value) {
297         for (Field field : entity.getClass().getDeclaredFields()) {
298             field.setAccessible(true);
299             OneToMany oneToMany = field.getAnnotation(OneToMany.class);
300             if (oneToMany == null) {
301                 continue;
302             }
303             if (oneToMany.mappedBy().equals(fieldName)) {
304                 JpaUtil.addObjectToCollectionField(field, entity, value);
305             }
306         }
307     }
308 
309     private @Nullable Object createObjectFromString(String s, Field field, Class<?> primaryKeyType) {
310         Class<?> type;
311         if (field.getAnnotation(Id.class) != null) {
312             type = primaryKeyType;
313         } else {
314             type = field.getType();
315         }
316         return createObjectFromString(s, type);
317     }
318 
319     @SuppressWarnings("checkstyle:noReturnNull")
320     private @Nullable Object createObjectFromString(String s, Class<?> type) {
321         if (s == null || s.isEmpty()) {
322             return null;
323         }
324         if (type == Byte.class) {
325             return Byte.parseByte(s);
326         } else if (type == Short.class) {
327             return Short.parseShort(s);
328         } else if (type == Integer.class) {
329             return Integer.parseInt(s);
330         } else if (type == Long.class) {
331             return Long.parseLong(s);
332         } else if (type == Float.class) {
333             return Float.parseFloat(s);
334         } else if (type == Double.class) {
335             return Double.parseDouble(s);
336         } else if (type == Boolean.class) {
337             return Boolean.parseBoolean(s);
338         } else if (type == Character.class) {
339             return s.charAt(0);
340         } else if (type == String.class) {
341             return s;
342         } else if (type == Date.class) {
343             return TestUtil.parseDate(s);
344         } else if (type == LocalDate.class) {
345             return LocalDate.parse(s);
346         } else if (type == LocalDateTime.class) {
347             return LocalDateTime.parse(s);
348         } else if (type == ZonedDateTime.class) {
349             return ZonedDateTime.parse(s);
350         } else if (type == BigDecimal.class) {
351             return new BigDecimal(s);
352         } else if (type == BigInteger.class) {
353             return new BigInteger(s);
354         } else if (type == UUID.class) {
355             return UUID.fromString(s);
356         } else if (type == List.class) {
357             return Arrays.asList(s.replaceAll("[{}]", "").split(","));
358         } else {
359             return findEntity(s, type);
360         }
361     }
362 
363     @SuppressWarnings({ "rawtypes", "unchecked" })
364     private Object findEntity(String strId, Class<?> entityType) {
365         if (entityType.isEnum()) {
366             Class<? extends Enum> enumType = (Class<? extends Enum>) entityType;
367             return Enum.valueOf(enumType, strId);
368         }
369         for (Object entity : entities) {
370             if (entity.getClass().equals(entityType)) {
371                 Field idField = JpaUtil.getIdField(entity);
372                 idField.setAccessible(true);
373                 try {
374                     Object id = idField.get(entity);
375                     if (id != null && id.equals(createObjectFromString(strId, id.getClass()))) {
376                         return entity;
377                     }
378                 } catch (IllegalAccessException e) {
379                     throw new IllegalStateException("Unexpected problem looking up entity of " + entityType + " with primary key " + strId,
380                             e);
381                 }
382             }
383         }
384         throw new IllegalArgumentException("Entity of " + entityType + " with primary key " + strId + " not found");
385     }
386 
387     /**
388      * Represents one row of data from the database.
389      *
390      * @author RealLifeDeveloper
391      */
392     public interface DbTableRow {
393         /**
394          * Gives the fields of this row.
395          *
396          * @return the fields
397          */
398         List<DbTableField> columns();
399     }
400 
401     /**
402      * Represents the value of a single field in the database.
403      *
404      * @param name  the name of the database column
405      * @param value the value of the field
406      */
407     public record DbTableField(String name, String value) {
408     }
409 }