View Javadoc
1   package com.reallifedeveloper.tools.test.database.inmemory;
2   
3   import java.lang.reflect.Constructor;
4   import java.lang.reflect.Field;
5   import java.lang.reflect.InvocationTargetException;
6   import java.lang.reflect.Method;
7   import java.util.ArrayList;
8   import java.util.Collections;
9   import java.util.HashMap;
10  import java.util.List;
11  import java.util.Map;
12  import java.util.Objects;
13  import java.util.Optional;
14  
15  import org.checkerframework.checker.nullness.qual.NonNull;
16  import org.checkerframework.checker.nullness.qual.Nullable;
17  import org.springframework.dao.EmptyResultDataAccessException;
18  import org.springframework.data.domain.Example;
19  import org.springframework.data.domain.Page;
20  import org.springframework.data.domain.PageImpl;
21  import org.springframework.data.domain.Pageable;
22  import org.springframework.data.domain.Sort;
23  import org.springframework.data.repository.CrudRepository;
24  import org.springframework.data.repository.PagingAndSortingRepository;
25  import org.springframework.data.repository.query.QueryByExampleExecutor;
26  
27  import com.reallifedeveloper.tools.test.TestUtil;
28  
29  /**
30   * An abstract helper class that implements the {@link CrudRepository} interface using an in-memory map instead of a database.
31   * <p>
32   * Contains useful methods for sub-classes implementing in-memory versions of repositories, such as {@link #findByField(String, Object)} and
33   * {@link #findByUniqueField(String, Object)}.
34   *
35   * @param <T>  the type of the entities handled by this repository
36   * @param <ID> the type of the entities' primary keys
37   *
38   * @author RealLifeDeveloper
39   */
40  @SuppressWarnings({ "PMD", "checkstyle:noReturnNull" }) // TODO: Consider refactoring this class using the hints from PMD
41  public abstract class AbstractInMemoryCrudRepository<T, ID extends Comparable<ID>>
42          implements CrudRepository<T, ID>, PagingAndSortingRepository<T, ID>, QueryByExampleExecutor<T> {
43  
44      private final Map<@NonNull ID, @NonNull T> entities = new HashMap<>();
45  
46      private final @Nullable PrimaryKeyGenerator<ID> primaryKeyGenerator;
47  
48      /**
49       * Creates a new {@code InMemoryCrudRepository} with no primary key generator. If an entity with a {@code null} primary key is saved, an
50       * exception is thrown.
51       */
52      public AbstractInMemoryCrudRepository() {
53          this.primaryKeyGenerator = null;
54      }
55  
56      /**
57       * Creates a new {@code InMemoryCrudRepository} with the provided primary key generator. If an entity with a {@code null} primary key is
58       * saved, the generator is used to create a new primary key that is stored in the entity before saving.
59       *
60       * @param primaryKeyGenerator the primary key generator to use, must not be {@code null}
61       */
62      public AbstractInMemoryCrudRepository(PrimaryKeyGenerator<ID> primaryKeyGenerator) {
63          if (primaryKeyGenerator == null) {
64              throw new IllegalArgumentException("primaryKeyGenerator must not be null");
65          }
66          this.primaryKeyGenerator = primaryKeyGenerator;
67      }
68  
69      /**
70       * Finds entities with a field matching a value.
71       *
72       * @param fieldName the name of the field to use when searching
73       * @param value     the value to search for
74       * @param <F>       the type of {@code value}
75       * @return a list of entities {@code e} such that {@code value.equals(e.fieldName)}
76       *
77       * @throws IllegalArgumentException if {@code fieldName} is {@code null}
78       */
79      protected <F> List<@NonNull T> findByField(String fieldName, F value) {
80          if (fieldName == null) {
81              throw new IllegalArgumentException("fieldName must not be null");
82          }
83          return entities.values().stream().filter(entity -> Objects.equals(value, TestUtil.getFieldValue(entity, fieldName))).toList();
84      }
85  
86      /**
87       * Finds a unique entity with a field matching a value.
88       *
89       * @param fieldName the name of the field to use when searching
90       * @param value     the value to search for
91       * @param <F>       the type of {@code value}
92       *
93       * @return the unique entity {@code e} such that {@code value.equals(e.fieldName)}, or {@code null} if no such entity is found
94       *
95       * @throws IllegalArgumentException if either argument is {@code null}, or if more than one entity with the given value is found
96       */
97      protected <F> Optional<T> findByUniqueField(String fieldName, F value) {
98          List<T> foundEntities = findByField(fieldName, value);
99          if (foundEntities.isEmpty()) {
100             return Optional.empty();
101         } else if (foundEntities.size() == 1) {
102             return Optional.of(foundEntities.get(0));
103         } else {
104             throw new IllegalArgumentException(
105                     "Field " + fieldName + " is not unique, found " + foundEntities.size() + " entities: " + foundEntities);
106         }
107     }
108 
109     /**
110      * {@inheritDoc}
111      */
112     @Override
113     public long count() {
114         return entities.size();
115     }
116 
117     /**
118      * {@inheritDoc}
119      */
120     @Override
121     public void deleteById(ID id) {
122         if (id == null) {
123             throw new IllegalArgumentException("id must not be null");
124         }
125         T removedEntity = entities.remove(id);
126         if (removedEntity == null) {
127             throw new EmptyResultDataAccessException("Entity with id " + id + " not found", 1);
128         }
129     }
130 
131     /**
132      * {@inheritDoc}
133      */
134     @Override
135     public void deleteAll(Iterable<? extends T> entitiesToDelete) {
136         if (entitiesToDelete == null) {
137             throw new IllegalArgumentException("entitiesToDelete must not be null");
138         }
139         for (T entity : entitiesToDelete) {
140             delete(entity);
141         }
142     }
143 
144     /**
145      * {@inheritDoc}
146      */
147     @Override
148     public void deleteAllById(Iterable<? extends ID> ids) {
149         for (ID id : ids) {
150             deleteById(id);
151         }
152     }
153 
154     /**
155      * {@inheritDoc}
156      */
157     @Override
158     public void delete(T entity) {
159         if (entity == null) {
160             throw new IllegalArgumentException("entity must not be null");
161         }
162         entities.remove(getId(entity));
163     }
164 
165     /**
166      * {@inheritDoc}
167      */
168     @Override
169     public void deleteAll() {
170         entities.clear();
171     }
172 
173     /**
174      * {@inheritDoc}
175      */
176     @Override
177     public boolean existsById(ID id) {
178         if (id == null) {
179             throw new IllegalArgumentException("id must not be null");
180         }
181         return entities.containsKey(id);
182     }
183 
184     /**
185      * {@inheritDoc}
186      */
187     @Override
188     public List<T> findAll() {
189         return new ArrayList<>(entities.values());
190     }
191 
192     /**
193      * {@inheritDoc}
194      */
195     @Override
196     public List<T> findAllById(Iterable<ID> ids) {
197         if (ids == null) {
198             throw new IllegalArgumentException("ids must not be null");
199         }
200         List<T> selectedEntities = new ArrayList<T>();
201         for (ID id : ids) {
202             Optional<T> optionalEntity = findById(id);
203             if (optionalEntity.isPresent()) {
204                 selectedEntities.add(optionalEntity.get());
205             }
206         }
207         return selectedEntities;
208     }
209 
210     /**
211      * {@inheritDoc}
212      */
213     @Override
214     public Optional<T> findById(ID id) {
215         if (id == null) {
216             throw new IllegalArgumentException("id must not be null");
217         }
218         T item = entities.get(id);
219         return Optional.ofNullable(item);
220     }
221 
222     /**
223      * {@inheritDoc}
224      */
225     @Override
226     public <S extends T> S save(S entity) {
227         if (entity == null) {
228             throw new IllegalArgumentException("entity must not be null");
229         }
230         ID id = getId(entity);
231         if (id == null) {
232             if (primaryKeyGenerator != null) {
233                 id = primaryKeyGenerator.nextPrimaryKey(maximumPrimaryKey());
234                 setId(entity, id);
235             } else {
236                 throw new IllegalStateException("Primary key is null and no primary key generator available: entity=" + entity);
237             }
238         }
239         entities.put(id, entity);
240         return entity;
241     }
242 
243     /**
244      * {@inheritDoc}
245      */
246     @Override
247     public <S extends T> List<S> saveAll(Iterable<S> entitiesToSave) {
248         if (entitiesToSave == null) {
249             throw new IllegalArgumentException("entitiesToSave must not be null");
250         }
251         List<S> savedEntities = new ArrayList<>();
252         for (S entity : entitiesToSave) {
253             savedEntities.add(save(entity));
254         }
255         return savedEntities;
256     }
257 
258     //
259     // PagingAndSortingRepository methods
260     //
261 
262     /**
263      * {@inheritDoc}
264      */
265     @Override
266     public List<T> findAll(Sort sort) {
267         return SortUtil.sort(findAll(), sort);
268     }
269 
270     /**
271      * {@inheritDoc}
272      */
273     @Override
274     public Page<T> findAll(Pageable pageable) {
275         List<T> allEntities = SortUtil.sort(findAll(), pageable.getSort());
276         int start = (int) pageable.getOffset();
277         int end = (start + pageable.getPageSize()) > allEntities.size() ? allEntities.size() : (start + pageable.getPageSize());
278         List<T> pagedEntities = start <= end ? allEntities.subList(start, end) : Collections.emptyList();
279         Page<T> page = new PageImpl<>(pagedEntities, pageable, allEntities.size());
280         return page;
281     }
282 
283     //
284     // Helper methods
285     //
286 
287     private ID maximumPrimaryKey() {
288         ID max = null;
289         for (T entity : findAll()) {
290             ID id = getId(entity);
291             if (max == null || id.compareTo(max) > 0) {
292                 max = id;
293             }
294         }
295         return max;
296     }
297 
298     /**
299      * Gives the value of the ID field or method of the given entity.
300      *
301      * @param entity the entity to examine, should not be {@code null}
302      *
303      * @return the value of the ID field or method of {@code entity}, may be {@code null}
304      */
305     @SuppressWarnings("unchecked")
306     protected @Nullable ID getId(T entity) {
307         ID id = null;
308         try {
309             if (getIdClass(entity).isPresent()) {
310                 // TODO: Handle IdClass with method annotations
311                 List<Field> idFields = getIdFields(entity);
312                 id = createIdClassInstance(entity, idFields);
313             } else if (getIdField(entity).isPresent()) {
314                 id = (ID) getIdField(entity).get().get(entity);
315             } else if (getIdMethod(entity).isPresent()) {
316                 id = (ID) getIdMethod(entity).get().invoke(entity);
317             } else {
318                 throw new IllegalArgumentException("Entity has no @Id annotation: " + entity);
319             }
320         } catch (ReflectiveOperationException e) {
321             throw new IllegalStateException(e);
322         }
323         return id;
324     }
325 
326     /**
327      * Sets the value of the ID field, or calls the ID setter method, for the given entity.
328      *
329      * @param entity the entity for which to set the ID
330      * @param id     the new ID value
331      */
332     protected void setId(T entity, ID id) {
333         try {
334             if (getIdClass(entity).isPresent()) {
335                 // TODO: Handle IdClass with method annotations
336                 List<Field> idFields = getIdFields(entity);
337                 setIdFieldsFromIdClassInstance(entity, idFields, id);
338             } else if (getIdField(entity).isPresent()) {
339                 getIdField(entity).get().set(entity, id);
340             } else if (getIdMethod(entity).isPresent()) {
341                 Method setMethod = getSetMethod(entity, id);
342                 setMethod.invoke(entity, id);
343             }
344         } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
345             throw new IllegalStateException(e);
346         }
347     }
348 
349     private Optional<Field> getIdField(T entity) {
350         List<Field> idFields = getIdFields(entity);
351         if (idFields.isEmpty()) {
352             return Optional.empty();
353         } else if (idFields.size() > 1) {
354             throw new IllegalStateException("Multiptle ID fields found in entity: " + entity);
355         } else {
356             return Optional.of(idFields.get(0));
357         }
358     }
359 
360     private List<Field> getIdFields(T entity) {
361         List<Field> idFields = new ArrayList<>();
362         Class<?> c = entity.getClass();
363         while (c != null) {
364             for (Field field : c.getDeclaredFields()) {
365                 if (isIdField(field)) {
366                     field.setAccessible(true);
367                     idFields.add(field);
368                 }
369             }
370             c = c.getSuperclass();
371         }
372         return idFields;
373     }
374 
375     private Optional<Method> getIdMethod(T entity) {
376         List<Method> idMethods = getIdMethods(entity);
377         if (idMethods.isEmpty()) {
378             return Optional.empty();
379         } else if (idMethods.size() > 1) {
380             throw new IllegalStateException("Multiptle ID methods found in entity: " + entity);
381         } else {
382             return Optional.of(idMethods.get(0));
383         }
384     }
385 
386     private List<Method> getIdMethods(T entity) {
387         List<Method> idMethods = new ArrayList<>();
388         Class<?> c = entity.getClass();
389         while (c != null) {
390             for (Method method : c.getDeclaredMethods()) {
391                 if (isIdMethod(method)) {
392                     method.setAccessible(true);
393                     idMethods.add(method);
394                 }
395             }
396             c = c.getSuperclass();
397         }
398         return idMethods;
399     }
400 
401     private Method getSetMethod(T entity, ID id) throws NoSuchMethodException {
402         Method getMethod = getIdMethod(entity)
403                 .orElseThrow(() -> new NoSuchMethodException("Get method for ID not found: entity=" + entity));
404         String setMethodName = getMethod.getName().replaceFirst("^get", "set");
405         Method setMethod = entity.getClass().getMethod(setMethodName, id.getClass());
406         setMethod.setAccessible(true);
407         return setMethod;
408     }
409 
410     private @Nullable ID createIdClassInstance(T entity, List<Field> idFields) throws ReflectiveOperationException {
411         Class<ID> idClass = getIdClass(entity).orElseThrow(() -> new IllegalStateException("ID class not found: entity=" + entity));
412         Constructor<ID> constructor = idClass.getDeclaredConstructor();
413         constructor.setAccessible(true);
414         ID id = constructor.newInstance();
415         for (Field idField : idFields) {
416             if (idField.get(entity) == null) {
417                 // If any of the ID fields is null, we say that the primary key is null.
418                 return null;
419             }
420             TestUtil.injectField(id, idField.getName(), idField.get(entity));
421         }
422         return id;
423     }
424 
425     private void setIdFieldsFromIdClassInstance(T entity, List<Field> idFields, ID id) {
426         for (Field idField : idFields) {
427             Object value = TestUtil.getFieldValue(id, idField.getName());
428             TestUtil.injectField(entity, idField.getName(), value);
429         }
430     }
431 
432     /**
433      * Override this in a concrete subclass to decide if a given field is an ID field of an entity.
434      *
435      * @param field the field to examine
436      *
437      * @return {@code true} if {@code field} is an ID field, {@code false} otherwise
438      */
439     protected abstract boolean isIdField(Field field);
440 
441     /**
442      * Override this in a concrete subclass to decide if a given method is a method giving the ID of an entity.
443      *
444      * @param method the method to examine
445      *
446      * @return {@code true} if {@code method} is an ID method, {@code false} otherwise
447      */
448     protected abstract boolean isIdMethod(Method method);
449 
450     /**
451      * Override this in concrete subclass to give the ID class representing a composite primary key for an entity, if any.
452      *
453      * @param entity the entity to examine
454      *
455      * @return the ID class representing the composite primary key of {@code entity}, or an empty optional if there is no such class
456      */
457     protected abstract Optional<Class<ID>> getIdClass(Object entity);
458 
459     @Override
460     public String toString() {
461         return getClass().getSimpleName() + "{entities=" + entities + "}";
462     }
463 
464     //
465     // QueryByExampleExecutor methods
466     //
467 
468     /**
469      * {@inheritDoc}
470      * <p>
471      * This method is not yet implemented, so it always throws an exception.
472      *
473      * @throws UnsupportedOperationException always
474      */
475     @Override
476     public <S extends T> Optional<S> findOne(Example<S> example) {
477         throw new UnsupportedOperationException("Not yet implemented");
478     }
479 
480     /**
481      * {@inheritDoc}
482      * <p>
483      * This method is not yet implemented, so it always throws an exception.
484      *
485      * @throws UnsupportedOperationException always
486      */
487     @Override
488     public <S extends T> List<S> findAll(Example<S> example) {
489         throw new UnsupportedOperationException("Not yet implemented");
490     }
491 
492     /**
493      * {@inheritDoc}
494      * <p>
495      * This method is not yet implemented, so it always throws an exception.
496      *
497      * @throws UnsupportedOperationException always
498      */
499     @Override
500     public <S extends T> Page<S> findAll(Example<S> example, Pageable pageable) {
501         throw new UnsupportedOperationException("Not yet implemented");
502     }
503 
504     /**
505      * {@inheritDoc}
506      * <p>
507      * This method is not yet implemented, so it always throws an exception.
508      *
509      * @throws UnsupportedOperationException always
510      */
511     @Override
512     public <S extends T> List<S> findAll(Example<S> example, Sort sort) {
513         throw new UnsupportedOperationException("Not yet implemented");
514     }
515 
516     /**
517      * {@inheritDoc}
518      * <p>
519      * This method is not yet implemented, so it always throws an exception.
520      *
521      * @throws UnsupportedOperationException always
522      */
523     @Override
524     public <S extends T> long count(Example<S> example) {
525         throw new UnsupportedOperationException("Not yet implemented");
526     }
527 
528     /**
529      * {@inheritDoc}
530      * <p>
531      * This method is not yet implemented, so it always throws an exception.
532      *
533      * @throws UnsupportedOperationException always
534      */
535     @Override
536     public <S extends T> boolean exists(Example<S> example) {
537         throw new UnsupportedOperationException("Not yet implemented");
538     }
539 
540     /**
541      * Make finalize method final to avoid "Finalizer attacks" and corresponding SpotBugs warning (CT_CONSTRUCTOR_THROW).
542      *
543      * @see <a href="https://wiki.sei.cmu.edu/confluence/display/java/OBJ11-J.+Be+wary+of+letting+constructors+throw+exceptions">
544      *      Explanation of finalizer attack</a>
545      */
546     @Override
547     @SuppressWarnings({ "checkstyle:NoFinalizer", "PMD.EmptyFinalizer", "PMD.EmptyMethodInAbstractClassShouldBeAbstract" })
548     protected final void finalize() throws Throwable {
549         // Do nothing
550     }
551 
552 }