提交 afe04847 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!919 fix pytest failed when one case compile error

Merge pull request !919 from jjfeing/master
......@@ -161,5 +161,12 @@ class CompilerPool:
ret = task_id, "Exception: Not support return type:" + str(ret_type)
return ret
def reset_task_info(self):
"""
reset task info when task compile error
"""
if self.__running_tasks:
self.__running_tasks.clear()
compile_pool = CompilerPool()
......@@ -40,6 +40,7 @@ constexpr auto kParallelCompileModule = "mindspore._extends.parallel_compile.tbe
constexpr auto kCreateParallelCompiler = "create_tbe_parallel_compiler";
constexpr auto kStartCompileOp = "start_compile_op";
constexpr auto kWaitOne = "wait_one";
constexpr auto kResetTaskInfo = "reset_task_info";
bool TbeOpParallelBuild(std::vector<AnfNodePtr> anf_nodes) {
auto build_manger = std::make_shared<ParallelBuildManager>();
......@@ -96,6 +97,8 @@ bool TbeOpParallelBuild(std::vector<AnfNodePtr> anf_nodes) {
ParallelBuildManager::ParallelBuildManager() { tbe_parallel_compiler_ = TbePythonFuncs::TbeParallelCompiler(); }
ParallelBuildManager::~ParallelBuildManager() { ResetTaskInfo(); }
int32_t ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) const {
PyObject *pRes = nullptr;
PyObject *pArgs = PyTuple_New(1);
......@@ -234,5 +237,16 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s
kernel_mod_ptr->SetWorkspaceSizeList(kernel_json_info.workspaces);
return kernel_mod_ptr;
}
void ParallelBuildManager::ResetTaskInfo() {
if (task_map_.empty()) {
MS_LOG(INFO) << "All tasks are compiled success.";
return;
}
task_map_.clear();
same_op_list_.clear();
PyObject *pArg = Py_BuildValue("()");
(void)PyObject_CallMethod(tbe_parallel_compiler_, kResetTaskInfo, "O", pArg);
}
} // namespace kernel
} // namespace mindspore
......@@ -40,7 +40,7 @@ struct KernelBuildTaskInfo {
class ParallelBuildManager {
public:
ParallelBuildManager();
~ParallelBuildManager() = default;
~ParallelBuildManager();
int32_t StartCompileOp(const nlohmann::json &kernel_json) const;
void SaveTaskInfo(int32_t task_id, const AnfNodePtr &anf_node, const std::string &json_name,
const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list,
......@@ -58,6 +58,7 @@ class ParallelBuildManager {
KernelModPtr GenKernelMod(const string &json_name, const string &processor,
const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list,
const KernelPackPtr &kernel_pack) const;
void ResetTaskInfo();
private:
PyObject *tbe_parallel_compiler_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册