View Javadoc
1   package com.reallifedeveloper.tools.test.database.inmemory;
2   
3   import java.lang.reflect.Constructor;
4   import java.lang.reflect.Field;
5   import java.lang.reflect.InvocationTargetException;
6   import java.lang.reflect.Method;
7   import java.util.ArrayList;
8   import java.util.Collections;
9   import java.util.HashMap;
10  import java.util.List;
11  import java.util.Map;
12  import java.util.Objects;
13  import java.util.Optional;
14  
15  import org.checkerframework.checker.nullness.qual.NonNull;
16  import org.checkerframework.checker.nullness.qual.Nullable;
17  import org.springframework.dao.EmptyResultDataAccessException;
18  import org.springframework.data.domain.Example;
19  import org.springframework.data.domain.Page;
20  import org.springframework.data.domain.PageImpl;
21  import org.springframework.data.domain.Pageable;
22  import org.springframework.data.domain.Sort;
23  import org.springframework.data.repository.CrudRepository;
24  import org.springframework.data.repository.PagingAndSortingRepository;
25  import org.springframework.data.repository.query.QueryByExampleExecutor;
26  
27  import com.reallifedeveloper.tools.test.TestUtil;
28  
29  /**
30   * An abstract helper class that implements the {@link CrudRepository} interface using an in-memory map instead of a database.
31   * <p>
32   * Contains useful methods for sub-classes implementing in-memory versions of repositories, such as {@link #findByField(String, Object)} and
33   * {@link #findByUniqueField(String, Object)}.
34   *
35   * @param <T>  the type of the entities handled by this repository
36   * @param <ID> the type of the entities' primary keys
37   *
38   * @author RealLifeDeveloper
39   */
40  @SuppressWarnings({ "PMD", "checkstyle:noReturnNull" }) // TODO: Consider refactoring this class using the hints from PMD
41  public abstract class AbstractInMemoryCrudRepository<T, ID extends Comparable<ID>>
42          implements CrudRepository<T, ID>, PagingAndSortingRepository<T, ID>, QueryByExampleExecutor<T> {
43  
44      private final Map<@NonNull ID, @NonNull T> entities = new HashMap<>();
45  
46      private final @Nullable PrimaryKeyGenerator<ID> primaryKeyGenerator;
47  
48      /**
49       * Creates a new {@code InMemoryCrudRepository} with no primary key generator. If an entity with a {@code null} primary key is saved, an
50       * exception is thrown.
51       */
52      public AbstractInMemoryCrudRepository() {
53          this.primaryKeyGenerator = null;
54      }
55  
56      /**
57       * Creates a new {@code InMemoryCrudRepository} with the provided primary key generator. If an entity with a {@code null} primary key is
58       * saved, the generator is used to create a new primary key that is stored in the entity before saving.
59       *
60       * @param primaryKeyGenerator the primary key generator to use, must not be {@code null}
61       */
62      public AbstractInMemoryCrudRepository(PrimaryKeyGenerator<ID> primaryKeyGenerator) {
63          if (primaryKeyGenerator == null) {
64              throw new IllegalArgumentException("primaryKeyGenerator must not be null");
65          }
66          this.primaryKeyGenerator = primaryKeyGenerator;
67      }
68  
69      /**
70       * Finds entities with a field matching a value.
71       *
72       * @param fieldName the name of the field to use when searching
73       * @param value     the value to search for
74       * @param <F>       the type of {@code value}
75       * @return a list of entities {@code e} such that {@code value.equals(e.fieldName)}
76       *
77       * @throws IllegalArgumentException if {@code fieldName} is {@code null}
78       */
79      protected <F> List<@NonNull T> findByField(String fieldName, F value) {
80          if (fieldName == null) {
81              throw new IllegalArgumentException("fieldName must not be null");
82          }
83          return entities.values().stream().filter(entity -> Objects.equals(value, TestUtil.getFieldValue(entity, fieldName))).toList();
84      }
85  
86      /**
87       * Finds a unique entity with a field matching a value.
88       *
89       * @param fieldName the name of the field to use when searching
90       * @param value     the value to search for
91       * @param <F>       the type of {@code value}
92       *
93       * @return the unique entity {@code e} such that {@code value.equals(e.fieldName)}, or {@code null} if no such entity is found
94       *
95       * @throws IllegalArgumentException if either argument is {@code null}, or if more than one entity with the given value is found
96       */
97      protected <F> Optional<T> findByUniqueField(String fieldName, F value) {
98          List<T> foundEntities = findByField(fieldName, value);
99          if (foundEntities.isEmpty()) {
100             return Optional.empty();
101         } else if (foundEntities.size() == 1) {
102             return Optional.of(foundEntities.get(0));
103         } else {
104             throw new IllegalArgumentException(
105                     "Field " + fieldName + " is not unique, found " + foundEntities.size() + " entities: " + foundEntities);
106         }
107     }
108 
109     /**
110      * {@inheritDoc}
111      */
112     @Override
113     public long count() {
114         return entities.size();
115     }
116 
117     /**
118      * {@inheritDoc}
119      */
120     @Override
121     public void deleteById(ID id) {
122         if (id == null) {
123             throw new IllegalArgumentException("id must not be null");
124         }
125         T removedEntity = entities.remove(id);
126         if (removedEntity == null) {
127             throw new EmptyResultDataAccessException("Entity with id " + id + " not found", 1);
128         }
129     }
130 
131     /**
132      * {@inheritDoc}
133      */
134     @Override
135     public void deleteAll(Iterable<? extends T> entitiesToDelete) {
136         if (entitiesToDelete == null) {
137             throw new IllegalArgumentException("entitiesToDelete must not be null");
138         }
139         for (T entity : entitiesToDelete) {
140             delete(entity);
141         }
142     }
143 
144     /**
145      * {@inheritDoc}
146      */
147     @Override
148     public void deleteAllById(Iterable<? extends ID> ids) {
149         for (ID id : ids) {
150             deleteById(id);
151         }
152     }
153 
154     /**
155      * {@inheritDoc}
156      */
157     @Override
158     public void delete(T entity) {
159         if (entity == null) {
160             throw new IllegalArgumentException("entity must not be null");
161         }
162         entities.remove(getId(entity));
163     }
164 
165     /**
166      * {@inheritDoc}
167      */
168     @Override
169     public void deleteAll() {
170         entities.clear();
171     }
172 
173     /**
174      * {@inheritDoc}
175      */
176     @Override
177     public boolean existsById(ID id) {
178         if (id == null) {
179             throw new IllegalArgumentException("id must not be null");
180         }
181         return entities.containsKey(id);
182     }
183 
184     /**
185      * {@inheritDoc}
186      */
187     @Override
188     public List<T> findAll() {
189         return new ArrayList<>(entities.values());
190     }
191 
192     /**
193      * {@inheritDoc}
194      */
195     @Override
196     public List<T> findAllById(Iterable<ID> ids) {
197         if (ids == null) {
198             throw new IllegalArgumentException("ids must not be null");
199         }
200         List<T> selectedEntities = new ArrayList<T>();
201         for (ID id : ids) {
202             Optional<T> optionalEntity = findById(id);
203             if (optionalEntity.isPresent()) {
204                 selectedEntities.add(optionalEntity.get());
205             }
206         }
207         return selectedEntities;
208     }
209 
210     /**
211      * {@inheritDoc}
212      */
213     @Override
214     public Optional<T> findById(ID id) {
215         if (id == null) {
216             throw new IllegalArgumentException("id must not be null");
217         }
218         T item = entities.get(id);
219         return Optional.ofNullable(item);
220     }
221 
222     /**
223      * {@inheritDoc}
224      */
225     @Override
226     public <S extends T> S save(S entity) {
227         if (entity == null) {
228             throw new IllegalArgumentException("entity must not be null");
229         }
230         ID id = getId(entity);
231         if (id == null) {
232             if (primaryKeyGenerator != null) {
233                 id = primaryKeyGenerator.nextPrimaryKey(maximumPrimaryKey());
234                 setId(entity, id);
235             } else {
236                 throw new IllegalStateException("Primary key is null and no primary key generator available: entity=" + entity);
237             }
238         }
239         entities.put(id, entity);
240         return entity;
241     }
242 
243     /**
244      * {@inheritDoc}
245      */
246     @Override
247     public <S extends T> List<S> saveAll(Iterable<S> entitiesToSave) {
248         if (entitiesToSave == null) {
249             throw new IllegalArgumentException("entitiesToSave must not be null");
250         }
251         List<S> savedEntities = new ArrayList<>();
252         for (S entity : entitiesToSave) {
253             savedEntities.add(save(entity));
254         }
255         return savedEntities;
256     }
257 
258     //
259     // PagingAndSortingRepository methods
260     //
261 
262     /**
263      * {@inheritDoc}
264      */
265     @Override
266     public List<T> findAll(Sort sort) {
267         return SortUtil.sort(findAll(), sort);
268     }
269 
270     /**
271      * {@inheritDoc}
272      */
273     @Override
274     public Page<T> findAll(Pageable pageable) {
275         List<T> allEntities = SortUtil.sort(findAll(), pageable.getSort());
276         int start = (int) pageable.getOffset();
277         int end = (start + pageable.getPageSize()) > allEntities.size() ? allEntities.size() : (start + pageable.getPageSize());
278         List<T> pagedEntities = start <= end ? allEntities.subList(start, end) : Collections.emptyList();
279         Page<T> page = new PageImpl<>(pagedEntities, pageable, allEntities.size());
280         return page;
281     }
282 
283     //
284     // Helper methods
285     //
286 
287     /**
288      * Gives the value of the largest primary key in an entity currently in the repository.
289      * <p>
290      * This method returns {@code null} if there are no entities in the repository. Since the save methods guarantee that
291      *
292      * @return the largest primary key in the repository, or {@code null} if the repository is empty
293      */
294     private @Nullable ID maximumPrimaryKey() {
295         ID max = null;
296         for (T entity : findAll()) {
297             ID id = getId(entity);
298             if (max == null || (id != null && id.compareTo(max) > 0)) {
299                 max = id;
300             }
301         }
302         return max;
303     }
304 
305     /**
306      * Gives the value of the ID field or method of the given entity.
307      *
308      * @param entity the entity to examine, should not be {@code null}
309      *
310      * @return the value of the ID field or method of {@code entity}, may be {@code null}
311      */
312     @SuppressWarnings("unchecked")
313     protected @Nullable ID getId(T entity) {
314         ID id = null;
315         try {
316             if (getCompositeIdClass(entity).isPresent()) {
317                 // TODO: Handle IdClass with method annotations
318                 List<Field> idFields = getIdFields(entity);
319                 id = createIdClassInstance(entity, idFields);
320             } else if (getIdField(entity).isPresent()) {
321                 id = (ID) getIdField(entity).get().get(entity);
322             } else if (getIdMethod(entity).isPresent()) {
323                 id = (ID) getIdMethod(entity).get().invoke(entity);
324             } else {
325                 throw new IllegalArgumentException("Entity has no @Id annotation: " + entity);
326             }
327         } catch (ReflectiveOperationException e) {
328             throw new IllegalStateException(e);
329         }
330         return id;
331     }
332 
333     /**
334      * Sets the value of the ID field, or calls the ID setter method, for the given entity.
335      *
336      * @param entity the entity for which to set the ID
337      * @param id     the new ID value
338      */
339     protected void setId(T entity, @Nullable ID id) {
340         try {
341             if (getCompositeIdClass(entity).isPresent()) {
342                 // TODO: Handle IdClass with method annotations
343                 List<Field> idFields = getIdFields(entity);
344                 if (id == null) {
345                     for (Field idField : idFields) {
346                         idField.set(entity, null);
347                     }
348                 } else {
349                     setIdFieldsFromIdClassInstance(entity, idFields, id);
350                 }
351             } else if (getIdField(entity).isPresent()) {
352                 getIdField(entity).get().set(entity, id);
353             } else if (getIdMethod(entity).isPresent()) {
354                 @SuppressWarnings("unchecked")
355                 Class<ID> idClass = (Class<ID>) getIdMethod(entity).get().getReturnType();
356                 Method setIdMethod = getSetMethod(entity, idClass);
357                 setIdMethod.invoke(entity, id);
358             }
359         } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
360             throw new IllegalStateException(e);
361         }
362     }
363 
364     private Optional<Field> getIdField(T entity) {
365         List<Field> idFields = getIdFields(entity);
366         if (idFields.isEmpty()) {
367             return Optional.empty();
368         } else if (idFields.size() > 1) {
369             throw new IllegalStateException("Multiptle ID fields found in entity: " + entity);
370         } else {
371             return Optional.of(idFields.get(0));
372         }
373     }
374 
375     private List<Field> getIdFields(T entity) {
376         List<Field> idFields = new ArrayList<>();
377         Class<?> c = entity.getClass();
378         while (c != null) {
379             for (Field field : c.getDeclaredFields()) {
380                 if (isIdField(field)) {
381                     field.setAccessible(true);
382                     idFields.add(field);
383                 }
384             }
385             c = c.getSuperclass();
386         }
387         return idFields;
388     }
389 
390     private Optional<Method> getIdMethod(T entity) {
391         List<Method> idMethods = getIdMethods(entity);
392         if (idMethods.isEmpty()) {
393             return Optional.empty();
394         } else if (idMethods.size() > 1) {
395             throw new IllegalStateException("Multiptle ID methods found in entity: " + entity);
396         } else {
397             return Optional.of(idMethods.get(0));
398         }
399     }
400 
401     private List<Method> getIdMethods(T entity) {
402         List<Method> idMethods = new ArrayList<>();
403         Class<?> c = entity.getClass();
404         while (c != null) {
405             for (Method method : c.getDeclaredMethods()) {
406                 if (isIdMethod(method)) {
407                     method.setAccessible(true);
408                     idMethods.add(method);
409                 }
410             }
411             c = c.getSuperclass();
412         }
413         return idMethods;
414     }
415 
416     private Method getSetMethod(T entity, Class<ID> idClass) throws NoSuchMethodException {
417         Method getMethod = getIdMethod(entity)
418                 .orElseThrow(() -> new NoSuchMethodException("Get method for ID not found: entity=" + entity));
419         String setMethodName = getMethod.getName().replaceFirst("^get", "set");
420         Method setMethod = entity.getClass().getMethod(setMethodName, idClass);
421         setMethod.setAccessible(true);
422         return setMethod;
423     }
424 
425     private @Nullable ID createIdClassInstance(T entity, List<Field> idFields) throws ReflectiveOperationException {
426         Class<ID> idClass = getCompositeIdClass(entity)
427                 .orElseThrow(() -> new IllegalStateException("Composibte ID class not found: entity=" + entity));
428         Constructor<ID> constructor = idClass.getDeclaredConstructor();
429         constructor.setAccessible(true);
430         ID id = constructor.newInstance();
431         for (Field idField : idFields) {
432             if (idField.get(entity) == null) {
433                 // If any of the ID fields is null, we say that the primary key is null.
434                 return null;
435             }
436             TestUtil.injectField(id, idField.getName(), idField.get(entity));
437         }
438         return id;
439     }
440 
441     private void setIdFieldsFromIdClassInstance(T entity, List<Field> idFields, ID id) {
442         for (Field idField : idFields) {
443             Object value = TestUtil.getFieldValue(id, idField.getName());
444             TestUtil.injectField(entity, idField.getName(), value);
445         }
446     }
447 
448     /**
449      * Override this in a concrete subclass to decide if a given field is an ID field of an entity.
450      *
451      * @param field the field to examine
452      *
453      * @return {@code true} if {@code field} is an ID field, {@code false} otherwise
454      */
455     protected abstract boolean isIdField(Field field);
456 
457     /**
458      * Override this in a concrete subclass to decide if a given method is a method giving the ID of an entity.
459      *
460      * @param method the method to examine
461      *
462      * @return {@code true} if {@code method} is an ID method, {@code false} otherwise
463      */
464     protected abstract boolean isIdMethod(Method method);
465 
466     /**
467      * Override this in concrete subclass to give the ID class representing a composite primary key for an entity, if any.
468      *
469      * @param entity the entity to examine
470      *
471      * @return the ID class representing the composite primary key of {@code entity}, or an empty optional if there is no such class
472      */
473     protected abstract Optional<Class<ID>> getCompositeIdClass(Object entity);
474 
475     @Override
476     public String toString() {
477         return getClass().getSimpleName() + "{entities=" + entities + "}";
478     }
479 
480     //
481     // QueryByExampleExecutor methods
482     //
483 
484     /**
485      * {@inheritDoc}
486      * <p>
487      * This method is not yet implemented, so it always throws an exception.
488      *
489      * @throws UnsupportedOperationException always
490      */
491     @Override
492     public <S extends T> Optional<S> findOne(Example<S> example) {
493         throw new UnsupportedOperationException("Not yet implemented");
494     }
495 
496     /**
497      * {@inheritDoc}
498      * <p>
499      * This method is not yet implemented, so it always throws an exception.
500      *
501      * @throws UnsupportedOperationException always
502      */
503     @Override
504     public <S extends T> List<S> findAll(Example<S> example) {
505         throw new UnsupportedOperationException("Not yet implemented");
506     }
507 
508     /**
509      * {@inheritDoc}
510      * <p>
511      * This method is not yet implemented, so it always throws an exception.
512      *
513      * @throws UnsupportedOperationException always
514      */
515     @Override
516     public <S extends T> Page<S> findAll(Example<S> example, Pageable pageable) {
517         throw new UnsupportedOperationException("Not yet implemented");
518     }
519 
520     /**
521      * {@inheritDoc}
522      * <p>
523      * This method is not yet implemented, so it always throws an exception.
524      *
525      * @throws UnsupportedOperationException always
526      */
527     @Override
528     public <S extends T> List<S> findAll(Example<S> example, Sort sort) {
529         throw new UnsupportedOperationException("Not yet implemented");
530     }
531 
532     /**
533      * {@inheritDoc}
534      * <p>
535      * This method is not yet implemented, so it always throws an exception.
536      *
537      * @throws UnsupportedOperationException always
538      */
539     @Override
540     public <S extends T> long count(Example<S> example) {
541         throw new UnsupportedOperationException("Not yet implemented");
542     }
543 
544     /**
545      * {@inheritDoc}
546      * <p>
547      * This method is not yet implemented, so it always throws an exception.
548      *
549      * @throws UnsupportedOperationException always
550      */
551     @Override
552     public <S extends T> boolean exists(Example<S> example) {
553         throw new UnsupportedOperationException("Not yet implemented");
554     }
555 
556     /**
557      * Make finalize method final to avoid "Finalizer attacks" and corresponding SpotBugs warning (CT_CONSTRUCTOR_THROW).
558      *
559      * @see <a href="https://wiki.sei.cmu.edu/confluence/display/java/OBJ11-J.+Be+wary+of+letting+constructors+throw+exceptions">
560      *      Explanation of finalizer attack</a>
561      */
562     @Override
563     @SuppressWarnings({ "deprecation", "removal", "Finalize", "checkstyle:NoFinalizer", "PMD.EmptyFinalizer",
564             "PMD.EmptyMethodInAbstractClassShouldBeAbstract" })
565     protected final void finalize() throws Throwable {
566         // Do nothing
567     }
568 
569 }