JpaUtil.java
package com.reallifedeveloper.tools.test.database;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import jakarta.persistence.Column;
import jakarta.persistence.EmbeddedId;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.IdClass;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.MapKey;
import jakarta.persistence.Table;
import lombok.experimental.UtilityClass;
/**
* A utility class for working with JPA entities and mappings.
*
* @author RealLifeDeveloper
*/
@UtilityClass
@SuppressWarnings("PMD")
@SuppressFBWarnings(value = { "CRLF_INJECTION_LOGS", "IMPROPER_UNICODE" })
public class JpaUtil {
private static final Logger LOG = LoggerFactory.getLogger(JpaUtil.class);
/**
* Gives the table name associated with an {@link Entity}.
*
* @param <T> the type of entity
* @param entityType the class object representing {@code T}
* @return the table name associated with {@code entityType}
*/
public static <T> String getTableName(Class<T> entityType) {
Table table = entityType.getAnnotation(Table.class);
if (table == null) {
return entityType.getSimpleName();
} else {
return table.name();
}
}
/**
* Gets the {@code Field} object for a field with a given name in a given object, also making the field accessible by calling
* {@code Field.setAccessible(true}.
* <p>
* This method never returns {@code null}; if the field is not found, a {@code NoSuchFieldException} is thrown
*
* @param entity the object to search for the field
* @param fieldName the name of the field for which to search
* @return the {@code Field} object representing the field named {@code fieldName} in {@code entity
* }
* @throws NoSuchFieldException if the field could not be found
*/
public static Field getField(Object entity, String fieldName) throws NoSuchFieldException {
Class<?> entityType = entity.getClass();
while (entityType != null) {
for (Field field : entityType.getDeclaredFields()) {
if (field.getName().equalsIgnoreCase(fieldName)) {
field.setAccessible(true);
return field;
}
}
entityType = entityType.getSuperclass();
}
throw new NoSuchFieldException(fieldName);
}
/**
* Gets thd ID field of an entity, i.e., the field annotated with an {@code ID} annotation.
* <p>
* This method never returns {@code null}; if no ID field is found, an {@code IllegalStateException} is thrown.
*
* @param entity the entity to search
* @return the {@code Field} object representing the ID field of {@code entity}
*/
public static Field getIdField(Object entity) {
Class<?> entityType = entity.getClass();
while (entityType != null) {
for (Field field : entityType.getDeclaredFields()) {
if (field.getDeclaredAnnotation(Id.class) != null) {
return field;
}
}
entityType = entityType.getSuperclass();
}
throw new IllegalStateException("Id field not found for entity " + entity);
}
/**
* Gets the value of the ID field of an entity.
*
* @param entity the entity to search
*
* @return the value of the ID Field
*
* @throws IllegalAccessException if there was a problem getting the value using reflection
*/
public static Object getIdValue(Object entity) throws IllegalAccessException {
Field idField = JpaUtil.getIdField(entity);
idField.setAccessible(true);
Object id = idField.get(entity);
return id;
}
/**
* Gets the name of the field representing a given attribute in a given class or one of its superclasses.
* <p>
* A field is considered to represent an attribute if one of the following is true:
* <ul>
* <li>The field has a {@code Column} annotation with a {@code name} equal to the attribute.</li>
* <li>The field has a {@code JoinColunm} annotation with a {@code name} equal to the attribute.</li>
* <li>The field name is the same as the attribute.</li>
* </ul>
* <p>
* This method never returns {@code null}; if no matching field is found, an {@code IllegalArgumentException} is thrown.
*
* @param <T> the type of the entity to search
* @param attributeName the attribute for which to try to find a matching field
* @param entityType the class in which to search for the field, continuing with superclasses if necessary
* @return the name of the field representing {@code attributeName}
* @throws IllegalArgumentException if no matching field could be found
*/
public static <T> String getFieldName(String attributeName, Class<T> entityType) {
return getFieldName(attributeName, entityType, entityType);
}
private static <T> String getFieldName(String attributeName, Class<T> entityType, Class<?> originalEntityType) {
for (Field field : entityType.getDeclaredFields()) {
if (checkFieldName(attributeName, field)) {
return field.getName();
}
}
if (entityType.getSuperclass() == null) {
throw new IllegalArgumentException("Cannot find any field matching attribute '" + attributeName.toLowerCase(Locale.getDefault())
+ "' for " + originalEntityType);
} else {
return getFieldName(attributeName, entityType.getSuperclass(), originalEntityType);
}
}
private static boolean checkFieldName(String attributeName, Field field) {
Column column = field.getAnnotation(Column.class);
if (column == null || column.name() == null) {
JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
if (joinColumn == null || joinColumn.name() == null) {
return field.getName().equalsIgnoreCase(attributeName);
} else {
return joinColumn.name().equalsIgnoreCase(attributeName);
}
} else {
return column.name().equalsIgnoreCase(attributeName);
}
}
/**
* Gives the primary key class for a given entity class.
*
* @param <ID> the type of the primary key
* @param entityType the class object representing the entity
* @return the primary key class of {@code entityType}
*/
@SuppressWarnings("unchecked")
public static <ID> Class<ID> getPrimaryKeyType(Class<?> entityType) {
if (entityType.getAnnotation(Entity.class) == null) {
throw new IllegalArgumentException("entityType does not have @Entity annotation: entityType=" + entityType);
}
Class<?> entityTypeOrSuperClass = entityType;
Type genericSuperclass = null;
while (entityTypeOrSuperClass.getSuperclass() != null) {
if (entityTypeOrSuperClass.getAnnotation(IdClass.class) != null) {
return (Class<ID>) entityTypeOrSuperClass.getAnnotation(IdClass.class).value();
}
for (Field field : entityTypeOrSuperClass.getDeclaredFields()) {
if (field.getAnnotation(EmbeddedId.class) != null) {
return (Class<ID>) field.getType();
}
if (field.getAnnotation(Id.class) != null) {
return (Class<ID>) getActualIdType(genericSuperclass, entityTypeOrSuperClass.getTypeParameters(),
field.getGenericType()).orElse((Class<Object>) field.getType());
}
}
genericSuperclass = entityTypeOrSuperClass.getGenericSuperclass();
entityTypeOrSuperClass = entityTypeOrSuperClass.getSuperclass();
}
throw new IllegalStateException("entityType without primary key annotation: entityType=" + entityType);
}
@SuppressWarnings("unchecked")
private static <ID> Optional<Class<ID>> getActualIdType(@Nullable Type genericEntityType, TypeVariable<?>[] typeVariables,
Type genericIdType) {
if (genericEntityType instanceof ParameterizedType parameterizedType) {
assert parameterizedType.getActualTypeArguments().length == typeVariables.length
: "Number of actual type arguments (" + parameterizedType.getActualTypeArguments().length
+ ") differs from number of type variables (" + typeVariables.length + ")";
for (int i = 0; i < typeVariables.length; i++) {
TypeVariable<?> typeVariable = typeVariables[i];
if (!typeVariable.getName().equals(genericIdType.getTypeName())) {
continue;
}
Type actualType = parameterizedType.getActualTypeArguments()[i];
try {
return Optional.of((Class<ID>) Class.forName(actualType.getTypeName()));
} catch (ClassNotFoundException e) {
throw new IllegalStateException("Unexpected problem looking up ID class", e);
}
}
}
return Optional.empty();
}
/**
* Calls the {@code add} method on the the {@code java.util.Collection} referenced by the given entity field to add the given value.
* <p>
* This method should only be called when you know that the field actually is a collection.
*
* @param field the field holding the collection
* @param entity the entity where {@code field} lives
* @param value the value to add, may be {@code null}
*/
@SuppressFBWarnings(value = "CRLF_INJECTION_LOGS", justification = "Only entity values being logged")
public static void addObjectToCollectionField(Field field, Object entity, Object value) {
assert Collection.class.isAssignableFrom(field.getType()) : "Expected field to be a Collection: field=" + field;
try {
Method add = field.getType().getMethod("add", Object.class);
LOG.debug("Calling add method on field {} of entity {} to add entity {} to collection", field.getName(), entity, value);
add.invoke(field.get(entity), value);
} catch (NoSuchMethodException e) {
throw new IllegalStateException(
"Method 'add' not found -- field " + fieldNameForLogging(entity, field) + " should be a Collection", e);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new IllegalStateException("Unexpected problem", e);
}
}
/**
* Calls the {@code put} method on the {@code java.util.Map} referenced by the given entity field to add the entities to map.
* <p>
* The key is found using the {@code MapKey} annotation of the field.
* <p>
* This method should only be called when you know that the field actually is a map.
*
* @param field the field holding the map
* @param entity the entity where {@code field} lives
* @param entitiesToMap a list of entities to map
*/
public static void addEntitiesToMapField(Field field, Object entity, List<?> entitiesToMap) {
LOG.trace("addEntitiesToMapField: field={}, entity={}, entitiesToMap={}", field, entity, entitiesToMap);
assert Map.class.isAssignableFrom(field.getType()) : "Expected field to be a Map: field=" + field;
for (Object entityToMap : entitiesToMap) {
MapKey mapKey = field.getAnnotation(MapKey.class);
if (mapKey == null) {
throw new IllegalStateException(
"Field " + fieldNameForLogging(entity, field) + " is a Map but is missing MapKey annotation");
}
try {
Method put = field.getType().getMethod("put", Object.class, Object.class);
Object key = JpaUtil.getField(entityToMap, mapKey.name()).get(entityToMap);
LOG.debug("Adding entity {} to map {} with key {}", entityToMap, fieldNameForLogging(entity, field), key);
put.invoke(field.get(entity), key, entityToMap);
} catch (NoSuchMethodException e) {
throw new IllegalStateException(
"Method 'put' not found -- field " + fieldNameForLogging(entity, field) + " should be a Map", e);
} catch (NoSuchFieldException e) {
throw new IllegalStateException(
"Field " + mapKey.name() + " not found, it is used in MapKey annotation in " + fieldNameForLogging(entity, field),
e);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new IllegalStateException("Unexpected problem", e);
}
}
}
private static String fieldNameForLogging(Object entity, Field field) {
return entity.getClass().getName() + "." + field.getName();
}
}