未验证 提交 aa5526ef 编写于 作者: S Stalary 提交者: GitHub

[Improvement][Test] Remove Powermock in dolphinscheduler-task-plugin (#12153)

* Remove the usage of powermock in task-plugin module
上级 ad4f3442
...@@ -17,56 +17,24 @@ ...@@ -17,56 +17,24 @@
package org.apache.dolphinscheduler.plugin.task.dvc; package org.apache.dolphinscheduler.plugin.task.dvc;
import java.util.Date;
import java.util.UUID;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContextCacheManager; import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContextCacheManager;
import org.apache.dolphinscheduler.spi.utils.JSONUtils;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mockito;
import org.powermock.api.mockito.PowerMockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.dolphinscheduler.spi.utils.PropertyUtils;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.powermock.modules.junit4.PowerMockRunner; import org.mockito.Mockito;
import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; import org.mockito.junit.MockitoJUnitRunner;
import org.apache.dolphinscheduler.spi.utils.JSONUtils;
@RunWith(PowerMockRunner.class) @RunWith(MockitoJUnitRunner.class)
@PrepareForTest({
JSONUtils.class,
PropertyUtils.class,
})
@PowerMockIgnore({"javax.*"})
@SuppressStaticInitializationFor("org.apache.dolphinscheduler.spi.utils.PropertyUtils")
public class DvcTaskTest { public class DvcTaskTest {
@Before
public void before() throws Exception {
PowerMockito.mockStatic(PropertyUtils.class);
}
public TaskExecutionContext createContext(DvcParameters dvcParameters) { public TaskExecutionContext createContext(DvcParameters dvcParameters) {
String parameters = JSONUtils.toJsonString(dvcParameters); String parameters = JSONUtils.toJsonString(dvcParameters);
TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters);
Mockito.when(taskExecutionContext.getTaskLogName()).thenReturn("DvcTest");
Mockito.when(taskExecutionContext.getExecutePath()).thenReturn("/tmp/dolphinscheduler_dvc_test");
Mockito.when(taskExecutionContext.getTaskAppId()).thenReturn(UUID.randomUUID().toString());
Mockito.when(taskExecutionContext.getStartTime()).thenReturn(new Date());
Mockito.when(taskExecutionContext.getTaskTimeout()).thenReturn(10000);
Mockito.when(taskExecutionContext.getLogPath()).thenReturn("/tmp/dolphinscheduler_dvc_test/log");
Mockito.when(taskExecutionContext.getEnvironmentConfig()).thenReturn("export PATH=$HOME/anaconda3/bin:$PATH");
String userName = System.getenv().get("USER");
Mockito.when(taskExecutionContext.getTenantCode()).thenReturn(userName);
TaskExecutionContextCacheManager.cacheTaskExecutionContext(taskExecutionContext); TaskExecutionContextCacheManager.cacheTaskExecutionContext(taskExecutionContext);
return taskExecutionContext; return taskExecutionContext;
...@@ -82,10 +50,10 @@ public class DvcTaskTest { ...@@ -82,10 +50,10 @@ public class DvcTaskTest {
} }
@Test @Test
public void testDvcUpload() throws Exception{ public void testDvcUpload() throws Exception {
DvcTask dvcTask = initTask(createUploadParameters()); DvcTask dvcTask = initTask(createUploadParameters());
Assert.assertEquals(dvcTask.buildCommand(), Assert.assertEquals(dvcTask.buildCommand(),
"which dvc || { echo \"dvc does not exist\"; exit 1; }; DVC_REPO=git@github.com:<YOUR-NAME-OR-ORG>/dvc-data-repository-example\n" + "which dvc || { echo \"dvc does not exist\"; exit 1; }; DVC_REPO=git@github.com:<YOUR-NAME-OR-ORG>/dvc-data-repository-example\n" +
"DVC_DATA_PATH=/home/<YOUR-NAME-OR-ORG>/test\n" + "DVC_DATA_PATH=/home/<YOUR-NAME-OR-ORG>/test\n" +
"DVC_DATA_LOCATION=test\n" + "DVC_DATA_LOCATION=test\n" +
"DVC_VERSION=iris_v2.3.1\n" + "DVC_VERSION=iris_v2.3.1\n" +
...@@ -101,10 +69,10 @@ public class DvcTaskTest { ...@@ -101,10 +69,10 @@ public class DvcTaskTest {
} }
@Test @Test
public void testDvcDownload() throws Exception{ public void testDvcDownload() throws Exception {
DvcTask dvcTask = initTask(createDownloadParameters()); DvcTask dvcTask = initTask(createDownloadParameters());
Assert.assertEquals(dvcTask.buildCommand(), Assert.assertEquals(dvcTask.buildCommand(),
"which dvc || { echo \"dvc does not exist\"; exit 1; }; DVC_REPO=git@github.com:<YOUR-NAME-OR-ORG>/dvc-data-repository-example\n" + "which dvc || { echo \"dvc does not exist\"; exit 1; }; DVC_REPO=git@github.com:<YOUR-NAME-OR-ORG>/dvc-data-repository-example\n" +
"DVC_DATA_PATH=data\n" + "DVC_DATA_PATH=data\n" +
"DVC_DATA_LOCATION=iris\n" + "DVC_DATA_LOCATION=iris\n" +
"DVC_VERSION=iris_v2.3.1\n" + "DVC_VERSION=iris_v2.3.1\n" +
...@@ -112,10 +80,10 @@ public class DvcTaskTest { ...@@ -112,10 +80,10 @@ public class DvcTaskTest {
} }
@Test @Test
public void testInitDvc() throws Exception{ public void testInitDvc() throws Exception {
DvcTask dvcTask = initTask(createInitDvcParameters()); DvcTask dvcTask = initTask(createInitDvcParameters());
Assert.assertEquals(dvcTask.buildCommand(), Assert.assertEquals(dvcTask.buildCommand(),
"which dvc || { echo \"dvc does not exist\"; exit 1; }; DVC_REPO=git@github.com:<YOUR-NAME-OR-ORG>/dvc-data-repository-example\n" + "which dvc || { echo \"dvc does not exist\"; exit 1; }; DVC_REPO=git@github.com:<YOUR-NAME-OR-ORG>/dvc-data-repository-example\n" +
"git clone $DVC_REPO dvc-repository; cd dvc-repository; pwd\n" + "git clone $DVC_REPO dvc-repository; cd dvc-repository; pwd\n" +
"dvc init || exit 1\n" + "dvc init || exit 1\n" +
"dvc remote add origin ~/.dvc_test -d\n" + "dvc remote add origin ~/.dvc_test -d\n" +
......
...@@ -98,7 +98,7 @@ public abstract class AbstractEmrTask extends AbstractRemoteTask { ...@@ -98,7 +98,7 @@ public abstract class AbstractEmrTask extends AbstractRemoteTask {
* *
* @return AmazonElasticMapReduce * @return AmazonElasticMapReduce
*/ */
private AmazonElasticMapReduce createEmrClient() { protected AmazonElasticMapReduce createEmrClient() {
final String awsAccessKeyId = PropertyUtils.getString(TaskConstants.AWS_ACCESS_KEY_ID); final String awsAccessKeyId = PropertyUtils.getString(TaskConstants.AWS_ACCESS_KEY_ID);
final String awsSecretAccessKey = PropertyUtils.getString(TaskConstants.AWS_SECRET_ACCESS_KEY); final String awsSecretAccessKey = PropertyUtils.getString(TaskConstants.AWS_SECRET_ACCESS_KEY);
......
...@@ -120,7 +120,7 @@ public class EmrAddStepsTask extends AbstractEmrTask { ...@@ -120,7 +120,7 @@ public class EmrAddStepsTask extends AbstractEmrTask {
* *
* @return AddJobFlowStepsRequest * @return AddJobFlowStepsRequest
*/ */
private AddJobFlowStepsRequest createAddJobFlowStepsRequest() { protected AddJobFlowStepsRequest createAddJobFlowStepsRequest() {
final AddJobFlowStepsRequest addJobFlowStepsRequest; final AddJobFlowStepsRequest addJobFlowStepsRequest;
try { try {
......
...@@ -114,7 +114,7 @@ public class EmrJobFlowTask extends AbstractEmrTask { ...@@ -114,7 +114,7 @@ public class EmrJobFlowTask extends AbstractEmrTask {
* *
* @return RunJobFlowRequest * @return RunJobFlowRequest
*/ */
private RunJobFlowRequest createRunJobFlowRequest() { protected RunJobFlowRequest createRunJobFlowRequest() {
final RunJobFlowRequest runJobFlowRequest; final RunJobFlowRequest runJobFlowRequest;
try { try {
......
...@@ -20,13 +20,7 @@ package org.apache.dolphinscheduler.plugin.task.emr; ...@@ -20,13 +20,7 @@ package org.apache.dolphinscheduler.plugin.task.emr;
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_FAILURE; import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_FAILURE;
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_KILL; import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_KILL;
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_SUCCESS; import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_SUCCESS;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.powermock.api.mockito.PowerMockito.doReturn;
import static org.powermock.api.mockito.PowerMockito.mock;
import static org.powermock.api.mockito.PowerMockito.mockStatic;
import static org.powermock.api.mockito.PowerMockito.spy;
import static org.powermock.api.mockito.PowerMockito.when;
import org.apache.dolphinscheduler.plugin.task.api.TaskCallBack; import org.apache.dolphinscheduler.plugin.task.api.TaskCallBack;
import org.apache.dolphinscheduler.plugin.task.api.TaskException; import org.apache.dolphinscheduler.plugin.task.api.TaskException;
...@@ -43,13 +37,10 @@ import org.junit.Assert; ...@@ -43,13 +37,10 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.powermock.api.mockito.PowerMockito; import org.mockito.Mockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.mockito.junit.MockitoJUnitRunner;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce;
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder;
import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsResult; import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsResult;
import com.amazonaws.services.elasticmapreduce.model.AmazonElasticMapReduceException; import com.amazonaws.services.elasticmapreduce.model.AmazonElasticMapReduceException;
import com.amazonaws.services.elasticmapreduce.model.DescribeStepResult; import com.amazonaws.services.elasticmapreduce.model.DescribeStepResult;
...@@ -62,14 +53,7 @@ import com.amazonaws.services.elasticmapreduce.model.StepStatus; ...@@ -62,14 +53,7 @@ import com.amazonaws.services.elasticmapreduce.model.StepStatus;
* *
* @since v3.1.0 * @since v3.1.0
*/ */
@RunWith(PowerMockRunner.class) @RunWith(MockitoJUnitRunner.class)
@PrepareForTest({
AmazonElasticMapReduceClientBuilder.class,
EmrAddStepsTask.class,
AmazonElasticMapReduce.class,
JSONUtils.class
})
@PowerMockIgnore({"javax.*"})
public class EmrAddStepsTaskTest { public class EmrAddStepsTaskTest {
private final StepStatus pendingState = private final StepStatus pendingState =
...@@ -99,32 +83,31 @@ public class EmrAddStepsTaskTest { ...@@ -99,32 +83,31 @@ public class EmrAddStepsTaskTest {
// mock EmrParameters and EmrAddStepsTask // mock EmrParameters and EmrAddStepsTask
EmrParameters emrParameters = buildEmrTaskParameters(); EmrParameters emrParameters = buildEmrTaskParameters();
String emrParametersString = JSONUtils.toJsonString(emrParameters); String emrParametersString = JSONUtils.toJsonString(emrParameters);
TaskExecutionContext taskExecutionContext = PowerMockito.mock(TaskExecutionContext.class); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
when(taskExecutionContext.getTaskParams()).thenReturn(emrParametersString); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(emrParametersString);
emrAddStepsTask = spy(new EmrAddStepsTask(taskExecutionContext)); emrAddStepsTask = Mockito.spy(new EmrAddStepsTask(taskExecutionContext));
// mock emrClient and behavior // mock emrClient and behavior
emrClient = mock(AmazonElasticMapReduce.class); emrClient = Mockito.mock(AmazonElasticMapReduce.class);
AddJobFlowStepsResult addJobFlowStepsResult = mock(AddJobFlowStepsResult.class); AddJobFlowStepsResult addJobFlowStepsResult = Mockito.mock(AddJobFlowStepsResult.class);
when(emrClient.addJobFlowSteps(any())).thenReturn(addJobFlowStepsResult); Mockito.when(emrClient.addJobFlowSteps(any())).thenReturn(addJobFlowStepsResult);
when(addJobFlowStepsResult.getStepIds()).thenReturn(Collections.singletonList("step-xx")); Mockito.when(addJobFlowStepsResult.getStepIds()).thenReturn(Collections.singletonList("step-xx"));
doReturn(emrClient).when(emrAddStepsTask, "createEmrClient"); Mockito.doReturn(emrClient).when(emrAddStepsTask).createEmrClient();
DescribeStepResult describeStepResult = mock(DescribeStepResult.class); DescribeStepResult describeStepResult = Mockito.mock(DescribeStepResult.class);
when(emrClient.describeStep(any())).thenReturn(describeStepResult); Mockito.when(emrClient.describeStep(any())).thenReturn(describeStepResult);
// mock step // mock step
step = mock(Step.class); step = Mockito.mock(Step.class);
when(describeStepResult.getStep()).thenReturn(step); Mockito.when(describeStepResult.getStep()).thenReturn(step);
emrAddStepsTask.init(); emrAddStepsTask.init();
} }
@Test(expected = TaskException.class) @Test(expected = TaskException.class)
public void testCanNotParseJson() throws Exception { public void testCanNotParseJson() throws Exception {
mockStatic(JSONUtils.class); Mockito.when(emrAddStepsTask.createAddJobFlowStepsRequest()).thenThrow(new EmrTaskException("can not parse AddJobFlowStepsRequest from json", new Exception("error")));
when(emrAddStepsTask, "createAddJobFlowStepsRequest").thenThrow(new EmrTaskException("can not parse AddJobFlowStepsRequest from json", new Exception("error")));
emrAddStepsTask.handle(taskCallBack); emrAddStepsTask.handle(taskCallBack);
} }
...@@ -133,17 +116,17 @@ public class EmrAddStepsTaskTest { ...@@ -133,17 +116,17 @@ public class EmrAddStepsTaskTest {
// mock EmrParameters and EmrAddStepsTask // mock EmrParameters and EmrAddStepsTask
EmrParameters emrParameters = buildErrorEmrTaskParameters(); EmrParameters emrParameters = buildErrorEmrTaskParameters();
String emrParametersString = JSONUtils.toJsonString(emrParameters); String emrParametersString = JSONUtils.toJsonString(emrParameters);
TaskExecutionContext taskExecutionContext = PowerMockito.mock(TaskExecutionContext.class); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
when(taskExecutionContext.getTaskParams()).thenReturn(emrParametersString); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(emrParametersString);
emrAddStepsTask = spy(new EmrAddStepsTask(taskExecutionContext)); emrAddStepsTask = Mockito.spy(new EmrAddStepsTask(taskExecutionContext));
doReturn(emrClient).when(emrAddStepsTask, "createEmrClient"); Mockito.doReturn(emrClient).when(emrAddStepsTask).createEmrClient();
emrAddStepsTask.init(); emrAddStepsTask.init();
emrAddStepsTask.handle(taskCallBack); emrAddStepsTask.handle(taskCallBack);
} }
@Test @Test
public void testHandle() throws Exception { public void testHandle() throws Exception {
when(step.getStatus()).thenReturn(pendingState, runningState, completedState); Mockito.when(step.getStatus()).thenReturn(pendingState, runningState, completedState);
emrAddStepsTask.handle(taskCallBack); emrAddStepsTask.handle(taskCallBack);
Assert.assertEquals(EXIT_CODE_SUCCESS, emrAddStepsTask.getExitStatusCode()); Assert.assertEquals(EXIT_CODE_SUCCESS, emrAddStepsTask.getExitStatusCode());
...@@ -151,7 +134,7 @@ public class EmrAddStepsTaskTest { ...@@ -151,7 +134,7 @@ public class EmrAddStepsTaskTest {
@Test @Test
public void testHandleUserRequestTerminate() throws Exception { public void testHandleUserRequestTerminate() throws Exception {
when(step.getStatus()).thenReturn(pendingState, runningState, cancelledState); Mockito.when(step.getStatus()).thenReturn(pendingState, runningState, cancelledState);
emrAddStepsTask.handle(taskCallBack); emrAddStepsTask.handle(taskCallBack);
Assert.assertEquals(EXIT_CODE_KILL, emrAddStepsTask.getExitStatusCode()); Assert.assertEquals(EXIT_CODE_KILL, emrAddStepsTask.getExitStatusCode());
...@@ -159,11 +142,11 @@ public class EmrAddStepsTaskTest { ...@@ -159,11 +142,11 @@ public class EmrAddStepsTaskTest {
@Test(expected = TaskException.class) @Test(expected = TaskException.class)
public void testHandleError() throws Exception { public void testHandleError() throws Exception {
when(step.getStatus()).thenReturn(pendingState, runningState, failedState); Mockito.when(step.getStatus()).thenReturn(pendingState, runningState, failedState);
emrAddStepsTask.handle(taskCallBack); emrAddStepsTask.handle(taskCallBack);
Assert.assertEquals(EXIT_CODE_FAILURE, emrAddStepsTask.getExitStatusCode()); Assert.assertEquals(EXIT_CODE_FAILURE, emrAddStepsTask.getExitStatusCode());
when(emrClient.addJobFlowSteps(any())).thenThrow(new AmazonElasticMapReduceException("error"), new EmrTaskException()); Mockito.when(emrClient.addJobFlowSteps(any())).thenThrow(new AmazonElasticMapReduceException("error"), new EmrTaskException());
emrAddStepsTask.handle(taskCallBack); emrAddStepsTask.handle(taskCallBack);
} }
......
...@@ -20,13 +20,7 @@ package org.apache.dolphinscheduler.plugin.task.emr; ...@@ -20,13 +20,7 @@ package org.apache.dolphinscheduler.plugin.task.emr;
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_FAILURE; import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_FAILURE;
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_KILL; import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_KILL;
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_SUCCESS; import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_SUCCESS;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.powermock.api.mockito.PowerMockito.doReturn;
import static org.powermock.api.mockito.PowerMockito.mock;
import static org.powermock.api.mockito.PowerMockito.mockStatic;
import static org.powermock.api.mockito.PowerMockito.spy;
import static org.powermock.api.mockito.PowerMockito.when;
import org.apache.dolphinscheduler.plugin.task.api.TaskCallBack; import org.apache.dolphinscheduler.plugin.task.api.TaskCallBack;
import org.apache.dolphinscheduler.plugin.task.api.TaskException; import org.apache.dolphinscheduler.plugin.task.api.TaskException;
...@@ -42,13 +36,10 @@ import org.junit.Assert; ...@@ -42,13 +36,10 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.powermock.api.mockito.PowerMockito; import org.mockito.Mockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.mockito.junit.MockitoJUnitRunner;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce;
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder;
import com.amazonaws.services.elasticmapreduce.model.AmazonElasticMapReduceException; import com.amazonaws.services.elasticmapreduce.model.AmazonElasticMapReduceException;
import com.amazonaws.services.elasticmapreduce.model.Cluster; import com.amazonaws.services.elasticmapreduce.model.Cluster;
import com.amazonaws.services.elasticmapreduce.model.ClusterState; import com.amazonaws.services.elasticmapreduce.model.ClusterState;
...@@ -58,14 +49,7 @@ import com.amazonaws.services.elasticmapreduce.model.ClusterStatus; ...@@ -58,14 +49,7 @@ import com.amazonaws.services.elasticmapreduce.model.ClusterStatus;
import com.amazonaws.services.elasticmapreduce.model.DescribeClusterResult; import com.amazonaws.services.elasticmapreduce.model.DescribeClusterResult;
import com.amazonaws.services.elasticmapreduce.model.RunJobFlowResult; import com.amazonaws.services.elasticmapreduce.model.RunJobFlowResult;
@RunWith(PowerMockRunner.class) @RunWith(MockitoJUnitRunner.class)
@PrepareForTest({
AmazonElasticMapReduceClientBuilder.class,
EmrJobFlowTask.class,
AmazonElasticMapReduce.class,
JSONUtils.class
})
@PowerMockIgnore({"javax.*"})
public class EmrJobFlowTaskTest { public class EmrJobFlowTaskTest {
private final ClusterStatus startingStatus = private final ClusterStatus startingStatus =
...@@ -126,22 +110,22 @@ public class EmrJobFlowTaskTest { ...@@ -126,22 +110,22 @@ public class EmrJobFlowTaskTest {
@Before @Before
public void before() throws Exception { public void before() throws Exception {
String emrParameters = buildEmrTaskParameters(); String emrParameters = buildEmrTaskParameters();
TaskExecutionContext taskExecutionContext = PowerMockito.mock(TaskExecutionContext.class); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
when(taskExecutionContext.getTaskParams()).thenReturn(emrParameters); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(emrParameters);
emrJobFlowTask = spy(new EmrJobFlowTask(taskExecutionContext)); emrJobFlowTask = Mockito.spy(new EmrJobFlowTask(taskExecutionContext));
// mock emrClient and behavior // mock emrClient and behavior
emrClient = mock(AmazonElasticMapReduce.class); emrClient = Mockito.mock(AmazonElasticMapReduce.class);
RunJobFlowResult runJobFlowResult = mock(RunJobFlowResult.class); RunJobFlowResult runJobFlowResult = Mockito.mock(RunJobFlowResult.class);
when(emrClient.runJobFlow(any())).thenReturn(runJobFlowResult); Mockito.when(emrClient.runJobFlow(any())).thenReturn(runJobFlowResult);
when(runJobFlowResult.getJobFlowId()).thenReturn("xx"); Mockito.when(runJobFlowResult.getJobFlowId()).thenReturn("xx");
doReturn(emrClient).when(emrJobFlowTask, "createEmrClient"); Mockito.doReturn(emrClient).when(emrJobFlowTask).createEmrClient();
DescribeClusterResult describeClusterResult = mock(DescribeClusterResult.class); DescribeClusterResult describeClusterResult = Mockito.mock(DescribeClusterResult.class);
when(emrClient.describeCluster(any())).thenReturn(describeClusterResult); Mockito.when(emrClient.describeCluster(any())).thenReturn(describeClusterResult);
// mock cluster // mock cluster
cluster = mock(Cluster.class); cluster = Mockito.mock(Cluster.class);
when(describeClusterResult.getCluster()).thenReturn(cluster); Mockito.when(describeClusterResult.getCluster()).thenReturn(cluster);
emrJobFlowTask.init(); emrJobFlowTask.init();
} }
...@@ -149,7 +133,7 @@ public class EmrJobFlowTaskTest { ...@@ -149,7 +133,7 @@ public class EmrJobFlowTaskTest {
@Test @Test
public void testHandle() throws Exception { public void testHandle() throws Exception {
when(cluster.getStatus()).thenReturn(startingStatus, softwareConfigStatus, runningStatus, terminatingStatus); Mockito.when(cluster.getStatus()).thenReturn(startingStatus, softwareConfigStatus, runningStatus, terminatingStatus);
emrJobFlowTask.handle(taskCallBack); emrJobFlowTask.handle(taskCallBack);
Assert.assertEquals(EXIT_CODE_SUCCESS, emrJobFlowTask.getExitStatusCode()); Assert.assertEquals(EXIT_CODE_SUCCESS, emrJobFlowTask.getExitStatusCode());
...@@ -158,7 +142,7 @@ public class EmrJobFlowTaskTest { ...@@ -158,7 +142,7 @@ public class EmrJobFlowTaskTest {
@Test @Test
public void testHandleAliveWhenNoSteps() throws Exception { public void testHandleAliveWhenNoSteps() throws Exception {
when(cluster.getStatus()).thenReturn(startingStatus, softwareConfigStatus, runningStatus, waitingStatus); Mockito.when(cluster.getStatus()).thenReturn(startingStatus, softwareConfigStatus, runningStatus, waitingStatus);
emrJobFlowTask.handle(taskCallBack); emrJobFlowTask.handle(taskCallBack);
Assert.assertEquals(EXIT_CODE_SUCCESS, emrJobFlowTask.getExitStatusCode()); Assert.assertEquals(EXIT_CODE_SUCCESS, emrJobFlowTask.getExitStatusCode());
...@@ -166,7 +150,7 @@ public class EmrJobFlowTaskTest { ...@@ -166,7 +150,7 @@ public class EmrJobFlowTaskTest {
@Test @Test
public void testHandleUserRequestTerminate() throws Exception { public void testHandleUserRequestTerminate() throws Exception {
when(cluster.getStatus()).thenReturn(startingStatus, userRequestTerminateStatus); Mockito.when(cluster.getStatus()).thenReturn(startingStatus, userRequestTerminateStatus);
emrJobFlowTask.handle(taskCallBack); emrJobFlowTask.handle(taskCallBack);
Assert.assertEquals(EXIT_CODE_KILL, emrJobFlowTask.getExitStatusCode()); Assert.assertEquals(EXIT_CODE_KILL, emrJobFlowTask.getExitStatusCode());
...@@ -174,7 +158,7 @@ public class EmrJobFlowTaskTest { ...@@ -174,7 +158,7 @@ public class EmrJobFlowTaskTest {
@Test @Test
public void testHandleTerminatedWithError() throws Exception { public void testHandleTerminatedWithError() throws Exception {
when(cluster.getStatus()).thenReturn(startingStatus, softwareConfigStatus, runningStatus, terminatedWithErrorsStatus); Mockito.when(cluster.getStatus()).thenReturn(startingStatus, softwareConfigStatus, runningStatus, terminatedWithErrorsStatus);
emrJobFlowTask.handle(taskCallBack); emrJobFlowTask.handle(taskCallBack);
Assert.assertEquals(EXIT_CODE_FAILURE, emrJobFlowTask.getExitStatusCode()); Assert.assertEquals(EXIT_CODE_FAILURE, emrJobFlowTask.getExitStatusCode());
...@@ -182,21 +166,20 @@ public class EmrJobFlowTaskTest { ...@@ -182,21 +166,20 @@ public class EmrJobFlowTaskTest {
@Test(expected = TaskException.class) @Test(expected = TaskException.class)
public void testCanNotParseJson() throws Exception { public void testCanNotParseJson() throws Exception {
mockStatic(JSONUtils.class); Mockito.when(emrJobFlowTask.createRunJobFlowRequest()).thenThrow(new EmrTaskException("can not parse RunJobFlowRequest from json", new Exception("error")));
when(emrJobFlowTask, "createRunJobFlowRequest").thenThrow(new EmrTaskException("can not parse RunJobFlowRequest from json", new Exception("error")));
emrJobFlowTask.handle(taskCallBack); emrJobFlowTask.handle(taskCallBack);
} }
@Test(expected = TaskException.class) @Test(expected = TaskException.class)
public void testClusterStatusNull() throws Exception { public void testClusterStatusNull() throws Exception {
when(emrClient.describeCluster(any())).thenReturn(null); Mockito.when(emrClient.describeCluster(any())).thenReturn(null);
emrJobFlowTask.handle(taskCallBack); emrJobFlowTask.handle(taskCallBack);
} }
@Test(expected = TaskException.class) @Test(expected = TaskException.class)
public void testRunJobFlowError() throws Exception { public void testRunJobFlowError() throws Exception {
when(emrClient.runJobFlow(any())).thenThrow(new AmazonElasticMapReduceException("error"), new EmrTaskException()); Mockito.when(emrClient.runJobFlow(any())).thenThrow(new AmazonElasticMapReduceException("error"), new EmrTaskException());
emrJobFlowTask.handle(taskCallBack); emrJobFlowTask.handle(taskCallBack);
} }
......
...@@ -19,7 +19,6 @@ package org.apache.dolphinscheduler.plugin.task.http; ...@@ -19,7 +19,6 @@ package org.apache.dolphinscheduler.plugin.task.http;
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_FAILURE; import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_FAILURE;
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_SUCCESS; import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_SUCCESS;
import static org.powermock.api.mockito.PowerMockito.when;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext;
import org.apache.dolphinscheduler.plugin.task.api.model.Property; import org.apache.dolphinscheduler.plugin.task.api.model.Property;
...@@ -43,7 +42,9 @@ import okhttp3.mockwebserver.RecordedRequest; ...@@ -43,7 +42,9 @@ import okhttp3.mockwebserver.RecordedRequest;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
...@@ -51,6 +52,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; ...@@ -51,6 +52,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
/** /**
* Test HttpTask * Test HttpTask
*/ */
@RunWith(MockitoJUnitRunner.class)
public class HttpTaskTest { public class HttpTaskTest {
private static final String CONTENT_TYPE = "Content-Type"; private static final String CONTENT_TYPE = "Content-Type";
...@@ -203,7 +205,7 @@ public class HttpTaskTest { ...@@ -203,7 +205,7 @@ public class HttpTaskTest {
private HttpTask generateHttpTaskFromParamData(String paramData, Map<String, String> prepareParamsMap) { private HttpTask generateHttpTaskFromParamData(String paramData, Map<String, String> prepareParamsMap) {
TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
when(taskExecutionContext.getTaskParams()).thenReturn(paramData); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(paramData);
if (prepareParamsMap != null) { if (prepareParamsMap != null) {
Map<String, Property> propertyParamsMap = new HashMap<>(); Map<String, Property> propertyParamsMap = new HashMap<>();
prepareParamsMap.forEach((k, v) -> { prepareParamsMap.forEach((k, v) -> {
...@@ -212,7 +214,7 @@ public class HttpTaskTest { ...@@ -212,7 +214,7 @@ public class HttpTaskTest {
property.setValue(v); property.setValue(v);
propertyParamsMap.put(k, property); propertyParamsMap.put(k, property);
}); });
when(taskExecutionContext.getPrepareParamsMap()).thenReturn(propertyParamsMap); Mockito.when(taskExecutionContext.getPrepareParamsMap()).thenReturn(propertyParamsMap);
} }
HttpTask httpTask = new HttpTask(taskExecutionContext); HttpTask httpTask = new HttpTask(taskExecutionContext);
httpTask.init(); httpTask.init();
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package org.apache.dolphinler.plugin.task.mlflow; package org.apache.dolphinler.plugin.task.mlflow;
import static org.powermock.api.mockito.PowerMockito.when;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContextCacheManager; import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContextCacheManager;
import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowConstants; import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowConstants;
...@@ -27,53 +25,37 @@ import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowTask; ...@@ -27,53 +25,37 @@ import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowTask;
import org.apache.dolphinscheduler.spi.utils.JSONUtils; import org.apache.dolphinscheduler.spi.utils.JSONUtils;
import org.apache.dolphinscheduler.spi.utils.PropertyUtils; import org.apache.dolphinscheduler.spi.utils.PropertyUtils;
import java.util.Date; import org.junit.After;
import java.util.UUID;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.MockedStatic;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.powermock.api.mockito.PowerMockito; import org.mockito.junit.MockitoJUnitRunner;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor;
import org.powermock.modules.junit4.PowerMockRunner;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@RunWith(PowerMockRunner.class) @RunWith(MockitoJUnitRunner.class)
@PrepareForTest({
JSONUtils.class,
PropertyUtils.class,
})
@PowerMockIgnore({"javax.*"})
@SuppressStaticInitializationFor("org.apache.dolphinscheduler.spi.utils.PropertyUtils")
public class MlflowTaskTest { public class MlflowTaskTest {
private static final Logger logger = LoggerFactory.getLogger(MlflowTask.class); private static final Logger logger = LoggerFactory.getLogger(MlflowTask.class);
private MockedStatic<PropertyUtils> propertyUtilsMockedStatic;
@Before @Before
public void before() throws Exception { public void init() {
PowerMockito.mockStatic(PropertyUtils.class); propertyUtilsMockedStatic = Mockito.mockStatic(PropertyUtils.class);
propertyUtilsMockedStatic.when(() -> PropertyUtils.getString(MlflowConstants.PRESET_REPOSITORY_VERSION_KEY)).thenReturn("main");
}
@After
public void clean() {
propertyUtilsMockedStatic.close();
} }
public TaskExecutionContext createContext(MlflowParameters mlflowParameters) { public TaskExecutionContext createContext(MlflowParameters mlflowParameters) {
String parameters = JSONUtils.toJsonString(mlflowParameters); String parameters = JSONUtils.toJsonString(mlflowParameters);
TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters);
Mockito.when(taskExecutionContext.getTaskLogName()).thenReturn("MLflowTest");
Mockito.when(taskExecutionContext.getExecutePath()).thenReturn("/tmp/dolphinscheduler_test");
Mockito.when(taskExecutionContext.getTaskAppId()).thenReturn(UUID.randomUUID().toString());
Mockito.when(taskExecutionContext.getTenantCode()).thenReturn("root");
Mockito.when(taskExecutionContext.getStartTime()).thenReturn(new Date());
Mockito.when(taskExecutionContext.getTaskTimeout()).thenReturn(10000);
Mockito.when(taskExecutionContext.getLogPath()).thenReturn("/tmp/dolphinscheduler_test/log");
Mockito.when(taskExecutionContext.getEnvironmentConfig()).thenReturn("export PATH=$HOME/anaconda3/bin:$PATH");
String userName = System.getenv().get("USER");
Mockito.when(taskExecutionContext.getTenantCode()).thenReturn(userName);
TaskExecutionContextCacheManager.cacheTaskExecutionContext(taskExecutionContext); TaskExecutionContextCacheManager.cacheTaskExecutionContext(taskExecutionContext);
return taskExecutionContext; return taskExecutionContext;
} }
...@@ -85,11 +67,11 @@ public class MlflowTaskTest { ...@@ -85,11 +67,11 @@ public class MlflowTaskTest {
Assert.assertEquals("main", MlflowTask.getPresetRepositoryVersion()); Assert.assertEquals("main", MlflowTask.getPresetRepositoryVersion());
String definedRepository = "https://github.com/<MY-ID>/dolphinscheduler-mlflow"; String definedRepository = "https://github.com/<MY-ID>/dolphinscheduler-mlflow";
when(PropertyUtils.getString(MlflowConstants.PRESET_REPOSITORY_KEY)).thenAnswer(invocation -> definedRepository); Mockito.when(PropertyUtils.getString(MlflowConstants.PRESET_REPOSITORY_KEY)).thenAnswer(invocation -> definedRepository);
Assert.assertEquals(definedRepository, MlflowTask.getPresetRepository()); Assert.assertEquals(definedRepository, MlflowTask.getPresetRepository());
String definedRepositoryVersion = "dev"; String definedRepositoryVersion = "dev";
when(PropertyUtils.getString(MlflowConstants.PRESET_REPOSITORY_VERSION_KEY)).thenAnswer(invocation -> definedRepositoryVersion); Mockito.when(PropertyUtils.getString(MlflowConstants.PRESET_REPOSITORY_VERSION_KEY)).thenAnswer(invocation -> definedRepositoryVersion);
Assert.assertEquals(definedRepositoryVersion, MlflowTask.getPresetRepositoryVersion()); Assert.assertEquals(definedRepositoryVersion, MlflowTask.getPresetRepositoryVersion());
} }
......
...@@ -20,13 +20,14 @@ package org.apache.dolphinscheduler.plugin.task.openmldb; ...@@ -20,13 +20,14 @@ package org.apache.dolphinscheduler.plugin.task.openmldb;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext;
import org.apache.dolphinscheduler.plugin.task.api.model.Property; import org.apache.dolphinscheduler.plugin.task.api.model.Property;
import org.apache.dolphinscheduler.plugin.task.api.parameters.AbstractParameters; import org.apache.dolphinscheduler.plugin.task.api.parameters.AbstractParameters;
import org.apache.dolphinscheduler.spi.utils.JSONUtils;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.powermock.reflect.Whitebox; import org.mockito.Mockito;
public class OpenmldbTaskTest { public class OpenmldbTaskTest {
static class MockOpenmldbTask extends OpenmldbTask { static class MockOpenmldbTask extends OpenmldbTask {
...@@ -59,15 +60,19 @@ public class OpenmldbTaskTest { ...@@ -59,15 +60,19 @@ public class OpenmldbTaskTest {
@Test @Test
public void buildSQLWithComment() throws Exception { public void buildSQLWithComment() throws Exception {
OpenmldbTask openmldbTask = createOpenmldbTask(); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
OpenmldbParameters openmldbParameters = new OpenmldbParameters(); OpenmldbParameters openmldbParameters = new OpenmldbParameters();
openmldbParameters.setExecuteMode("offline"); openmldbParameters.setExecuteMode("offline");
openmldbParameters.setZk("localhost:2181");
openmldbParameters.setZkPath("dolphinscheduler");
String rawSQLScript = "select * from users\r\n" String rawSQLScript = "select * from users\r\n"
+ "-- some comment\n" + "-- some comment\n"
+ "inner join order on users.order_id = order.id; \n\n;" + "inner join order on users.order_id = order.id; \n\n;"
+ "select * from users;"; + "select * from users;";
openmldbParameters.setSql(rawSQLScript); openmldbParameters.setSql(rawSQLScript);
Whitebox.setInternalState(openmldbTask, "openmldbParameters", openmldbParameters); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(JSONUtils.toJsonString(openmldbParameters));
OpenmldbTask openmldbTask = new OpenmldbTask(taskExecutionContext);
openmldbTask.init();
OpenmldbParameters internal = (OpenmldbParameters) openmldbTask.getParameters(); OpenmldbParameters internal = (OpenmldbParameters) openmldbTask.getParameters();
Assert.assertNotNull(internal); Assert.assertNotNull(internal);
Assert.assertEquals(internal.getExecuteMode(), "offline"); Assert.assertEquals(internal.getExecuteMode(), "offline");
...@@ -75,7 +80,7 @@ public class OpenmldbTaskTest { ...@@ -75,7 +80,7 @@ public class OpenmldbTaskTest {
String result1 = openmldbTask.buildPythonScriptContent(); String result1 = openmldbTask.buildPythonScriptContent();
Assert.assertEquals("import openmldb\n" Assert.assertEquals("import openmldb\n"
+ "import sqlalchemy as db\n" + "import sqlalchemy as db\n"
+ "engine = db.create_engine('openmldb:///?zk=null&zkPath=null')\n" + "engine = db.create_engine('openmldb:///?zk=localhost:2181&zkPath=dolphinscheduler')\n"
+ "con = engine.connect()\n" + "con = engine.connect()\n"
+ "con.execute(\"set @@execute_mode='offline';\")\n" + "con.execute(\"set @@execute_mode='offline';\")\n"
+ "con.execute(\"set @@sync_job=true\")\n" + "con.execute(\"set @@sync_job=true\")\n"
......
...@@ -19,17 +19,13 @@ package org.apache.dolphinscheduler.plugin.task.python; ...@@ -19,17 +19,13 @@ package org.apache.dolphinscheduler.plugin.task.python;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.powermock.reflect.Whitebox;
public class PythonTaskTest { public class PythonTaskTest {
@Test @Test
public void buildPythonExecuteCommand() throws Exception { public void buildPythonExecuteCommand() throws Exception {
PythonTask pythonTask = createPythonTask(); PythonTask pythonTask = createPythonTask();
String methodName = "buildPythonExecuteCommand"; Assert.assertEquals("${PYTHON_HOME} test.py", pythonTask.buildPythonExecuteCommand("test.py"));
String pythonFile = "test.py";
String result1 = Whitebox.invokeMethod(pythonTask, methodName, pythonFile);
Assert.assertEquals("${PYTHON_HOME} test.py", result1);
} }
private PythonTask createPythonTask() { private PythonTask createPythonTask() {
......
...@@ -45,26 +45,14 @@ import org.junit.Before; ...@@ -45,26 +45,14 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.powermock.api.mockito.PowerMockito; import org.mockito.junit.MockitoJUnitRunner;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest; @RunWith(MockitoJUnitRunner.class)
import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor;
import org.powermock.modules.junit4.PowerMockRunner;
@RunWith(PowerMockRunner.class)
@PrepareForTest({JSONUtils.class, PropertyUtils.class,})
@PowerMockIgnore({"javax.*"})
@SuppressStaticInitializationFor("org.apache.dolphinscheduler.spi.utils.PropertyUtils")
public class PytorchTaskTest { public class PytorchTaskTest {
private final String pythonPath = "."; private final String pythonPath = ".";
private final String requirementPath = "requirements.txt"; private final String requirementPath = "requirements.txt";
@Before
public void before() {
PowerMockito.mockStatic(PropertyUtils.class);
}
@Test @Test
public void testPythonEnvManager() { public void testPythonEnvManager() {
PythonEnvManager envManager = new PythonEnvManager(); PythonEnvManager envManager = new PythonEnvManager();
...@@ -207,22 +195,7 @@ public class PytorchTaskTest { ...@@ -207,22 +195,7 @@ public class PytorchTaskTest {
public TaskExecutionContext createContext(PytorchParameters pytorchParameters) { public TaskExecutionContext createContext(PytorchParameters pytorchParameters) {
String parameters = JSONUtils.toJsonString(pytorchParameters); String parameters = JSONUtils.toJsonString(pytorchParameters);
TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
Mockito.when(taskExecutionContext.getTaskLogName()).thenReturn("PytorchTest");
String APP_ID = UUID.randomUUID().toString();
String folder = String.format("/tmp/dolphinscheduler_PytorchTest_%s", APP_ID);
Mockito.when(taskExecutionContext.getExecutePath()).thenReturn(folder);
Mockito.when(taskExecutionContext.getTaskAppId()).thenReturn(APP_ID);
Mockito.when(taskExecutionContext.getTenantCode()).thenReturn("root");
Mockito.when(taskExecutionContext.getStartTime()).thenReturn(new Date());
Mockito.when(taskExecutionContext.getTaskTimeout()).thenReturn(10000);
Mockito.when(taskExecutionContext.getLogPath()).thenReturn(folder + "/log");
Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters);
String envirementConfig = "export PATH=$HOME/anaconda3/bin:$PATH\n" + "export PYTHON_HOME=/bin/python";
Mockito.when(taskExecutionContext.getEnvironmentConfig()).thenReturn(envirementConfig);
String userName = System.getenv().get("USER");
Mockito.when(taskExecutionContext.getTenantCode()).thenReturn(userName);
TaskExecutionContextCacheManager.cacheTaskExecutionContext(taskExecutionContext); TaskExecutionContextCacheManager.cacheTaskExecutionContext(taskExecutionContext);
return taskExecutionContext; return taskExecutionContext;
} }
......
...@@ -18,40 +18,29 @@ ...@@ -18,40 +18,29 @@
package org.apache.dolphinscheduler.plugin.task.sagemaker; package org.apache.dolphinscheduler.plugin.task.sagemaker;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.powermock.api.mockito.PowerMockito.mock;
import static org.powermock.api.mockito.PowerMockito.when;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext;
import org.apache.dolphinscheduler.spi.utils.JSONUtils; import org.apache.dolphinscheduler.spi.utils.JSONUtils;
import org.apache.dolphinscheduler.spi.utils.PropertyUtils;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import java.io.InputStream; import java.io.InputStream;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.mockito.junit.MockitoJUnitRunner;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import com.amazonaws.services.sagemaker.AmazonSageMaker; import com.amazonaws.services.sagemaker.AmazonSageMaker;
import com.amazonaws.services.sagemaker.model.DescribePipelineExecutionResult; import com.amazonaws.services.sagemaker.model.DescribePipelineExecutionResult;
import com.amazonaws.services.sagemaker.model.ListPipelineExecutionStepsResult;
import com.amazonaws.services.sagemaker.model.PipelineExecutionStep;
import com.amazonaws.services.sagemaker.model.StartPipelineExecutionRequest; import com.amazonaws.services.sagemaker.model.StartPipelineExecutionRequest;
import com.amazonaws.services.sagemaker.model.StartPipelineExecutionResult; import com.amazonaws.services.sagemaker.model.StartPipelineExecutionResult;
import com.amazonaws.services.sagemaker.model.StopPipelineExecutionResult; import com.amazonaws.services.sagemaker.model.StopPipelineExecutionResult;
@RunWith(PowerMockRunner.class) @RunWith(MockitoJUnitRunner.class)
@PrepareForTest({JSONUtils.class, PropertyUtils.class,})
@PowerMockIgnore({"javax.*"})
public class SagemakerTaskTest { public class SagemakerTaskTest {
private final String pipelineExecutionArn = "test-pipeline-arn"; private final String pipelineExecutionArn = "test-pipeline-arn";
...@@ -66,34 +55,22 @@ public class SagemakerTaskTest { ...@@ -66,34 +55,22 @@ public class SagemakerTaskTest {
TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters);
client = mock(AmazonSageMaker.class); client = Mockito.mock(AmazonSageMaker.class);
sagemakerTask = new SagemakerTask(taskExecutionContext); sagemakerTask = new SagemakerTask(taskExecutionContext);
sagemakerTask.init(); sagemakerTask.init();
StartPipelineExecutionResult startPipelineExecutionResult = mock(StartPipelineExecutionResult.class); StartPipelineExecutionResult startPipelineExecutionResult = Mockito.mock(StartPipelineExecutionResult.class);
when(startPipelineExecutionResult.getPipelineExecutionArn()).thenReturn(pipelineExecutionArn); Mockito.when(startPipelineExecutionResult.getPipelineExecutionArn()).thenReturn(pipelineExecutionArn);
StopPipelineExecutionResult stopPipelineExecutionResult = mock(StopPipelineExecutionResult.class); StopPipelineExecutionResult stopPipelineExecutionResult = Mockito.mock(StopPipelineExecutionResult.class);
when(stopPipelineExecutionResult.getPipelineExecutionArn()).thenReturn(pipelineExecutionArn); Mockito.when(stopPipelineExecutionResult.getPipelineExecutionArn()).thenReturn(pipelineExecutionArn);
DescribePipelineExecutionResult describePipelineExecutionResult = mock(DescribePipelineExecutionResult.class); DescribePipelineExecutionResult describePipelineExecutionResult = Mockito.mock(DescribePipelineExecutionResult.class);
when(describePipelineExecutionResult.getPipelineExecutionStatus()).thenReturn("Executing", "Succeeded"); Mockito.when(describePipelineExecutionResult.getPipelineExecutionStatus()).thenReturn("Executing", "Succeeded");
ListPipelineExecutionStepsResult listPipelineExecutionStepsResult =
mock(ListPipelineExecutionStepsResult.class);
PipelineExecutionStep pipelineExecutionStep = mock(PipelineExecutionStep.class);
List<PipelineExecutionStep> pipelineExecutionSteps = new ArrayList<>();
pipelineExecutionSteps.add(pipelineExecutionStep);
pipelineExecutionSteps.add(pipelineExecutionStep);
when(pipelineExecutionStep.toString()).thenReturn("Test Step1", "Test Step2");
when(listPipelineExecutionStepsResult.getPipelineExecutionSteps()).thenReturn(pipelineExecutionSteps);
when(client.startPipelineExecution(any())).thenReturn(startPipelineExecutionResult);
when(client.stopPipelineExecution(any())).thenReturn(stopPipelineExecutionResult);
when(client.describePipelineExecution(any())).thenReturn(describePipelineExecutionResult);
when(client.listPipelineExecutionSteps(any())).thenReturn(listPipelineExecutionStepsResult);
Mockito.when(client.startPipelineExecution(any())).thenReturn(startPipelineExecutionResult);
Mockito.when(client.stopPipelineExecution(any())).thenReturn(stopPipelineExecutionResult);
Mockito.when(client.describePipelineExecution(any())).thenReturn(describePipelineExecutionResult);
} }
@Test @Test
......
...@@ -17,9 +17,6 @@ ...@@ -17,9 +17,6 @@
package org.apache.dolphinscheduler.plugin.task.spark; package org.apache.dolphinscheduler.plugin.task.spark;
import static org.powermock.api.mockito.PowerMockito.spy;
import static org.powermock.api.mockito.PowerMockito.when;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext;
import org.apache.dolphinscheduler.plugin.task.api.model.ResourceInfo; import org.apache.dolphinscheduler.plugin.task.api.model.ResourceInfo;
import org.apache.dolphinscheduler.spi.utils.JSONUtils; import org.apache.dolphinscheduler.spi.utils.JSONUtils;
...@@ -29,28 +26,21 @@ import java.util.Collections; ...@@ -29,28 +26,21 @@ import java.util.Collections;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.powermock.api.mockito.PowerMockito; import org.mockito.Mockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.mockito.junit.MockitoJUnitRunner;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
@RunWith(PowerMockRunner.class)
@PrepareForTest({
JSONUtils.class
})
@PowerMockIgnore({"javax.*"})
@RunWith(MockitoJUnitRunner.class)
public class SparkTaskTest { public class SparkTaskTest {
@Test @Test
public void testBuildCommandWithSparkSql() throws Exception { public void testBuildCommandWithSparkSql() throws Exception {
String parameters = buildSparkParametersWithSparkSql(); String parameters = buildSparkParametersWithSparkSql();
TaskExecutionContext taskExecutionContext = PowerMockito.mock(TaskExecutionContext.class); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
when(taskExecutionContext.getTaskParams()).thenReturn(parameters); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters);
when(taskExecutionContext.getExecutePath()).thenReturn("/tmp"); Mockito.when(taskExecutionContext.getExecutePath()).thenReturn("/tmp");
when(taskExecutionContext.getTaskAppId()).thenReturn("5536"); Mockito.when(taskExecutionContext.getTaskAppId()).thenReturn("5536");
SparkTask sparkTask = spy(new SparkTask(taskExecutionContext)); SparkTask sparkTask = Mockito.spy(new SparkTask(taskExecutionContext));
sparkTask.init(); sparkTask.init();
Assert.assertEquals(sparkTask.buildCommand(), Assert.assertEquals(sparkTask.buildCommand(),
"${SPARK_HOME}/bin/spark-sql " + "${SPARK_HOME}/bin/spark-sql " +
...@@ -68,11 +58,9 @@ public class SparkTaskTest { ...@@ -68,11 +58,9 @@ public class SparkTaskTest {
@Test @Test
public void testBuildCommandWithSparkSubmit() { public void testBuildCommandWithSparkSubmit() {
String parameters = buildSparkParametersWithSparkSubmit(); String parameters = buildSparkParametersWithSparkSubmit();
TaskExecutionContext taskExecutionContext = PowerMockito.mock(TaskExecutionContext.class); TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class);
when(taskExecutionContext.getTaskParams()).thenReturn(parameters); Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters);
when(taskExecutionContext.getExecutePath()).thenReturn("/tmp"); SparkTask sparkTask = Mockito.spy(new SparkTask(taskExecutionContext));
when(taskExecutionContext.getTaskAppId()).thenReturn("5536");
SparkTask sparkTask = spy(new SparkTask(taskExecutionContext));
sparkTask.init(); sparkTask.init();
Assert.assertEquals(sparkTask.buildCommand(), Assert.assertEquals(sparkTask.buildCommand(),
"${SPARK_HOME}/bin/spark-submit " + "${SPARK_HOME}/bin/spark-submit " +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册