未验证 提交 0cc0ee77 编写于 作者: C caishunfeng 提交者: GitHub

[Bug][Master] fix master task failover (#10065)

* fix master task failover

* ui
上级 c1642402
...@@ -52,4 +52,13 @@ public class TaskProcessorFactory { ...@@ -52,4 +52,13 @@ public class TaskProcessorFactory {
return iTaskProcessor.getClass().newInstance(); return iTaskProcessor.getClass().newInstance();
} }
/**
* if match master processor, then this task type is processed on the master
* @param type
* @return
*/
public static boolean isMasterTask(String type) {
return PROCESS_MAP.containsKey(type);
}
} }
...@@ -30,6 +30,7 @@ import org.apache.dolphinscheduler.plugin.task.api.enums.ExecutionStatus; ...@@ -30,6 +30,7 @@ import org.apache.dolphinscheduler.plugin.task.api.enums.ExecutionStatus;
import org.apache.dolphinscheduler.server.builder.TaskExecutionContextBuilder; import org.apache.dolphinscheduler.server.builder.TaskExecutionContextBuilder;
import org.apache.dolphinscheduler.server.master.config.MasterConfig; import org.apache.dolphinscheduler.server.master.config.MasterConfig;
import org.apache.dolphinscheduler.server.master.runner.WorkflowExecuteThreadPool; import org.apache.dolphinscheduler.server.master.runner.WorkflowExecuteThreadPool;
import org.apache.dolphinscheduler.server.master.runner.task.TaskProcessorFactory;
import org.apache.dolphinscheduler.server.utils.ProcessUtils; import org.apache.dolphinscheduler.server.utils.ProcessUtils;
import org.apache.dolphinscheduler.service.process.ProcessService; import org.apache.dolphinscheduler.service.process.ProcessService;
import org.apache.dolphinscheduler.service.registry.RegistryClient; import org.apache.dolphinscheduler.service.registry.RegistryClient;
...@@ -127,7 +128,11 @@ public class FailoverService { ...@@ -127,7 +128,11 @@ public class FailoverService {
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
List<ProcessInstance> needFailoverProcessInstanceList = processService.queryNeedFailoverProcessInstances(masterHost); List<ProcessInstance> needFailoverProcessInstanceList = processService.queryNeedFailoverProcessInstances(masterHost);
LOGGER.info("start master[{}] failover, process list size:{}", masterHost, needFailoverProcessInstanceList.size()); LOGGER.info("start master[{}] failover, process list size:{}", masterHost, needFailoverProcessInstanceList.size());
List<Server> workerServers = registryClient.getServerList(NodeType.WORKER);
// servers need to contains master hosts and worker hosts, otherwise the logic task will failover fail.
List<Server> servers = registryClient.getServerList(NodeType.WORKER);
servers.addAll(registryClient.getServerList(NodeType.MASTER));
for (ProcessInstance processInstance : needFailoverProcessInstanceList) { for (ProcessInstance processInstance : needFailoverProcessInstanceList) {
if (Constants.NULL.equals(processInstance.getHost())) { if (Constants.NULL.equals(processInstance.getHost())) {
continue; continue;
...@@ -136,7 +141,7 @@ public class FailoverService { ...@@ -136,7 +141,7 @@ public class FailoverService {
List<TaskInstance> validTaskInstanceList = processService.findValidTaskListByProcessId(processInstance.getId()); List<TaskInstance> validTaskInstanceList = processService.findValidTaskListByProcessId(processInstance.getId());
for (TaskInstance taskInstance : validTaskInstanceList) { for (TaskInstance taskInstance : validTaskInstanceList) {
LOGGER.info("failover task instance id: {}, process instance id: {}", taskInstance.getId(), taskInstance.getProcessInstanceId()); LOGGER.info("failover task instance id: {}, process instance id: {}", taskInstance.getId(), taskInstance.getProcessInstanceId());
failoverTaskInstance(processInstance, taskInstance, workerServers); failoverTaskInstance(processInstance, taskInstance, servers);
} }
if (serverStartupTime != null && processInstance.getRestartTime() != null if (serverStartupTime != null && processInstance.getRestartTime() != null
...@@ -198,29 +203,37 @@ public class FailoverService { ...@@ -198,29 +203,37 @@ public class FailoverService {
/** /**
* failover task instance * failover task instance
* <p> * <p>
* 1. kill yarn job if there are yarn jobs in tasks. * 1. kill yarn job if run on worker and there are yarn jobs in tasks.
* 2. change task state from running to need failover. * 2. change task state from running to need failover.
* 3. try to notify local master * 3. try to notify local master
* @param processInstance
* @param taskInstance
* @param servers if failover master, servers container master servers and worker servers; if failover worker, servers contain worker servers.
*/ */
private void failoverTaskInstance(ProcessInstance processInstance, TaskInstance taskInstance, List<Server> workerServers) { private void failoverTaskInstance(ProcessInstance processInstance, TaskInstance taskInstance, List<Server> servers) {
if (processInstance == null) { if (processInstance == null) {
LOGGER.error("failover task instance error, processInstance {} of taskInstance {} is null", LOGGER.error("failover task instance error, processInstance {} of taskInstance {} is null",
taskInstance.getProcessInstanceId(), taskInstance.getId()); taskInstance.getProcessInstanceId(), taskInstance.getId());
return; return;
} }
if (!checkTaskInstanceNeedFailover(workerServers, taskInstance)) { if (!checkTaskInstanceNeedFailover(servers, taskInstance)) {
return; return;
} }
boolean isMasterTask = TaskProcessorFactory.isMasterTask(taskInstance.getTaskType());
taskInstance.setProcessInstance(processInstance); taskInstance.setProcessInstance(processInstance);
TaskExecutionContext taskExecutionContext = TaskExecutionContextBuilder.get()
.buildTaskInstanceRelatedInfo(taskInstance) if (!isMasterTask) {
.buildProcessInstanceRelatedInfo(processInstance) TaskExecutionContext taskExecutionContext = TaskExecutionContextBuilder.get()
.create(); .buildTaskInstanceRelatedInfo(taskInstance)
.buildProcessInstanceRelatedInfo(processInstance)
if (masterConfig.isKillYarnJobWhenTaskFailover()) { .create();
// only kill yarn job if exists , the local thread has exited
ProcessUtils.killYarnJob(taskExecutionContext); if (masterConfig.isKillYarnJobWhenTaskFailover()) {
// only kill yarn job if exists , the local thread has exited
ProcessUtils.killYarnJob(taskExecutionContext);
}
} }
taskInstance.setState(ExecutionStatus.NEED_FAULT_TOLERANCE); taskInstance.setState(ExecutionStatus.NEED_FAULT_TOLERANCE);
...@@ -256,13 +269,13 @@ public class FailoverService { ...@@ -256,13 +269,13 @@ public class FailoverService {
} }
/** /**
* task needs failover if task start before worker starts * task needs failover if task start before server starts
* *
* @param workerServers worker servers * @param servers servers, can container master servers or worker servers
* @param taskInstance task instance * @param taskInstance task instance
* @return true if task instance need fail over * @return true if task instance need fail over
*/ */
private boolean checkTaskInstanceNeedFailover(List<Server> workerServers, TaskInstance taskInstance) { private boolean checkTaskInstanceNeedFailover(List<Server> servers, TaskInstance taskInstance) {
boolean taskNeedFailover = true; boolean taskNeedFailover = true;
...@@ -279,14 +292,13 @@ public class FailoverService { ...@@ -279,14 +292,13 @@ public class FailoverService {
return false; return false;
} }
//now no host will execute this task instance,so no need to failover the task //now no host will execute this task instance,so no need to failover the task
if (taskInstance.getHost() == null) { if (taskInstance.getHost() == null) {
return false; return false;
} }
//if task start after worker starts, there is no need to failover the task. //if task start after server starts, there is no need to failover the task.
if (checkTaskAfterWorkerStart(workerServers, taskInstance)) { if (checkTaskAfterServerStart(servers, taskInstance)) {
taskNeedFailover = false; taskNeedFailover = false;
} }
...@@ -296,19 +308,20 @@ public class FailoverService { ...@@ -296,19 +308,20 @@ public class FailoverService {
/** /**
* check task start after the worker server starts. * check task start after the worker server starts.
* *
* @param servers servers, can contain master servers or worker servers
* @param taskInstance task instance * @param taskInstance task instance
* @return true if task instance start time after worker server start date * @return true if task instance start time after server start date
*/ */
private boolean checkTaskAfterWorkerStart(List<Server> workerServers, TaskInstance taskInstance) { private boolean checkTaskAfterServerStart(List<Server> servers, TaskInstance taskInstance) {
if (StringUtils.isEmpty(taskInstance.getHost())) { if (StringUtils.isEmpty(taskInstance.getHost())) {
return false; return false;
} }
Date workerServerStartDate = getServerStartupTime(workerServers, taskInstance.getHost()); Date serverStartDate = getServerStartupTime(servers, taskInstance.getHost());
if (workerServerStartDate != null) { if (serverStartDate != null) {
if (taskInstance.getStartTime() == null) { if (taskInstance.getStartTime() == null) {
return taskInstance.getSubmitTime().after(workerServerStartDate); return taskInstance.getSubmitTime().after(serverStartDate);
} else { } else {
return taskInstance.getStartTime().after(workerServerStartDate); return taskInstance.getStartTime().after(serverStartDate);
} }
} }
return false; return false;
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
package org.apache.dolphinscheduler.server.master.service; package org.apache.dolphinscheduler.server.master.service;
import static org.apache.dolphinscheduler.common.Constants.COMMON_TASK_TYPE;
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.TASK_TYPE_DEPENDENT;
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.TASK_TYPE_SWITCH;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doNothing;
...@@ -30,9 +33,11 @@ import org.apache.dolphinscheduler.dao.entity.TaskInstance; ...@@ -30,9 +33,11 @@ import org.apache.dolphinscheduler.dao.entity.TaskInstance;
import org.apache.dolphinscheduler.plugin.task.api.enums.ExecutionStatus; import org.apache.dolphinscheduler.plugin.task.api.enums.ExecutionStatus;
import org.apache.dolphinscheduler.server.master.config.MasterConfig; import org.apache.dolphinscheduler.server.master.config.MasterConfig;
import org.apache.dolphinscheduler.server.master.runner.WorkflowExecuteThreadPool; import org.apache.dolphinscheduler.server.master.runner.WorkflowExecuteThreadPool;
import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;
import org.apache.dolphinscheduler.service.process.ProcessService; import org.apache.dolphinscheduler.service.process.ProcessService;
import org.apache.dolphinscheduler.service.registry.RegistryClient; import org.apache.dolphinscheduler.service.registry.RegistryClient;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Date; import java.util.Date;
...@@ -46,6 +51,7 @@ import org.mockito.Mockito; ...@@ -46,6 +51,7 @@ import org.mockito.Mockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.context.ApplicationContext;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
...@@ -72,22 +78,34 @@ public class FailoverServiceTest { ...@@ -72,22 +78,34 @@ public class FailoverServiceTest {
@Mock @Mock
private WorkflowExecuteThreadPool workflowExecuteThreadPool; private WorkflowExecuteThreadPool workflowExecuteThreadPool;
private String testHost; private static int masterPort = 5678;
private static int workerPort = 1234;
private String testMasterHost;
private String testWorkerHost;
private ProcessInstance processInstance; private ProcessInstance processInstance;
private TaskInstance taskInstance; private TaskInstance masterTaskInstance;
private TaskInstance workerTaskInstance;
@Before @Before
public void before() throws Exception { public void before() throws Exception {
given(masterConfig.getListenPort()).willReturn(8080); // init spring context
ApplicationContext applicationContext = Mockito.mock(ApplicationContext.class);
SpringApplicationContext springApplicationContext = new SpringApplicationContext();
springApplicationContext.setApplicationContext(applicationContext);
given(masterConfig.getListenPort()).willReturn(masterPort);
testHost = failoverService.getLocalAddress(); testMasterHost = failoverService.getLocalAddress();
String ip = testHost.split(":")[0]; String ip = testMasterHost.split(":")[0];
int port = Integer.valueOf(testHost.split(":")[1]); int port = Integer.valueOf(testMasterHost.split(":")[1]);
Assert.assertEquals(8080, port); Assert.assertEquals(masterPort, port);
testWorkerHost = ip + ":" + workerPort;
given(registryClient.getLock(Mockito.anyString())).willReturn(true); given(registryClient.getLock(Mockito.anyString())).willReturn(true);
given(registryClient.releaseLock(Mockito.anyString())).willReturn(true); given(registryClient.releaseLock(Mockito.anyString())).willReturn(true);
given(registryClient.getHostByEventDataPath(Mockito.anyString())).willReturn(testHost); given(registryClient.getHostByEventDataPath(Mockito.anyString())).willReturn(testMasterHost);
given(registryClient.getStoppable()).willReturn(cause -> { given(registryClient.getStoppable()).willReturn(cause -> {
}); });
given(registryClient.checkNodeExists(Mockito.anyString(), Mockito.any())).willReturn(true); given(registryClient.checkNodeExists(Mockito.anyString(), Mockito.any())).willReturn(true);
...@@ -95,30 +113,43 @@ public class FailoverServiceTest { ...@@ -95,30 +113,43 @@ public class FailoverServiceTest {
processInstance = new ProcessInstance(); processInstance = new ProcessInstance();
processInstance.setId(1); processInstance.setId(1);
processInstance.setHost(testHost); processInstance.setHost(testMasterHost);
processInstance.setRestartTime(new Date()); processInstance.setRestartTime(new Date());
processInstance.setHistoryCmd("xxx"); processInstance.setHistoryCmd("xxx");
processInstance.setCommandType(CommandType.STOP); processInstance.setCommandType(CommandType.STOP);
taskInstance = new TaskInstance(); masterTaskInstance = new TaskInstance();
taskInstance.setId(1); masterTaskInstance.setId(1);
taskInstance.setStartTime(new Date()); masterTaskInstance.setStartTime(new Date());
taskInstance.setHost(testHost); masterTaskInstance.setHost(testMasterHost);
masterTaskInstance.setTaskType(TASK_TYPE_SWITCH);
workerTaskInstance = new TaskInstance();
workerTaskInstance.setId(2);
workerTaskInstance.setStartTime(new Date());
workerTaskInstance.setHost(testWorkerHost);
workerTaskInstance.setTaskType(COMMON_TASK_TYPE);
given(processService.queryNeedFailoverTaskInstances(Mockito.anyString())).willReturn(Arrays.asList(taskInstance)); given(processService.queryNeedFailoverTaskInstances(Mockito.anyString())).willReturn(Arrays.asList(masterTaskInstance, workerTaskInstance));
given(processService.queryNeedFailoverProcessInstanceHost()).willReturn(Lists.newArrayList(testHost)); given(processService.queryNeedFailoverProcessInstanceHost()).willReturn(Lists.newArrayList(testMasterHost));
given(processService.queryNeedFailoverProcessInstances(Mockito.anyString())).willReturn(Arrays.asList(processInstance)); given(processService.queryNeedFailoverProcessInstances(Mockito.anyString())).willReturn(Arrays.asList(processInstance));
doNothing().when(processService).processNeedFailoverProcessInstances(Mockito.any(ProcessInstance.class)); doNothing().when(processService).processNeedFailoverProcessInstances(Mockito.any(ProcessInstance.class));
given(processService.findValidTaskListByProcessId(Mockito.anyInt())).willReturn(Lists.newArrayList(taskInstance)); given(processService.findValidTaskListByProcessId(Mockito.anyInt())).willReturn(Lists.newArrayList(masterTaskInstance, workerTaskInstance));
given(processService.findProcessInstanceDetailById(Mockito.anyInt())).willReturn(processInstance); given(processService.findProcessInstanceDetailById(Mockito.anyInt())).willReturn(processInstance);
Thread.sleep(1000); Thread.sleep(1000);
Server server = new Server(); Server masterServer = new Server();
server.setHost(ip); masterServer.setHost(ip);
server.setPort(port); masterServer.setPort(masterPort);
server.setCreateTime(new Date()); masterServer.setCreateTime(new Date());
given(registryClient.getServerList(NodeType.WORKER)).willReturn(Arrays.asList(server));
given(registryClient.getServerList(NodeType.MASTER)).willReturn(Arrays.asList(server)); Server workerServer = new Server();
workerServer.setHost(ip);
workerServer.setPort(workerPort);
workerServer.setCreateTime(new Date());
given(registryClient.getServerList(NodeType.WORKER)).willReturn(new ArrayList<>(Arrays.asList(workerServer)));
given(registryClient.getServerList(NodeType.MASTER)).willReturn(new ArrayList<>(Arrays.asList(masterServer)));
ReflectionTestUtils.setField(failoverService, "registryClient", registryClient); ReflectionTestUtils.setField(failoverService, "registryClient", registryClient);
doNothing().when(workflowExecuteThreadPool).submitStateEvent(Mockito.any(StateEvent.class)); doNothing().when(workflowExecuteThreadPool).submitStateEvent(Mockito.any(StateEvent.class));
...@@ -132,26 +163,26 @@ public class FailoverServiceTest { ...@@ -132,26 +163,26 @@ public class FailoverServiceTest {
@Test @Test
public void failoverMasterTest() { public void failoverMasterTest() {
processInstance.setHost(Constants.NULL); processInstance.setHost(Constants.NULL);
taskInstance.setState(ExecutionStatus.RUNNING_EXECUTION); masterTaskInstance.setState(ExecutionStatus.RUNNING_EXECUTION);
failoverService.failoverServerWhenDown(testHost, NodeType.MASTER); failoverService.failoverServerWhenDown(testMasterHost, NodeType.MASTER);
Assert.assertNotEquals(taskInstance.getState(), ExecutionStatus.NEED_FAULT_TOLERANCE); Assert.assertNotEquals(masterTaskInstance.getState(), ExecutionStatus.NEED_FAULT_TOLERANCE);
processInstance.setHost(testHost); processInstance.setHost(testMasterHost);
taskInstance.setState(ExecutionStatus.SUCCESS); masterTaskInstance.setState(ExecutionStatus.SUCCESS);
failoverService.failoverServerWhenDown(testHost, NodeType.MASTER); failoverService.failoverServerWhenDown(testMasterHost, NodeType.MASTER);
Assert.assertNotEquals(taskInstance.getState(), ExecutionStatus.NEED_FAULT_TOLERANCE); Assert.assertNotEquals(masterTaskInstance.getState(), ExecutionStatus.NEED_FAULT_TOLERANCE);
Assert.assertEquals(Constants.NULL, processInstance.getHost()); Assert.assertEquals(Constants.NULL, processInstance.getHost());
processInstance.setHost(testHost); processInstance.setHost(testMasterHost);
taskInstance.setState(ExecutionStatus.RUNNING_EXECUTION); masterTaskInstance.setState(ExecutionStatus.RUNNING_EXECUTION);
failoverService.failoverServerWhenDown(testHost, NodeType.MASTER); failoverService.failoverServerWhenDown(testMasterHost, NodeType.MASTER);
Assert.assertEquals(taskInstance.getState(), ExecutionStatus.NEED_FAULT_TOLERANCE); Assert.assertEquals(masterTaskInstance.getState(), ExecutionStatus.NEED_FAULT_TOLERANCE);
Assert.assertEquals(Constants.NULL, processInstance.getHost()); Assert.assertEquals(Constants.NULL, processInstance.getHost());
} }
@Test @Test
public void failoverWorkTest() { public void failoverWorkTest() {
failoverService.failoverServerWhenDown(testHost, NodeType.WORKER); failoverService.failoverServerWhenDown(testWorkerHost, NodeType.WORKER);
Assert.assertEquals(taskInstance.getState(), ExecutionStatus.NEED_FAULT_TOLERANCE); Assert.assertEquals(workerTaskInstance.getState(), ExecutionStatus.NEED_FAULT_TOLERANCE);
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册