1 package com.reallifedeveloper.tools.test.database.inmemory;
2
3 import java.lang.reflect.Constructor;
4 import java.lang.reflect.Field;
5 import java.lang.reflect.InvocationTargetException;
6 import java.lang.reflect.Method;
7 import java.util.ArrayList;
8 import java.util.Collections;
9 import java.util.HashMap;
10 import java.util.List;
11 import java.util.Map;
12 import java.util.Objects;
13 import java.util.Optional;
14
15 import org.checkerframework.checker.nullness.qual.NonNull;
16 import org.checkerframework.checker.nullness.qual.Nullable;
17 import org.springframework.dao.EmptyResultDataAccessException;
18 import org.springframework.data.domain.Example;
19 import org.springframework.data.domain.Page;
20 import org.springframework.data.domain.PageImpl;
21 import org.springframework.data.domain.Pageable;
22 import org.springframework.data.domain.Sort;
23 import org.springframework.data.repository.CrudRepository;
24 import org.springframework.data.repository.PagingAndSortingRepository;
25 import org.springframework.data.repository.query.QueryByExampleExecutor;
26
27 import com.reallifedeveloper.tools.test.TestUtil;
28
29
30
31
32
33
34
35
36
37
38
39
40 @SuppressWarnings({ "PMD", "checkstyle:noReturnNull" })
41 public abstract class AbstractInMemoryCrudRepository<T, ID extends Comparable<ID>>
42 implements CrudRepository<T, ID>, PagingAndSortingRepository<T, ID>, QueryByExampleExecutor<T> {
43
44 private final Map<@NonNull ID, @NonNull T> entities = new HashMap<>();
45
46 private final @Nullable PrimaryKeyGenerator<ID> primaryKeyGenerator;
47
48
49
50
51
52 public AbstractInMemoryCrudRepository() {
53 this.primaryKeyGenerator = null;
54 }
55
56
57
58
59
60
61
62 public AbstractInMemoryCrudRepository(PrimaryKeyGenerator<ID> primaryKeyGenerator) {
63 if (primaryKeyGenerator == null) {
64 throw new IllegalArgumentException("primaryKeyGenerator must not be null");
65 }
66 this.primaryKeyGenerator = primaryKeyGenerator;
67 }
68
69
70
71
72
73
74
75
76
77
78
79 protected <F> List<@NonNull T> findByField(String fieldName, F value) {
80 if (fieldName == null) {
81 throw new IllegalArgumentException("fieldName must not be null");
82 }
83 return entities.values().stream().filter(entity -> Objects.equals(value, TestUtil.getFieldValue(entity, fieldName))).toList();
84 }
85
86
87
88
89
90
91
92
93
94
95
96
97 protected <F> Optional<T> findByUniqueField(String fieldName, F value) {
98 List<T> foundEntities = findByField(fieldName, value);
99 if (foundEntities.isEmpty()) {
100 return Optional.empty();
101 } else if (foundEntities.size() == 1) {
102 return Optional.of(foundEntities.get(0));
103 } else {
104 throw new IllegalArgumentException(
105 "Field " + fieldName + " is not unique, found " + foundEntities.size() + " entities: " + foundEntities);
106 }
107 }
108
109
110
111
112 @Override
113 public long count() {
114 return entities.size();
115 }
116
117
118
119
120 @Override
121 public void deleteById(ID id) {
122 if (id == null) {
123 throw new IllegalArgumentException("id must not be null");
124 }
125 T removedEntity = entities.remove(id);
126 if (removedEntity == null) {
127 throw new EmptyResultDataAccessException("Entity with id " + id + " not found", 1);
128 }
129 }
130
131
132
133
134 @Override
135 public void deleteAll(Iterable<? extends T> entitiesToDelete) {
136 if (entitiesToDelete == null) {
137 throw new IllegalArgumentException("entitiesToDelete must not be null");
138 }
139 for (T entity : entitiesToDelete) {
140 delete(entity);
141 }
142 }
143
144
145
146
147 @Override
148 public void deleteAllById(Iterable<? extends ID> ids) {
149 for (ID id : ids) {
150 deleteById(id);
151 }
152 }
153
154
155
156
157 @Override
158 public void delete(T entity) {
159 if (entity == null) {
160 throw new IllegalArgumentException("entity must not be null");
161 }
162 entities.remove(getId(entity));
163 }
164
165
166
167
168 @Override
169 public void deleteAll() {
170 entities.clear();
171 }
172
173
174
175
176 @Override
177 public boolean existsById(ID id) {
178 if (id == null) {
179 throw new IllegalArgumentException("id must not be null");
180 }
181 return entities.containsKey(id);
182 }
183
184
185
186
187 @Override
188 public List<T> findAll() {
189 return new ArrayList<>(entities.values());
190 }
191
192
193
194
195 @Override
196 public List<T> findAllById(Iterable<ID> ids) {
197 if (ids == null) {
198 throw new IllegalArgumentException("ids must not be null");
199 }
200 List<T> selectedEntities = new ArrayList<T>();
201 for (ID id : ids) {
202 Optional<T> optionalEntity = findById(id);
203 if (optionalEntity.isPresent()) {
204 selectedEntities.add(optionalEntity.get());
205 }
206 }
207 return selectedEntities;
208 }
209
210
211
212
213 @Override
214 public Optional<T> findById(ID id) {
215 if (id == null) {
216 throw new IllegalArgumentException("id must not be null");
217 }
218 T item = entities.get(id);
219 return Optional.ofNullable(item);
220 }
221
222
223
224
225 @Override
226 public <S extends T> S save(S entity) {
227 if (entity == null) {
228 throw new IllegalArgumentException("entity must not be null");
229 }
230 ID id = getId(entity);
231 if (id == null) {
232 if (primaryKeyGenerator != null) {
233 id = primaryKeyGenerator.nextPrimaryKey(maximumPrimaryKey());
234 setId(entity, id);
235 } else {
236 throw new IllegalStateException("Primary key is null and no primary key generator available: entity=" + entity);
237 }
238 }
239 entities.put(id, entity);
240 return entity;
241 }
242
243
244
245
246 @Override
247 public <S extends T> List<S> saveAll(Iterable<S> entitiesToSave) {
248 if (entitiesToSave == null) {
249 throw new IllegalArgumentException("entitiesToSave must not be null");
250 }
251 List<S> savedEntities = new ArrayList<>();
252 for (S entity : entitiesToSave) {
253 savedEntities.add(save(entity));
254 }
255 return savedEntities;
256 }
257
258
259
260
261
262
263
264
265 @Override
266 public List<T> findAll(Sort sort) {
267 return SortUtil.sort(findAll(), sort);
268 }
269
270
271
272
273 @Override
274 public Page<T> findAll(Pageable pageable) {
275 List<T> allEntities = SortUtil.sort(findAll(), pageable.getSort());
276 int start = (int) pageable.getOffset();
277 int end = (start + pageable.getPageSize()) > allEntities.size() ? allEntities.size() : (start + pageable.getPageSize());
278 List<T> pagedEntities = start <= end ? allEntities.subList(start, end) : Collections.emptyList();
279 Page<T> page = new PageImpl<>(pagedEntities, pageable, allEntities.size());
280 return page;
281 }
282
283
284
285
286
287 private ID maximumPrimaryKey() {
288 ID max = null;
289 for (T entity : findAll()) {
290 ID id = getId(entity);
291 if (max == null || id.compareTo(max) > 0) {
292 max = id;
293 }
294 }
295 return max;
296 }
297
298
299
300
301
302
303
304
305 @SuppressWarnings("unchecked")
306 protected @Nullable ID getId(T entity) {
307 ID id = null;
308 try {
309 if (getIdClass(entity).isPresent()) {
310
311 List<Field> idFields = getIdFields(entity);
312 id = createIdClassInstance(entity, idFields);
313 } else if (getIdField(entity).isPresent()) {
314 id = (ID) getIdField(entity).get().get(entity);
315 } else if (getIdMethod(entity).isPresent()) {
316 id = (ID) getIdMethod(entity).get().invoke(entity);
317 } else {
318 throw new IllegalArgumentException("Entity has no @Id annotation: " + entity);
319 }
320 } catch (ReflectiveOperationException e) {
321 throw new IllegalStateException(e);
322 }
323 return id;
324 }
325
326
327
328
329
330
331
332 protected void setId(T entity, ID id) {
333 try {
334 if (getIdClass(entity).isPresent()) {
335
336 List<Field> idFields = getIdFields(entity);
337 setIdFieldsFromIdClassInstance(entity, idFields, id);
338 } else if (getIdField(entity).isPresent()) {
339 getIdField(entity).get().set(entity, id);
340 } else if (getIdMethod(entity).isPresent()) {
341 Method setMethod = getSetMethod(entity, id);
342 setMethod.invoke(entity, id);
343 }
344 } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
345 throw new IllegalStateException(e);
346 }
347 }
348
349 private Optional<Field> getIdField(T entity) {
350 List<Field> idFields = getIdFields(entity);
351 if (idFields.isEmpty()) {
352 return Optional.empty();
353 } else if (idFields.size() > 1) {
354 throw new IllegalStateException("Multiptle ID fields found in entity: " + entity);
355 } else {
356 return Optional.of(idFields.get(0));
357 }
358 }
359
360 private List<Field> getIdFields(T entity) {
361 List<Field> idFields = new ArrayList<>();
362 Class<?> c = entity.getClass();
363 while (c != null) {
364 for (Field field : c.getDeclaredFields()) {
365 if (isIdField(field)) {
366 field.setAccessible(true);
367 idFields.add(field);
368 }
369 }
370 c = c.getSuperclass();
371 }
372 return idFields;
373 }
374
375 private Optional<Method> getIdMethod(T entity) {
376 List<Method> idMethods = getIdMethods(entity);
377 if (idMethods.isEmpty()) {
378 return Optional.empty();
379 } else if (idMethods.size() > 1) {
380 throw new IllegalStateException("Multiptle ID methods found in entity: " + entity);
381 } else {
382 return Optional.of(idMethods.get(0));
383 }
384 }
385
386 private List<Method> getIdMethods(T entity) {
387 List<Method> idMethods = new ArrayList<>();
388 Class<?> c = entity.getClass();
389 while (c != null) {
390 for (Method method : c.getDeclaredMethods()) {
391 if (isIdMethod(method)) {
392 method.setAccessible(true);
393 idMethods.add(method);
394 }
395 }
396 c = c.getSuperclass();
397 }
398 return idMethods;
399 }
400
401 private Method getSetMethod(T entity, ID id) throws NoSuchMethodException {
402 Method getMethod = getIdMethod(entity)
403 .orElseThrow(() -> new NoSuchMethodException("Get method for ID not found: entity=" + entity));
404 String setMethodName = getMethod.getName().replaceFirst("^get", "set");
405 Method setMethod = entity.getClass().getMethod(setMethodName, id.getClass());
406 setMethod.setAccessible(true);
407 return setMethod;
408 }
409
410 private @Nullable ID createIdClassInstance(T entity, List<Field> idFields) throws ReflectiveOperationException {
411 Class<ID> idClass = getIdClass(entity).orElseThrow(() -> new IllegalStateException("ID class not found: entity=" + entity));
412 Constructor<ID> constructor = idClass.getDeclaredConstructor();
413 constructor.setAccessible(true);
414 ID id = constructor.newInstance();
415 for (Field idField : idFields) {
416 if (idField.get(entity) == null) {
417
418 return null;
419 }
420 TestUtil.injectField(id, idField.getName(), idField.get(entity));
421 }
422 return id;
423 }
424
425 private void setIdFieldsFromIdClassInstance(T entity, List<Field> idFields, ID id) {
426 for (Field idField : idFields) {
427 Object value = TestUtil.getFieldValue(id, idField.getName());
428 TestUtil.injectField(entity, idField.getName(), value);
429 }
430 }
431
432
433
434
435
436
437
438
439 protected abstract boolean isIdField(Field field);
440
441
442
443
444
445
446
447
448 protected abstract boolean isIdMethod(Method method);
449
450
451
452
453
454
455
456
457 protected abstract Optional<Class<ID>> getIdClass(Object entity);
458
459 @Override
460 public String toString() {
461 return getClass().getSimpleName() + "{entities=" + entities + "}";
462 }
463
464
465
466
467
468
469
470
471
472
473
474
475 @Override
476 public <S extends T> Optional<S> findOne(Example<S> example) {
477 throw new UnsupportedOperationException("Not yet implemented");
478 }
479
480
481
482
483
484
485
486
487 @Override
488 public <S extends T> List<S> findAll(Example<S> example) {
489 throw new UnsupportedOperationException("Not yet implemented");
490 }
491
492
493
494
495
496
497
498
499 @Override
500 public <S extends T> Page<S> findAll(Example<S> example, Pageable pageable) {
501 throw new UnsupportedOperationException("Not yet implemented");
502 }
503
504
505
506
507
508
509
510
511 @Override
512 public <S extends T> List<S> findAll(Example<S> example, Sort sort) {
513 throw new UnsupportedOperationException("Not yet implemented");
514 }
515
516
517
518
519
520
521
522
523 @Override
524 public <S extends T> long count(Example<S> example) {
525 throw new UnsupportedOperationException("Not yet implemented");
526 }
527
528
529
530
531
532
533
534
535 @Override
536 public <S extends T> boolean exists(Example<S> example) {
537 throw new UnsupportedOperationException("Not yet implemented");
538 }
539
540
541
542
543
544
545
546 @Override
547 @SuppressWarnings({ "checkstyle:NoFinalizer", "PMD.EmptyFinalizer", "PMD.EmptyMethodInAbstractClassShouldBeAbstract" })
548 protected final void finalize() throws Throwable {
549
550 }
551
552 }