提交 c2ab8d6c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5525 check worker thread exception

Merge pull request !5525 from kisnwang/async-run-graph
......@@ -110,6 +110,12 @@ Executor::Executor(const std::string &device_name, uint32_t device_id) {
worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
}
void Executor::CheckException() {
if (exception_ptr_ != nullptr) {
std::rethrow_exception(exception_ptr_);
}
}
void Executor::WorkerJoin() {
StopWorker();
worker_->join();
......@@ -128,7 +134,11 @@ void Executor::WorkerLoop() {
OnWorkerExit();
return;
}
task->Run();
try {
task->Run();
} catch (const std::exception &e) {
exception_ptr_ = std::current_exception();
}
if (task->type_ == kCompileNodes) {
compile_cond_var_.notify_all();
} else if (task->type_ == kCompileGraph) {
......@@ -183,6 +193,7 @@ bool Executor::IsAllInputsReady(const std::vector<tensor::TensorPtr> &inputs) {
GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrList &lst,
const AnfNodePtrList &outputs) {
CheckException();
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<CompileNodesTask>();
task->session_ = session;
......@@ -191,10 +202,12 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL
ready_tasks_.push(task);
task_cond_var_.notify_all();
compile_cond_var_.wait(lock);
CheckException();
return task->graph_id_;
}
GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
CheckException();
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<CompileGraphTask>();
task->session_ = session;
......@@ -202,10 +215,12 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraph
ready_tasks_.push(task);
task_cond_var_.notify_all();
compile_cond_var_.wait(lock);
CheckException();
return task->graph_id_;
}
void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) {
CheckException();
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<BuildGraphTask>();
task->session_ = session;
......@@ -213,10 +228,12 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) {
ready_tasks_.push(task);
task_cond_var_.notify_all();
build_cond_var_.wait(lock);
CheckException();
}
void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
CheckException();
auto task = std::make_shared<RunGraphTask>();
task->session_ = session;
task->graph_id_ = graph_id;
......@@ -237,10 +254,12 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
task_cond_var_.notify_all();
py::gil_scoped_release release;
run_cond_var_.wait(lock);
CheckException();
}
void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) {
CheckException();
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<BuildOpTask>();
task->session_ = session;
......@@ -251,10 +270,12 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c
ready_tasks_.push(task);
task_cond_var_.notify_all();
build_op_cond_var_.wait(lock);
CheckException();
}
py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
CheckException();
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<RunOpTask>();
task->session_ = session;
......@@ -264,6 +285,7 @@ py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info
ready_tasks_.push(task);
task_cond_var_.notify_all();
run_op_cond_var_.wait(lock);
CheckException();
// Trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(task->outputs_);
......
......@@ -26,6 +26,7 @@
#include <thread>
#include <mutex>
#include <condition_variable>
#include <exception>
#include "backend/session/session_basic.h"
#include "ir/anf.h"
#include "ir/tensor.h"
......@@ -128,11 +129,12 @@ class Executor {
const std::vector<tensor::TensorPtr> &input_tensors);
void OnRunGraphFinished();
protected:
private:
void UpdateOutputTensors(VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node);
std::vector<std::shared_ptr<RunGraphTask>> GetNewReadyTasks();
bool IsAllInputsReady(const std::vector<tensor::TensorPtr> &inputs);
void CheckException();
void StopWorker();
void OnWorkerExit();
......@@ -149,6 +151,7 @@ class Executor {
std::queue<std::shared_ptr<Task>> ready_tasks_;
std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
std::shared_ptr<std::thread> worker_;
std::exception_ptr exception_ptr_{nullptr};
};
} // namespace session
} // namespace mindspore
......
......@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANGER_H_
#define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANGER_H_
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANAGER_H_
#define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANAGER_H_
#include <set>
#include <map>
#include <string>
......@@ -42,4 +42,4 @@ class ExecutorManager {
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANGER_H_
#endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANAGER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册