src/test/java/ru/indvdum/jpa/tests/AbstractJPAEntityTest.groovy
author indvdum (gotoindvdum[at]gmail[dot]com)
Sat, 24 Nov 2012 15:58:22 +0400
changeset 14 22646fd200c5
parent 10 7241826f43f6
child 15 7d8a7e7635d2
permissions -rw-r--r--
Support of auto-generating Enum field types
     1 package ru.indvdum.jpa.tests
     2 ;
     3 
     4 import static org.junit.Assert.*
     5 
     6 import java.lang.reflect.Field
     7 import java.lang.reflect.Method
     8 import java.lang.reflect.Modifier
     9 
    10 import javax.persistence.EmbeddedId
    11 import javax.persistence.Entity
    12 import javax.persistence.GeneratedValue
    13 import javax.persistence.Id
    14 import javax.persistence.IdClass
    15 import javax.persistence.Transient
    16 
    17 import org.junit.Test
    18 
    19 import ru.indvdum.jpa.dao.JPADataAccessObject
    20 import ru.indvdum.jpa.entities.AbstractEntity;
    21 
    22 
    23 /**
    24  * JUnit test case for testing of a creating, listing, updating and removing 
    25  * operations with database of the JPA entities.
    26  * 
    27  * @author 	indvdum (gotoindvdum@gmail.com)
    28  * @since 23.12.2011 22:54:26
    29  *
    30  */
    31 abstract class AbstractJPAEntityTest extends AbstractJPATest {
    32 
    33 	protected def uniqueValue = 1
    34 	protected JPADataAccessObject dao = null
    35 	protected Set toRemove = new HashSet()
    36 	protected Map<Class, Integer> existedEntitiesCount = [:]
    37 
    38 	@Test
    39 	public void testEntity() {
    40 		dao = createDAO()
    41 		try {
    42 			testEntity(getEntityClass())
    43 		} catch (Throwable t) {
    44 			dao.rollback()
    45 			throw t
    46 		} finally {
    47 			dao.close()		
    48 		}
    49 	}
    50 	
    51 	/**
    52 	 * @return Your implementation of JPADataAccessObject
    53 	 */
    54 	protected abstract JPADataAccessObject createDAO();
    55 
    56 	/**
    57 	 * Test creating, listing, updating, and removing of an {@code entityClass} object
    58 	 * 
    59 	 * @param entityClass
    60 	 */
    61 	protected void testEntity(Class entityClass) {
    62 		
    63 		assert AbstractEntity.class.isAssignableFrom(entityClass)
    64 
    65 		Map<String, Object> fieldsValues
    66 		def dbEntity
    67 
    68 		// creating
    69 		def entity = createEntity(entityClass)
    70 		testCreatedEntity(entity)
    71 
    72 		// listing
    73 		toRemove.each {
    74 			assert dao.list(it.class).each { it2 ->
    75 				assert it.class.isAssignableFrom(it2.class)
    76 			}.size() == toRemove.findAll { it2 ->
    77 				it.class.isAssignableFrom(it2.class)
    78 			}.size() + existedEntitiesCount[it.class]
    79 		}
    80 
    81 		// updating
    82 		fieldsValues = updateFields(entity)
    83 		assert dao.persist(entity)
    84 		assert dao.contains(entity)
    85 		dbEntity = dao.find(entityClass, (entity as AbstractEntity).getIdentifierValue())
    86 		assert dbEntity != null
    87 		assert checkEntityFieldValues(dbEntity, fieldsValues)
    88 		testUpdatedEntity(entity)
    89 
    90 		// removing
    91 		assert dao.remove(toRemove)
    92 		assert !dao.contains(toRemove)
    93 		testRemovedEntity(entity)
    94 	}
    95 	
    96 	/**
    97 	 * Create and persist to database an {@code entityClass} object
    98 	 * 
    99 	 * @param entityClass
   100 	 * @return created entity object
   101 	 */
   102 	protected Object createEntity(Class entityClass) {
   103 		assertNotNull entityClass.annotations.find {it instanceof Entity}
   104 		def entity = entityClass.newInstance()
   105 		assert entity.class == entityClass
   106 		toRemove.add(entity)
   107 		if (!existedEntitiesCount.containsKey(entityClass))
   108 			existedEntitiesCount[entityClass] = dao.list(entityClass).size()
   109 		
   110 		Map<String, Object> fieldsValues = updateFields(entity)
   111 		assert dao.persist(entity)
   112 		assert dao.contains(entity)
   113 		def dbEntity = dao.find(entityClass, (entity as AbstractEntity).getIdentifierValue())
   114 		assert dbEntity != null
   115 		assert checkEntityFieldValues(dbEntity, fieldsValues)
   116 		
   117 		return entity
   118 	}
   119 
   120 	/**
   121 	 * Update all fields of an {@code entity} object
   122 	 * @param entity
   123 	 * @return generated field values
   124 	 */
   125 	protected Map<String, Object> updateFields(Object entity) {
   126 		Map<String, Object> fieldsValues = new HashMap<String, Object>()
   127 		getFields(entity).grep {
   128 			it.annotations.find {
   129 				it instanceof Transient || it instanceof GeneratedValue
   130 			} == null
   131 		}.each {
   132 			def newValue = generateFieldValue(entity, it)
   133 			setFieldValueBySetter(entity, it, newValue)
   134 			fieldsValues.put(it.name, newValue)
   135 		}
   136 		return fieldsValues
   137 	}
   138 	
   139 	/**
   140 	 * Collections and arrays will not be processed
   141 	 * 
   142 	 * @param entity
   143 	 * @param field
   144 	 * @return generated field value
   145 	 */
   146 	protected Object generateFieldValue(Object entity, Field field) {
   147 		def type = field.getType()
   148 		def newValue
   149 		if(type.toString() == 'boolean' || type == Boolean.class) {
   150 			newValue = (boolean) (uniqueValue++ % 2i == 0i)
   151 		} else if(type.toString() == 'byte' || type == Byte) {
   152 			newValue = (byte) (uniqueValue++ % Byte.MAX_VALUE + 1i)
   153 		} else if(type.toString() == 'char' || type == Character) {
   154 			newValue = (char) (uniqueValue++ % (int) Character.MAX_VALUE + 1i)
   155 		} else if(type.toString() == 'short' || type == Short) {
   156 			newValue = (short) (uniqueValue++ % Short.MAX_VALUE + 1i)
   157 		} else if(type.toString() == 'int' || type == Integer) {
   158 			newValue = (int) (uniqueValue++ % Integer.MAX_VALUE + 1i)
   159 		} else if(type.toString() == 'long' || type == Long) {
   160 			newValue = (long) uniqueValue++ % Long.MAX_VALUE + 1L
   161 		} else if(type.toString() == 'float' || type == Float) {
   162 			newValue = (float) uniqueValue++ % Float.MAX_VALUE + 1f
   163 		} else if(type.toString() == 'double' || type == Double) {
   164 			newValue = (double) uniqueValue++ % Double.MAX_VALUE + 1d
   165 		} else if(type == String.class) {
   166 			newValue = (String) "test${uniqueValue++}"
   167 		} else if(type instanceof Class && (type as Class).annotations.find {it instanceof Entity} != null) { // modifying of a primary keys is deprecated
   168 			// an attempt to use already created entities
   169 			def currentValue = getFieldValue(entity, field)
   170 			newValue = toRemove.find {it.class == type && it != currentValue}
   171 			if(newValue == null)
   172 				newValue = createEntity(type as Class)
   173 		} else if(Enum.class.isAssignableFrom(type)) {
   174 			def values = type.values();
   175 			newValue = values[uniqueValue++ % values.size()];
   176 		} else if(
   177 				type instanceof Class 
   178 				&& !(
   179 					field.clazz.annotations.find {it instanceof IdClass} != null 
   180 					&& field.declaredAnnotations.find {it instanceof Id} != null
   181 				)
   182 				&& field.declaredAnnotations.find {it instanceof EmbeddedId} == null
   183 				&& !Collection.class.isAssignableFrom(type)
   184 				&& !(type as Class).isArray()
   185 			) { // modifying of a primary keys is deprecated
   186 			newValue = (type as Class).newInstance()
   187 		} else {
   188 			newValue = getFieldValue(entity, field)
   189 		}
   190 		return newValue
   191 	}
   192 
   193 	/**
   194 	 * @param entity
   195 	 * @param rightValues
   196 	 * @return {@code true}, if all entity fields is equals to {@code rightValues}
   197 	 */
   198 	protected boolean checkEntityFieldValues(Object entity, Map<String, Object> rightValues) {
   199 		def result = true
   200 		def fields = getFields(entity)
   201 		rightValues.each {  key, value ->
   202 			fields.find { it.name == key }.each {
   203 				if(!getFieldValue(entity, it).equals(value)) {
   204 					result = false
   205 				}
   206 			}
   207 		}
   208 		return result
   209 	}
   210 
   211 	/**
   212 	 * Set a {@code field} to a {@code value} of an entity {@code object}
   213 	 * 
   214 	 * @param object
   215 	 * @param field
   216 	 * @param value
   217 	 */
   218 	protected void setFieldValue(Object object, Field field, Object value) {
   219 		boolean isAccessible = field.accessible
   220 		field.accessible = true
   221 		field.set(object, value)
   222 		field.accessible = isAccessible
   223 	}
   224 	
   225 	/**
   226 	 * Trying set a {@code field} to a {@code value} of an entity {@code object}
   227 	 * by using a setter method. If setter not found, field will be set directly
   228 	 * throw {@code setFieldValue} method.
   229 	 * 
   230 	 * @param object
   231 	 * @param field
   232 	 * @param value
   233 	 */
   234 	protected void setFieldValueBySetter(Object object, Field field, Object value) {
   235 		boolean isSetted = false
   236 		getMethods(object).grep {
   237 			(
   238 				Modifier.isPublic(it.modifiers) 
   239 				&& it.parameterTypes.size() == 1 
   240 				&& (
   241 					value == null 
   242 					|| it.parameterTypes[0].isAssignableFrom(value.class)
   243 					)
   244 				&& it.name =~ /^set(?i:${field.name.charAt(0)})${field.name.replaceFirst("^.{1}", "")}/
   245 			)
   246 		}.each {
   247 			object."${it.name}"(value)
   248 			isSetted = true
   249 		}
   250 		if(isSetted)
   251 			return
   252 		setFieldValue(object, field, value)
   253 	}
   254 
   255 	/**
   256 	 * @param object
   257 	 * @param field
   258 	 * @return the field value of an object
   259 	 */
   260 	protected Object getFieldValue(Object object, Field field) {
   261 		boolean isAccessible = field.accessible
   262 		field.accessible = true
   263 		def value = field.get(object)
   264 		field.accessible = isAccessible
   265 		return value
   266 	}
   267 
   268 	/**
   269 	 * @param obj
   270 	 * @return all object fields, including inherited
   271 	 */
   272 	protected Collection<Field> getFields(obj) {
   273 		Class clazz = obj.getClass()
   274 		Collection<Field> fields = clazz.declaredFields
   275 		while(clazz.superclass != null) {
   276 			clazz = clazz.superclass
   277 			fields.addAll(clazz.declaredFields)
   278 		}
   279 		fields.grep {
   280 			!it.synthetic &&
   281 					!Modifier.isStatic(it.modifiers) &&
   282 					!Modifier.isTransient(it.modifiers)
   283 		}
   284 	}
   285 
   286 	/**
   287 	 * @param obj
   288 	 * @return all object methods, including inherited
   289 	 */
   290 	protected Collection<Method> getMethods(obj) {
   291 		Class clazz = obj.getClass()
   292 		Collection<Method> methods = clazz.declaredMethods
   293 		while(clazz.superclass != null) {
   294 			clazz = clazz.superclass
   295 			methods.addAll(clazz.declaredMethods)
   296 		}
   297 		methods.grep {
   298 			!it.synthetic &&
   299 					!Modifier.isStatic(it.modifiers) &&
   300 					!Modifier.isTransient(it.modifiers)
   301 		}
   302 	}
   303 	
   304 	/**
   305 	 * For implement in successors
   306 	 *
   307 	 * @param entity
   308 	 */
   309 	protected void testCreatedEntity(Object entity) {
   310 		
   311 	}
   312 	
   313 	/**
   314 	 * For implement in successors
   315 	 *
   316 	 * @param entity
   317 	 */
   318 	protected void testUpdatedEntity(Object entity) {
   319 		
   320 	}
   321 	
   322 	/**
   323 	 * For implement in successors
   324 	 *
   325 	 * @param entity
   326 	 */
   327 	protected void testRemovedEntity(Object entity) {
   328 		
   329 	}
   330 
   331 	/**
   332 	 * @return an {@link AbstractEntity} successor
   333 	 */
   334 	abstract protected Class getEntityClass()
   335 }