1 package com.reallifedeveloper.tools.test.database.inmemory;
2
3 import java.lang.reflect.Constructor;
4 import java.lang.reflect.Field;
5 import java.lang.reflect.InvocationTargetException;
6 import java.lang.reflect.Method;
7 import java.util.ArrayList;
8 import java.util.Collections;
9 import java.util.HashMap;
10 import java.util.List;
11 import java.util.Map;
12 import java.util.Objects;
13 import java.util.Optional;
14
15 import org.checkerframework.checker.nullness.qual.NonNull;
16 import org.checkerframework.checker.nullness.qual.Nullable;
17 import org.springframework.dao.EmptyResultDataAccessException;
18 import org.springframework.data.domain.Example;
19 import org.springframework.data.domain.Page;
20 import org.springframework.data.domain.PageImpl;
21 import org.springframework.data.domain.Pageable;
22 import org.springframework.data.domain.Sort;
23 import org.springframework.data.repository.CrudRepository;
24 import org.springframework.data.repository.PagingAndSortingRepository;
25 import org.springframework.data.repository.query.QueryByExampleExecutor;
26
27 import com.reallifedeveloper.tools.test.TestUtil;
28
29
30
31
32
33
34
35
36
37
38
39
40 @SuppressWarnings({ "PMD", "checkstyle:noReturnNull" })
41 public abstract class AbstractInMemoryCrudRepository<T, ID extends Comparable<ID>>
42 implements CrudRepository<T, ID>, PagingAndSortingRepository<T, ID>, QueryByExampleExecutor<T> {
43
44 private final Map<@NonNull ID, @NonNull T> entities = new HashMap<>();
45
46 private final @Nullable PrimaryKeyGenerator<ID> primaryKeyGenerator;
47
48
49
50
51
52 public AbstractInMemoryCrudRepository() {
53 this.primaryKeyGenerator = null;
54 }
55
56
57
58
59
60
61
62 public AbstractInMemoryCrudRepository(PrimaryKeyGenerator<ID> primaryKeyGenerator) {
63 if (primaryKeyGenerator == null) {
64 throw new IllegalArgumentException("primaryKeyGenerator must not be null");
65 }
66 this.primaryKeyGenerator = primaryKeyGenerator;
67 }
68
69
70
71
72
73
74
75
76
77
78
79 protected <F> List<@NonNull T> findByField(String fieldName, F value) {
80 if (fieldName == null) {
81 throw new IllegalArgumentException("fieldName must not be null");
82 }
83 return entities.values().stream().filter(entity -> Objects.equals(value, TestUtil.getFieldValue(entity, fieldName))).toList();
84 }
85
86
87
88
89
90
91
92
93
94
95
96
97 protected <F> Optional<T> findByUniqueField(String fieldName, F value) {
98 List<T> foundEntities = findByField(fieldName, value);
99 if (foundEntities.isEmpty()) {
100 return Optional.empty();
101 } else if (foundEntities.size() == 1) {
102 return Optional.of(foundEntities.get(0));
103 } else {
104 throw new IllegalArgumentException(
105 "Field " + fieldName + " is not unique, found " + foundEntities.size() + " entities: " + foundEntities);
106 }
107 }
108
109
110
111
112 @Override
113 public long count() {
114 return entities.size();
115 }
116
117
118
119
120 @Override
121 public void deleteById(ID id) {
122 if (id == null) {
123 throw new IllegalArgumentException("id must not be null");
124 }
125 T removedEntity = entities.remove(id);
126 if (removedEntity == null) {
127 throw new EmptyResultDataAccessException("Entity with id " + id + " not found", 1);
128 }
129 }
130
131
132
133
134 @Override
135 public void deleteAll(Iterable<? extends T> entitiesToDelete) {
136 if (entitiesToDelete == null) {
137 throw new IllegalArgumentException("entitiesToDelete must not be null");
138 }
139 for (T entity : entitiesToDelete) {
140 delete(entity);
141 }
142 }
143
144
145
146
147 @Override
148 public void deleteAllById(Iterable<? extends ID> ids) {
149 for (ID id : ids) {
150 deleteById(id);
151 }
152 }
153
154
155
156
157 @Override
158 public void delete(T entity) {
159 if (entity == null) {
160 throw new IllegalArgumentException("entity must not be null");
161 }
162 entities.remove(getId(entity));
163 }
164
165
166
167
168 @Override
169 public void deleteAll() {
170 entities.clear();
171 }
172
173
174
175
176 @Override
177 public boolean existsById(ID id) {
178 if (id == null) {
179 throw new IllegalArgumentException("id must not be null");
180 }
181 return entities.containsKey(id);
182 }
183
184
185
186
187 @Override
188 public List<T> findAll() {
189 return new ArrayList<>(entities.values());
190 }
191
192
193
194
195 @Override
196 public List<T> findAllById(Iterable<ID> ids) {
197 if (ids == null) {
198 throw new IllegalArgumentException("ids must not be null");
199 }
200 List<T> selectedEntities = new ArrayList<T>();
201 for (ID id : ids) {
202 Optional<T> optionalEntity = findById(id);
203 if (optionalEntity.isPresent()) {
204 selectedEntities.add(optionalEntity.get());
205 }
206 }
207 return selectedEntities;
208 }
209
210
211
212
213 @Override
214 public Optional<T> findById(ID id) {
215 if (id == null) {
216 throw new IllegalArgumentException("id must not be null");
217 }
218 T item = entities.get(id);
219 return Optional.ofNullable(item);
220 }
221
222
223
224
225 @Override
226 public <S extends T> S save(S entity) {
227 if (entity == null) {
228 throw new IllegalArgumentException("entity must not be null");
229 }
230 ID id = getId(entity);
231 if (id == null) {
232 if (primaryKeyGenerator != null) {
233 id = primaryKeyGenerator.nextPrimaryKey(maximumPrimaryKey());
234 setId(entity, id);
235 } else {
236 throw new IllegalStateException("Primary key is null and no primary key generator available: entity=" + entity);
237 }
238 }
239 entities.put(id, entity);
240 return entity;
241 }
242
243
244
245
246 @Override
247 public <S extends T> List<S> saveAll(Iterable<S> entitiesToSave) {
248 if (entitiesToSave == null) {
249 throw new IllegalArgumentException("entitiesToSave must not be null");
250 }
251 List<S> savedEntities = new ArrayList<>();
252 for (S entity : entitiesToSave) {
253 savedEntities.add(save(entity));
254 }
255 return savedEntities;
256 }
257
258
259
260
261
262
263
264
265 @Override
266 public List<T> findAll(Sort sort) {
267 return SortUtil.sort(findAll(), sort);
268 }
269
270
271
272
273 @Override
274 public Page<T> findAll(Pageable pageable) {
275 List<T> allEntities = SortUtil.sort(findAll(), pageable.getSort());
276 int start = (int) pageable.getOffset();
277 int end = (start + pageable.getPageSize()) > allEntities.size() ? allEntities.size() : (start + pageable.getPageSize());
278 List<T> pagedEntities = start <= end ? allEntities.subList(start, end) : Collections.emptyList();
279 Page<T> page = new PageImpl<>(pagedEntities, pageable, allEntities.size());
280 return page;
281 }
282
283
284
285
286
287
288
289
290
291
292
293
294 private @Nullable ID maximumPrimaryKey() {
295 ID max = null;
296 for (T entity : findAll()) {
297 ID id = getId(entity);
298 if (max == null || (id != null && id.compareTo(max) > 0)) {
299 max = id;
300 }
301 }
302 return max;
303 }
304
305
306
307
308
309
310
311
312 @SuppressWarnings("unchecked")
313 protected @Nullable ID getId(T entity) {
314 ID id = null;
315 try {
316 if (getCompositeIdClass(entity).isPresent()) {
317
318 List<Field> idFields = getIdFields(entity);
319 id = createIdClassInstance(entity, idFields);
320 } else if (getIdField(entity).isPresent()) {
321 id = (ID) getIdField(entity).get().get(entity);
322 } else if (getIdMethod(entity).isPresent()) {
323 id = (ID) getIdMethod(entity).get().invoke(entity);
324 } else {
325 throw new IllegalArgumentException("Entity has no @Id annotation: " + entity);
326 }
327 } catch (ReflectiveOperationException e) {
328 throw new IllegalStateException(e);
329 }
330 return id;
331 }
332
333
334
335
336
337
338
339 protected void setId(T entity, @Nullable ID id) {
340 try {
341 if (getCompositeIdClass(entity).isPresent()) {
342
343 List<Field> idFields = getIdFields(entity);
344 if (id == null) {
345 for (Field idField : idFields) {
346 idField.set(entity, null);
347 }
348 } else {
349 setIdFieldsFromIdClassInstance(entity, idFields, id);
350 }
351 } else if (getIdField(entity).isPresent()) {
352 getIdField(entity).get().set(entity, id);
353 } else if (getIdMethod(entity).isPresent()) {
354 @SuppressWarnings("unchecked")
355 Class<ID> idClass = (Class<ID>) getIdMethod(entity).get().getReturnType();
356 Method setIdMethod = getSetMethod(entity, idClass);
357 setIdMethod.invoke(entity, id);
358 }
359 } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
360 throw new IllegalStateException(e);
361 }
362 }
363
364 private Optional<Field> getIdField(T entity) {
365 List<Field> idFields = getIdFields(entity);
366 if (idFields.isEmpty()) {
367 return Optional.empty();
368 } else if (idFields.size() > 1) {
369 throw new IllegalStateException("Multiptle ID fields found in entity: " + entity);
370 } else {
371 return Optional.of(idFields.get(0));
372 }
373 }
374
375 private List<Field> getIdFields(T entity) {
376 List<Field> idFields = new ArrayList<>();
377 Class<?> c = entity.getClass();
378 while (c != null) {
379 for (Field field : c.getDeclaredFields()) {
380 if (isIdField(field)) {
381 field.setAccessible(true);
382 idFields.add(field);
383 }
384 }
385 c = c.getSuperclass();
386 }
387 return idFields;
388 }
389
390 private Optional<Method> getIdMethod(T entity) {
391 List<Method> idMethods = getIdMethods(entity);
392 if (idMethods.isEmpty()) {
393 return Optional.empty();
394 } else if (idMethods.size() > 1) {
395 throw new IllegalStateException("Multiptle ID methods found in entity: " + entity);
396 } else {
397 return Optional.of(idMethods.get(0));
398 }
399 }
400
401 private List<Method> getIdMethods(T entity) {
402 List<Method> idMethods = new ArrayList<>();
403 Class<?> c = entity.getClass();
404 while (c != null) {
405 for (Method method : c.getDeclaredMethods()) {
406 if (isIdMethod(method)) {
407 method.setAccessible(true);
408 idMethods.add(method);
409 }
410 }
411 c = c.getSuperclass();
412 }
413 return idMethods;
414 }
415
416 private Method getSetMethod(T entity, Class<ID> idClass) throws NoSuchMethodException {
417 Method getMethod = getIdMethod(entity)
418 .orElseThrow(() -> new NoSuchMethodException("Get method for ID not found: entity=" + entity));
419 String setMethodName = getMethod.getName().replaceFirst("^get", "set");
420 Method setMethod = entity.getClass().getMethod(setMethodName, idClass);
421 setMethod.setAccessible(true);
422 return setMethod;
423 }
424
425 private @Nullable ID createIdClassInstance(T entity, List<Field> idFields) throws ReflectiveOperationException {
426 Class<ID> idClass = getCompositeIdClass(entity)
427 .orElseThrow(() -> new IllegalStateException("Composibte ID class not found: entity=" + entity));
428 Constructor<ID> constructor = idClass.getDeclaredConstructor();
429 constructor.setAccessible(true);
430 ID id = constructor.newInstance();
431 for (Field idField : idFields) {
432 if (idField.get(entity) == null) {
433
434 return null;
435 }
436 TestUtil.injectField(id, idField.getName(), idField.get(entity));
437 }
438 return id;
439 }
440
441 private void setIdFieldsFromIdClassInstance(T entity, List<Field> idFields, ID id) {
442 for (Field idField : idFields) {
443 Object value = TestUtil.getFieldValue(id, idField.getName());
444 TestUtil.injectField(entity, idField.getName(), value);
445 }
446 }
447
448
449
450
451
452
453
454
455 protected abstract boolean isIdField(Field field);
456
457
458
459
460
461
462
463
464 protected abstract boolean isIdMethod(Method method);
465
466
467
468
469
470
471
472
473 protected abstract Optional<Class<ID>> getCompositeIdClass(Object entity);
474
475 @Override
476 public String toString() {
477 return getClass().getSimpleName() + "{entities=" + entities + "}";
478 }
479
480
481
482
483
484
485
486
487
488
489
490
491 @Override
492 public <S extends T> Optional<S> findOne(Example<S> example) {
493 throw new UnsupportedOperationException("Not yet implemented");
494 }
495
496
497
498
499
500
501
502
503 @Override
504 public <S extends T> List<S> findAll(Example<S> example) {
505 throw new UnsupportedOperationException("Not yet implemented");
506 }
507
508
509
510
511
512
513
514
515 @Override
516 public <S extends T> Page<S> findAll(Example<S> example, Pageable pageable) {
517 throw new UnsupportedOperationException("Not yet implemented");
518 }
519
520
521
522
523
524
525
526
527 @Override
528 public <S extends T> List<S> findAll(Example<S> example, Sort sort) {
529 throw new UnsupportedOperationException("Not yet implemented");
530 }
531
532
533
534
535
536
537
538
539 @Override
540 public <S extends T> long count(Example<S> example) {
541 throw new UnsupportedOperationException("Not yet implemented");
542 }
543
544
545
546
547
548
549
550
551 @Override
552 public <S extends T> boolean exists(Example<S> example) {
553 throw new UnsupportedOperationException("Not yet implemented");
554 }
555
556
557
558
559
560
561
562 @Override
563 @SuppressWarnings({ "deprecation", "removal", "Finalize", "checkstyle:NoFinalizer", "PMD.EmptyFinalizer",
564 "PMD.EmptyMethodInAbstractClassShouldBeAbstract" })
565 protected final void finalize() throws Throwable {
566
567 }
568
569 }