AbstractInMemoryCrudRepository.java

package com.reallifedeveloper.tools.test.database.inmemory;

import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.data.domain.Example;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.PagingAndSortingRepository;
import org.springframework.data.repository.query.QueryByExampleExecutor;

import com.reallifedeveloper.tools.test.TestUtil;

/**
 * An abstract helper class that implements the {@link CrudRepository} interface using an in-memory map instead of a database.
 * <p>
 * Contains useful methods for sub-classes implementing in-memory versions of repositories, such as {@link #findByField(String, Object)} and
 * {@link #findByUniqueField(String, Object)}.
 *
 * @param <T>  the type of the entities handled by this repository
 * @param <ID> the type of the entities' primary keys
 *
 * @author RealLifeDeveloper
 */
@SuppressWarnings({ "PMD", "checkstyle:noReturnNull" }) // TODO: Consider refactoring this class using the hints from PMD
public abstract class AbstractInMemoryCrudRepository<T, ID extends Comparable<ID>>
        implements CrudRepository<T, ID>, PagingAndSortingRepository<T, ID>, QueryByExampleExecutor<T> {

    private final Map<@NonNull ID, @NonNull T> entities = new HashMap<>();

    private final @Nullable PrimaryKeyGenerator<ID> primaryKeyGenerator;

    /**
     * Creates a new {@code InMemoryCrudRepository} with no primary key generator. If an entity with a {@code null} primary key is saved, an
     * exception is thrown.
     */
    public AbstractInMemoryCrudRepository() {
        this.primaryKeyGenerator = null;
    }

    /**
     * Creates a new {@code InMemoryCrudRepository} with the provided primary key generator. If an entity with a {@code null} primary key is
     * saved, the generator is used to create a new primary key that is stored in the entity before saving.
     *
     * @param primaryKeyGenerator the primary key generator to use, must not be {@code null}
     */
    public AbstractInMemoryCrudRepository(PrimaryKeyGenerator<ID> primaryKeyGenerator) {
        if (primaryKeyGenerator == null) {
            throw new IllegalArgumentException("primaryKeyGenerator must not be null");
        }
        this.primaryKeyGenerator = primaryKeyGenerator;
    }

    /**
     * Finds entities with a field matching a value.
     *
     * @param fieldName the name of the field to use when searching
     * @param value     the value to search for
     * @param <F>       the type of {@code value}
     * @return a list of entities {@code e} such that {@code value.equals(e.fieldName)}
     *
     * @throws IllegalArgumentException if {@code fieldName} is {@code null}
     */
    protected <F> List<@NonNull T> findByField(String fieldName, F value) {
        if (fieldName == null) {
            throw new IllegalArgumentException("fieldName must not be null");
        }
        return entities.values().stream().filter(entity -> Objects.equals(value, TestUtil.getFieldValue(entity, fieldName))).toList();
    }

    /**
     * Finds a unique entity with a field matching a value.
     *
     * @param fieldName the name of the field to use when searching
     * @param value     the value to search for
     * @param <F>       the type of {@code value}
     *
     * @return the unique entity {@code e} such that {@code value.equals(e.fieldName)}, or {@code null} if no such entity is found
     *
     * @throws IllegalArgumentException if either argument is {@code null}, or if more than one entity with the given value is found
     */
    protected <F> Optional<T> findByUniqueField(String fieldName, F value) {
        List<T> foundEntities = findByField(fieldName, value);
        if (foundEntities.isEmpty()) {
            return Optional.empty();
        } else if (foundEntities.size() == 1) {
            return Optional.of(foundEntities.get(0));
        } else {
            throw new IllegalArgumentException(
                    "Field " + fieldName + " is not unique, found " + foundEntities.size() + " entities: " + foundEntities);
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public long count() {
        return entities.size();
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void deleteById(ID id) {
        if (id == null) {
            throw new IllegalArgumentException("id must not be null");
        }
        T removedEntity = entities.remove(id);
        if (removedEntity == null) {
            throw new EmptyResultDataAccessException("Entity with id " + id + " not found", 1);
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void deleteAll(Iterable<? extends T> entitiesToDelete) {
        if (entitiesToDelete == null) {
            throw new IllegalArgumentException("entitiesToDelete must not be null");
        }
        for (T entity : entitiesToDelete) {
            delete(entity);
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void deleteAllById(Iterable<? extends ID> ids) {
        for (ID id : ids) {
            deleteById(id);
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void delete(T entity) {
        if (entity == null) {
            throw new IllegalArgumentException("entity must not be null");
        }
        entities.remove(getId(entity));
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void deleteAll() {
        entities.clear();
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public boolean existsById(ID id) {
        if (id == null) {
            throw new IllegalArgumentException("id must not be null");
        }
        return entities.containsKey(id);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public List<T> findAll() {
        return new ArrayList<>(entities.values());
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public List<T> findAllById(Iterable<ID> ids) {
        if (ids == null) {
            throw new IllegalArgumentException("ids must not be null");
        }
        List<T> selectedEntities = new ArrayList<T>();
        for (ID id : ids) {
            Optional<T> optionalEntity = findById(id);
            if (optionalEntity.isPresent()) {
                selectedEntities.add(optionalEntity.get());
            }
        }
        return selectedEntities;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public Optional<T> findById(ID id) {
        if (id == null) {
            throw new IllegalArgumentException("id must not be null");
        }
        T item = entities.get(id);
        return Optional.ofNullable(item);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public <S extends T> S save(S entity) {
        if (entity == null) {
            throw new IllegalArgumentException("entity must not be null");
        }
        ID id = getId(entity);
        if (id == null) {
            if (primaryKeyGenerator != null) {
                id = primaryKeyGenerator.nextPrimaryKey(maximumPrimaryKey());
                setId(entity, id);
            } else {
                throw new IllegalStateException("Primary key is null and no primary key generator available: entity=" + entity);
            }
        }
        entities.put(id, entity);
        return entity;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public <S extends T> List<S> saveAll(Iterable<S> entitiesToSave) {
        if (entitiesToSave == null) {
            throw new IllegalArgumentException("entitiesToSave must not be null");
        }
        List<S> savedEntities = new ArrayList<>();
        for (S entity : entitiesToSave) {
            savedEntities.add(save(entity));
        }
        return savedEntities;
    }

    //
    // PagingAndSortingRepository methods
    //

    /**
     * {@inheritDoc}
     */
    @Override
    public List<T> findAll(Sort sort) {
        return SortUtil.sort(findAll(), sort);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public Page<T> findAll(Pageable pageable) {
        List<T> allEntities = SortUtil.sort(findAll(), pageable.getSort());
        int start = (int) pageable.getOffset();
        int end = (start + pageable.getPageSize()) > allEntities.size() ? allEntities.size() : (start + pageable.getPageSize());
        List<T> pagedEntities = start <= end ? allEntities.subList(start, end) : Collections.emptyList();
        Page<T> page = new PageImpl<>(pagedEntities, pageable, allEntities.size());
        return page;
    }

    //
    // Helper methods
    //

    private ID maximumPrimaryKey() {
        ID max = null;
        for (T entity : findAll()) {
            ID id = getId(entity);
            if (max == null || id.compareTo(max) > 0) {
                max = id;
            }
        }
        return max;
    }

    /**
     * Gives the value of the ID field or method of the given entity.
     *
     * @param entity the entity to examine, should not be {@code null}
     *
     * @return the value of the ID field or method of {@code entity}, may be {@code null}
     */
    @SuppressWarnings("unchecked")
    protected @Nullable ID getId(T entity) {
        ID id = null;
        try {
            if (getIdClass(entity).isPresent()) {
                // TODO: Handle IdClass with method annotations
                List<Field> idFields = getIdFields(entity);
                id = createIdClassInstance(entity, idFields);
            } else if (getIdField(entity).isPresent()) {
                id = (ID) getIdField(entity).get().get(entity);
            } else if (getIdMethod(entity).isPresent()) {
                id = (ID) getIdMethod(entity).get().invoke(entity);
            } else {
                throw new IllegalArgumentException("Entity has no @Id annotation: " + entity);
            }
        } catch (ReflectiveOperationException e) {
            throw new IllegalStateException(e);
        }
        return id;
    }

    /**
     * Sets the value of the ID field, or calls the ID setter method, for the given entity.
     *
     * @param entity the entity for which to set the ID
     * @param id     the new ID value
     */
    protected void setId(T entity, ID id) {
        try {
            if (getIdClass(entity).isPresent()) {
                // TODO: Handle IdClass with method annotations
                List<Field> idFields = getIdFields(entity);
                setIdFieldsFromIdClassInstance(entity, idFields, id);
            } else if (getIdField(entity).isPresent()) {
                getIdField(entity).get().set(entity, id);
            } else if (getIdMethod(entity).isPresent()) {
                Method setMethod = getSetMethod(entity, id);
                setMethod.invoke(entity, id);
            }
        } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
            throw new IllegalStateException(e);
        }
    }

    private Optional<Field> getIdField(T entity) {
        List<Field> idFields = getIdFields(entity);
        if (idFields.isEmpty()) {
            return Optional.empty();
        } else if (idFields.size() > 1) {
            throw new IllegalStateException("Multiptle ID fields found in entity: " + entity);
        } else {
            return Optional.of(idFields.get(0));
        }
    }

    private List<Field> getIdFields(T entity) {
        List<Field> idFields = new ArrayList<>();
        Class<?> c = entity.getClass();
        while (c != null) {
            for (Field field : c.getDeclaredFields()) {
                if (isIdField(field)) {
                    field.setAccessible(true);
                    idFields.add(field);
                }
            }
            c = c.getSuperclass();
        }
        return idFields;
    }

    private Optional<Method> getIdMethod(T entity) {
        List<Method> idMethods = getIdMethods(entity);
        if (idMethods.isEmpty()) {
            return Optional.empty();
        } else if (idMethods.size() > 1) {
            throw new IllegalStateException("Multiptle ID methods found in entity: " + entity);
        } else {
            return Optional.of(idMethods.get(0));
        }
    }

    private List<Method> getIdMethods(T entity) {
        List<Method> idMethods = new ArrayList<>();
        Class<?> c = entity.getClass();
        while (c != null) {
            for (Method method : c.getDeclaredMethods()) {
                if (isIdMethod(method)) {
                    method.setAccessible(true);
                    idMethods.add(method);
                }
            }
            c = c.getSuperclass();
        }
        return idMethods;
    }

    private Method getSetMethod(T entity, ID id) throws NoSuchMethodException {
        Method getMethod = getIdMethod(entity)
                .orElseThrow(() -> new NoSuchMethodException("Get method for ID not found: entity=" + entity));
        String setMethodName = getMethod.getName().replaceFirst("^get", "set");
        Method setMethod = entity.getClass().getMethod(setMethodName, id.getClass());
        setMethod.setAccessible(true);
        return setMethod;
    }

    private @Nullable ID createIdClassInstance(T entity, List<Field> idFields) throws ReflectiveOperationException {
        Class<ID> idClass = getIdClass(entity).orElseThrow(() -> new IllegalStateException("ID class not found: entity=" + entity));
        Constructor<ID> constructor = idClass.getDeclaredConstructor();
        constructor.setAccessible(true);
        ID id = constructor.newInstance();
        for (Field idField : idFields) {
            if (idField.get(entity) == null) {
                // If any of the ID fields is null, we say that the primary key is null.
                return null;
            }
            TestUtil.injectField(id, idField.getName(), idField.get(entity));
        }
        return id;
    }

    private void setIdFieldsFromIdClassInstance(T entity, List<Field> idFields, ID id) {
        for (Field idField : idFields) {
            Object value = TestUtil.getFieldValue(id, idField.getName());
            TestUtil.injectField(entity, idField.getName(), value);
        }
    }

    /**
     * Override this in a concrete subclass to decide if a given field is an ID field of an entity.
     *
     * @param field the field to examine
     *
     * @return {@code true} if {@code field} is an ID field, {@code false} otherwise
     */
    protected abstract boolean isIdField(Field field);

    /**
     * Override this in a concrete subclass to decide if a given method is a method giving the ID of an entity.
     *
     * @param method the method to examine
     *
     * @return {@code true} if {@code method} is an ID method, {@code false} otherwise
     */
    protected abstract boolean isIdMethod(Method method);

    /**
     * Override this in concrete subclass to give the ID class representing a composite primary key for an entity, if any.
     *
     * @param entity the entity to examine
     *
     * @return the ID class representing the composite primary key of {@code entity}, or an empty optional if there is no such class
     */
    protected abstract Optional<Class<ID>> getIdClass(Object entity);

    @Override
    public String toString() {
        return getClass().getSimpleName() + "{entities=" + entities + "}";
    }

    //
    // QueryByExampleExecutor methods
    //

    /**
     * {@inheritDoc}
     * <p>
     * This method is not yet implemented, so it always throws an exception.
     *
     * @throws UnsupportedOperationException always
     */
    @Override
    public <S extends T> Optional<S> findOne(Example<S> example) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    /**
     * {@inheritDoc}
     * <p>
     * This method is not yet implemented, so it always throws an exception.
     *
     * @throws UnsupportedOperationException always
     */
    @Override
    public <S extends T> List<S> findAll(Example<S> example) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    /**
     * {@inheritDoc}
     * <p>
     * This method is not yet implemented, so it always throws an exception.
     *
     * @throws UnsupportedOperationException always
     */
    @Override
    public <S extends T> Page<S> findAll(Example<S> example, Pageable pageable) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    /**
     * {@inheritDoc}
     * <p>
     * This method is not yet implemented, so it always throws an exception.
     *
     * @throws UnsupportedOperationException always
     */
    @Override
    public <S extends T> List<S> findAll(Example<S> example, Sort sort) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    /**
     * {@inheritDoc}
     * <p>
     * This method is not yet implemented, so it always throws an exception.
     *
     * @throws UnsupportedOperationException always
     */
    @Override
    public <S extends T> long count(Example<S> example) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    /**
     * {@inheritDoc}
     * <p>
     * This method is not yet implemented, so it always throws an exception.
     *
     * @throws UnsupportedOperationException always
     */
    @Override
    public <S extends T> boolean exists(Example<S> example) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    /**
     * Make finalize method final to avoid "Finalizer attacks" and corresponding SpotBugs warning (CT_CONSTRUCTOR_THROW).
     *
     * @see <a href="https://wiki.sei.cmu.edu/confluence/display/java/OBJ11-J.+Be+wary+of+letting+constructors+throw+exceptions">
     *      Explanation of finalizer attack</a>
     */
    @Override
    @SuppressWarnings({ "checkstyle:NoFinalizer", "PMD.EmptyFinalizer", "PMD.EmptyMethodInAbstractClassShouldBeAbstract" })
    protected final void finalize() throws Throwable {
        // Do nothing
    }

}