CrudRepositoryWriter.java
package com.reallifedeveloper.tools.test.database;
import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.repository.CrudRepository;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import jakarta.persistence.Embeddable;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.JoinTable;
import jakarta.persistence.OneToMany;
import jakarta.persistence.OneToOne;
import lombok.Getter;
import com.reallifedeveloper.tools.test.TestUtil;
/**
* A helper class to write data into a {@link CrudRepository} from some data source, e.g., a CSV file, where each entity is represented by a
* {@link DbTableRow}.
* <p>
* This can be useful for inserting test data into a repository, irrespective of whether the repository connects to a real database or not.
* <p>
* TODO: The current implementation only has basic support for "to many" associations (there must be a &JoinTable annotation on a field,
* with &JoinColumn annotations), and for enums (an enum must be stored as a string).
*
* @author RealLifeDeveloper
*/
@Getter
@SuppressWarnings("PMD")
@SuppressFBWarnings(value = { "CRLF_INJECTION_LOGS", "IMPROPER_UNICODE" })
public class CrudRepositoryWriter {
private static final Logger LOG = LoggerFactory.getLogger(CrudRepositoryWriter.class);
private final Set<Class<?>> classes = new HashSet<>();
private final List<Object> entities = new ArrayList<>();
/**
* Creates a new entity based on data from a {@link DbTableRow} and writes it into a repository if appropriate.
* <p>
* This method may create entities that are not directly handled by the repository, in which case they are assumed to be related to some
* entity in the repository.
*
* @param <T> the type of entities in the repository
* @param <E> the type of entity being created
* @param <ID> the type of the primary key of the entities in the repository
* @param tableRow the {@code TableRow} with the data to insert into the fields of the newly created entity
* @param repositoryEntityType the class object representing {@code T}, i.e., the type ofrepository entities
* @param entityType the class object representing {@code E}, i.e., the type of entity being created, or {@code null}
* @param repository the repository in which to insert the newly created entity
* @param tableName the name of the database table where the entity should be stored
* @return {@code true} if an entity was created, no matter if it was saved in the repository, {@code false} otherwise
* @throws ReflectiveOperationException if some reflection operation failed creating the entity or setting is fields
*/
public <T, E, ID extends Serializable> boolean writeEntity(DbTableRow tableRow, Class<T> repositoryEntityType,
@Nullable Class<E> entityType, CrudRepository<T, ID> repository, String tableName) throws ReflectiveOperationException {
if (entityType == null) {
return false;
}
if (entityType.getAnnotation(Embeddable.class) != null) {
return writeEmbeddable(tableRow, repositoryEntityType, entityType, repository, tableName);
}
if (entityType.getAnnotation(Entity.class) == null || !JpaUtil.getTableName(entityType).equalsIgnoreCase(tableName)) {
return false;
}
E entity = createEntity(entityType);
for (DbTableField column : tableRow.columns()) {
String fieldName = JpaUtil.getFieldName(column.name(), entityType);
setField(entity, fieldName, column.value());
}
LOG.debug("Saving entity {}", entity);
entities.add(entity);
classes.add(entity.getClass());
if (entityType.equals(repositoryEntityType)) {
T entityToSave = repositoryEntityType.cast(entity);
repository.save(entityToSave);
}
return true;
}
@SuppressWarnings("UnusedVariable")
private <T, E, ID extends Serializable> boolean writeEmbeddable(DbTableRow tableRow, Class<T> repositoryEntityType,
@Nullable Class<E> entityType, CrudRepository<T, ID> repository, String tableName) {
LOG.debug("Saving embeddable {}", tableRow);
throw new UnsupportedOperationException("writeEmbeddable not yet implemented");
}
/**
* Connects entities based on data in a join table.
*
* @param tableRow the {@code TableRow} with the data for the join table
* @param joinTtableName the name of the join table to use to connect entities
*/
public void addEntitiesFromJoinTable(DbTableRow tableRow, String joinTtableName) {
joinTableField(joinTtableName).ifPresent(joinTableField -> {
joinTableField.setAccessible(true);
ParameterizedType parameterizedType = (ParameterizedType) joinTableField.getGenericType();
Class<?> targetType = (Class<?>) parameterizedType.getActualTypeArguments()[0];
JoinTable joinTable = joinTableField.getAnnotation(JoinTable.class);
assert joinTable != null : "JoinTable annotation should be present when the joinTableField method returns a non-empty value";
for (JoinColumn joinColumn : joinTable.joinColumns()) {
for (JoinColumn inverseJoinColumn : joinTable.inverseJoinColumns()) {
addEntityFromJoinTable(tableRow, joinTableField, targetType, joinColumn, inverseJoinColumn);
}
}
});
}
/**
* Goes through all entities that have been saved, trying to fix missing associations.
*
* @throws ReflectiveOperationException if something went wrong using reflection to analyze the entities
*/
public void fillReferencesBetweenEntities() throws ReflectiveOperationException {
for (Object entity : entities) {
for (Field field : entity.getClass().getDeclaredFields()) {
field.setAccessible(true);
OneToOne oneToOne = field.getAnnotation(OneToOne.class);
if (oneToOne != null) {
handleOneToOne(entity, field, oneToOne);
}
OneToMany oneToMany = field.getAnnotation(OneToMany.class);
if (oneToMany != null) {
handleOneToMany(entity, field, oneToMany);
}
}
}
}
private void handleOneToOne(Object entity, Field field, OneToOne oneToOne) throws IllegalAccessException, NoSuchFieldException {
String mappedBy = oneToOne.mappedBy();
if (mappedBy == null || mappedBy.isEmpty()) {
JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
if (joinColumn != null) {
mappedBy = JpaUtil.getFieldName(joinColumn.name(), field.getType());
}
}
if (mappedBy == null || mappedBy.isEmpty()) {
throw new IllegalStateException("OneToOne field " + entity.getClass().getName() + "." + field.getName()
+ " has no mappedBy and no JoinColumn annotation");
}
Object id = JpaUtil.getIdValue(entity);
List<?> entitiesToMap = findEntitiesByClassAndField(field.getType(), mappedBy, id);
if (entitiesToMap.size() > 1) {
throw new IllegalStateException("Found multiple candidates for OneToOne mapping: entity=" + entity + ", field={}" + field);
}
Object value = entitiesToMap.isEmpty() ? null : entitiesToMap.get(0);
LOG.debug("Setting OneToOne field {} to {}", entity.getClass().getName() + "." + field.getName(), value);
field.set(entity, value);
}
private void handleOneToMany(Object entity, Field field, OneToMany oneToMany) throws IllegalAccessException, NoSuchFieldException {
Class<?> collectionType = field.getType();
if (Collection.class.isAssignableFrom(collectionType)) {
saveEntityInCollection(entity, field, oneToMany);
} else if (Map.class.isAssignableFrom(collectionType)) {
saveEntityInMap(entity, field, oneToMany);
}
}
private void saveEntityInCollection(Object entity, Field field, OneToMany oneToMany)
throws IllegalAccessException, NoSuchFieldException {
LOG.trace("saveEntityInCollection: entity={}, field={}, oneToMany={}", entity, field, oneToMany);
LOG.trace("Not yet implemented");
}
private void saveEntityInMap(Object entity, Field field, OneToMany oneToMany) throws IllegalAccessException, NoSuchFieldException {
LOG.trace("saveEntityInMap: entity={}, field={}, oneToMany={}", entity, field, oneToMany);
ParameterizedType parameterizedType = (ParameterizedType) field.getGenericType();
Type[] targetTypes = parameterizedType.getActualTypeArguments();
Class<?> targetClass = getClass(targetTypes[1].getTypeName());
String mappedBy = oneToMany.mappedBy();
if (mappedBy == null || mappedBy.isEmpty()) {
JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
if (joinColumn != null) {
mappedBy = JpaUtil.getFieldName(joinColumn.name(), targetClass);
}
}
if (mappedBy == null || mappedBy.isEmpty()) {
throw new IllegalStateException("OneToMany field " + entity.getClass().getName() + "." + field.getName()
+ " has no mappedBy and no JoinColumn annotation");
}
Object id = JpaUtil.getIdValue(entity);
List<?> entitiesToMap = findEntitiesByClassAndField(targetClass, mappedBy, id);
JpaUtil.addEntitiesToMapField(field, entity, entitiesToMap);
}
private Class<?> getClass(String className) {
try {
return Class.forName(className);
} catch (ClassNotFoundException e) {
throw new IllegalStateException("Class " + className + " not found", e);
}
}
private <T> List<T> findEntitiesByClassAndField(Class<T> entityClass, String fieldName, Object value)
throws IllegalAccessException, NoSuchFieldException {
LOG.trace("Finding entities by class={}, field={} and value={}", entityClass, fieldName, value);
List<T> foundEntities = new ArrayList<>();
for (T entity : entitiesOfType(entityClass)) {
Field field = entity.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
// LOG.debug("{}.{}={}", entityClass.getName(), fieldName, field.get(entity));
Object fieldValue = field.get(entity);
if (fieldValue == null) {
continue;
}
if (fieldValue.equals(value)) {
foundEntities.add(entity);
} else if (fieldValue.getClass().getAnnotation(Entity.class) != null && JpaUtil.getIdValue(fieldValue).equals(value)) {
foundEntities.add(entity);
}
}
return foundEntities;
}
@SuppressWarnings("unchecked")
private <T> List<T> entitiesOfType(Class<T> entityType) {
// LOG.debug("Getting entities of type {}", entityType.getName());
return (List<T>) entities.stream().filter(entity -> entity.getClass().equals(entityType)).toList();
}
private Optional<Field> joinTableField(String tableName) {
for (Class<?> c : classes) {
for (Field field : c.getDeclaredFields()) {
JoinTable joinTable = field.getAnnotation(JoinTable.class);
if (joinTable != null && tableName.equalsIgnoreCase(joinTable.name())) {
return Optional.of(field);
}
}
}
return Optional.empty();
}
private void addEntityFromJoinTable(DbTableRow tableRow, Field joinTableField, Class<?> targetType, JoinColumn joinColumn,
JoinColumn inverseJoinColumn) {
String lhsPrimaryKey = null;
String rhsPrimaryKey = null;
for (DbTableField column : tableRow.columns()) {
if (column.name().equalsIgnoreCase(joinColumn.name())) {
lhsPrimaryKey = column.value();
} else if (column.name().equalsIgnoreCase(inverseJoinColumn.name())) {
rhsPrimaryKey = column.value();
}
}
if (lhsPrimaryKey == null || rhsPrimaryKey == null) {
throw new IllegalStateException("Failed to find join table: missing attribute in DBUnit XML file: '" + joinColumn.name()
+ "' or '" + inverseJoinColumn.name() + "'");
}
Object lhs = findEntity(lhsPrimaryKey, joinTableField.getDeclaringClass());
Object rhs = findEntity(rhsPrimaryKey, targetType);
JpaUtil.addObjectToCollectionField(joinTableField, lhs, rhs);
}
private <T> T createEntity(Class<T> entityType) throws ReflectiveOperationException {
Constructor<T> constructor = entityType.getDeclaredConstructor();
constructor.setAccessible(true);
return constructor.newInstance();
}
private <T> void setField(T entity, String fieldName, String attributeValue) throws ReflectiveOperationException {
Field field = JpaUtil.getField(entity, fieldName);
field.setAccessible(true);
Object fieldValue = createObjectFromString(attributeValue, field, JpaUtil.getPrimaryKeyType(entity.getClass()));
LOG.trace("Setting field {} to {}", fieldName, fieldValue);
field.set(entity, fieldValue);
if (fieldValue != null && fieldValue.getClass().getAnnotation(Entity.class) != null) {
potentiallyAddValueToCollection(fieldValue, fieldName, entity);
}
}
private <T> void potentiallyAddValueToCollection(Object entity, String fieldName, T value) {
for (Field field : entity.getClass().getDeclaredFields()) {
field.setAccessible(true);
OneToMany oneToMany = field.getAnnotation(OneToMany.class);
if (oneToMany == null) {
continue;
}
if (oneToMany.mappedBy().equals(fieldName)) {
JpaUtil.addObjectToCollectionField(field, entity, value);
}
}
}
private @Nullable Object createObjectFromString(String s, Field field, Class<?> primaryKeyType) {
Class<?> type;
if (field.getAnnotation(Id.class) != null) {
type = primaryKeyType;
} else {
type = field.getType();
}
return createObjectFromString(s, type);
}
@SuppressWarnings("checkstyle:noReturnNull")
private @Nullable Object createObjectFromString(String s, Class<?> type) {
if (s == null || s.isEmpty()) {
return null;
}
if (type == Byte.class) {
return Byte.parseByte(s);
} else if (type == Short.class) {
return Short.parseShort(s);
} else if (type == Integer.class) {
return Integer.parseInt(s);
} else if (type == Long.class) {
return Long.parseLong(s);
} else if (type == Float.class) {
return Float.parseFloat(s);
} else if (type == Double.class) {
return Double.parseDouble(s);
} else if (type == Boolean.class) {
return Boolean.parseBoolean(s);
} else if (type == Character.class) {
return s.charAt(0);
} else if (type == String.class) {
return s;
} else if (type == Date.class) {
return TestUtil.parseDate(s);
} else if (type == LocalDate.class) {
return LocalDate.parse(s);
} else if (type == LocalDateTime.class) {
return LocalDateTime.parse(s);
} else if (type == ZonedDateTime.class) {
return ZonedDateTime.parse(s);
} else if (type == BigDecimal.class) {
return new BigDecimal(s);
} else if (type == BigInteger.class) {
return new BigInteger(s);
} else if (type == UUID.class) {
return UUID.fromString(s);
} else if (type == List.class) {
return Arrays.asList(s.replaceAll("[{}]", "").split(","));
} else {
return findEntity(s, type);
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private Object findEntity(String strId, Class<?> entityType) {
if (entityType.isEnum()) {
Class<? extends Enum> enumType = (Class<? extends Enum>) entityType;
return Enum.valueOf(enumType, strId);
}
for (Object entity : entities) {
if (entity.getClass().equals(entityType)) {
Field idField = JpaUtil.getIdField(entity);
idField.setAccessible(true);
try {
Object id = idField.get(entity);
if (id != null && id.equals(createObjectFromString(strId, id.getClass()))) {
return entity;
}
} catch (IllegalAccessException e) {
throw new IllegalStateException("Unexpected problem looking up entity of " + entityType + " with primary key " + strId,
e);
}
}
}
throw new IllegalArgumentException("Entity of " + entityType + " with primary key " + strId + " not found");
}
/**
* Represents one row of data from the database.
*
* @author RealLifeDeveloper
*/
public interface DbTableRow {
/**
* Gives the fields of this row.
*
* @return the fields
*/
List<DbTableField> columns();
}
/**
* Represents the value of a single field in the database.
*
* @param name the name of the database column
* @param value the value of the field
*/
public record DbTableField(String name, String value) {
}
}