提交 d77b715d 编写于 作者: A asympro 提交者: Sam Brannen

Merge class-level and method-level @Sql declarations

See gh-1835
上级 b0939a8a
......@@ -31,7 +31,8 @@ import org.springframework.core.annotation.AliasFor;
* SQL {@link #scripts} and {@link #statements} to be executed against a given
* database during integration tests.
*
* <p>Method-level declarations override class-level declarations.
* <p>Method-level declarations override class-level declarations by default.
* This behaviour can be adjusted via {@link MergeMode}
*
* <p>Script execution is performed by the {@link SqlScriptsTestExecutionListener},
* which is enabled by default.
......@@ -146,6 +147,13 @@ public @interface Sql {
*/
SqlConfig config() default @SqlConfig;
/**
* Indicates whether this annotation should be merged with upper-level annotations
* or override them.
* <p>Defaults to {@link MergeMode#OVERRIDE}.
*/
MergeMode mergeMode() default MergeMode.OVERRIDE;
/**
* Enumeration of <em>phases</em> that dictate when SQL scripts are executed.
......@@ -165,4 +173,23 @@ public @interface Sql {
AFTER_TEST_METHOD
}
/**
* Enumeration of <em>modes</em> that dictate whether or not
* declared SQL {@link #scripts} and {@link #statements} are merged
* with the upper-level annotations.
*/
enum MergeMode {
/**
* Indicates that locally declared SQL {@link #scripts} and {@link #statements}
* should override the upper-level (e.g. Class-level) annotations.
*/
OVERRIDE,
/**
* Indicates that locally declared SQL {@link #scripts} and {@link #statements}
* should be merged the upper-level (e.g. Class-level) annotations.
*/
MERGE
}
}
......@@ -16,14 +16,17 @@
package org.springframework.test.context.jdbc;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jetbrains.annotations.NotNull;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.io.ByteArrayResource;
......@@ -126,19 +129,35 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
* {@link TestContext} and {@link ExecutionPhase}.
*/
private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) throws Exception {
boolean classLevel = false;
Set<Sql> sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
testContext.getTestMethod(), Sql.class, SqlGroup.class);
if (sqlAnnotations.isEmpty()) {
sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
testContext.getTestClass(), Sql.class, SqlGroup.class);
if (!sqlAnnotations.isEmpty()) {
classLevel = true;
}
Set<Sql> methodLevelSqls = getScriptsFromElement(testContext.getTestMethod());
List<Sql> methodLevelOverrides = methodLevelSqls.stream()
.filter(s -> s.executionPhase() == executionPhase)
.filter(s -> s.mergeMode() == Sql.MergeMode.OVERRIDE)
.collect(Collectors.toList());
if (methodLevelOverrides.isEmpty()) {
executeScripts(getScriptsFromElement(testContext.getTestClass()), testContext, executionPhase, true);
executeScripts(methodLevelSqls, testContext, executionPhase, false);
} else {
executeScripts(methodLevelOverrides, testContext, executionPhase, false);
}
}
/**
* Get SQL scripts configured via {@link Sql @Sql} for the supplied
* {@link AnnotatedElement}.
*/
private Set<Sql> getScriptsFromElement(AnnotatedElement annotatedElement) throws Exception {
return AnnotatedElementUtils.getMergedRepeatableAnnotations(annotatedElement, Sql.class, SqlGroup.class);
}
for (Sql sql : sqlAnnotations) {
/**
* Execute given {@link Sql @Sql} scripts.
* {@link AnnotatedElement}.
*/
private void executeScripts(Iterable<Sql> scripts, TestContext testContext, ExecutionPhase executionPhase,
boolean classLevel)
throws Exception {
for (Sql sql : scripts) {
executeSqlScripts(sql, executionPhase, testContext, classLevel);
}
}
......@@ -166,14 +185,7 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
mergedSqlConfig, executionPhase, testContext));
}
final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
populator.setSeparator(mergedSqlConfig.getSeparator());
populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
final ResourceDatabasePopulator populator = configurePopulator(mergedSqlConfig);
String[] scripts = getScripts(sql, testContext, classLevel);
scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
......@@ -232,6 +244,19 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
}
}
@NotNull
private ResourceDatabasePopulator configurePopulator(MergedSqlConfig mergedSqlConfig) {
final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
populator.setSeparator(mergedSqlConfig.getSeparator());
populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
return populator;
}
@Nullable
private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) {
try {
......
......@@ -25,6 +25,7 @@ import org.junit.runners.MethodSorters;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
import org.springframework.test.jdbc.JdbcTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
......@@ -58,6 +59,10 @@ public class RepeatableSqlAnnotationSqlScriptsTests extends AbstractTransactiona
assertNumUsers(2);
}
protected int countRowsInTable(String tableName) {
return JdbcTestUtils.countRowsInTable(this.jdbcTemplate, tableName);
}
protected void assertNumUsers(int expected) {
assertThat(countRowsInTable("user")).as("Number of rows in the 'user' table.").isEqualTo(expected);
}
......
package org.springframework.test.context.jdbc;
import org.junit.Test;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
import static org.junit.Assert.assertEquals;
/**
* Test to verify method level merge of @Sql annotations.
*
* @author Dmitry Semukhin
*/
@ContextConfiguration(classes = EmptyDatabaseConfig.class)
@Sql(value = {"schema.sql", "data-add-catbert.sql"})
@DirtiesContext
public class SqlMethodMergeTest extends AbstractTransactionalJUnit4SpringContextTests {
@Test
@Sql(value = "data-add-dogbert.sql", mergeMode = Sql.MergeMode.MERGE)
public void testMerge() {
assertNumUsers(2);
}
protected void assertNumUsers(int expected) {
assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
}
}
package org.springframework.test.context.jdbc;
import org.junit.Test;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
import static org.junit.Assert.assertEquals;
/**
* Test to verify method level override of @Sql annotations.
*
* @author Dmitry Semukhin
*/
@ContextConfiguration(classes = EmptyDatabaseConfig.class)
@Sql(value = {"schema.sql", "data-add-catbert.sql"})
@DirtiesContext
public class SqlMethodOverrideTest extends AbstractTransactionalJUnit4SpringContextTests {
@Test
@Sql(value = {"schema.sql", "data.sql", "data-add-dogbert.sql", "data-add-catbert.sql"}, mergeMode = Sql.MergeMode.OVERRIDE)
public void testMerge() {
assertNumUsers(3);
}
protected void assertNumUsers(int expected) {
assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册