From 1835ff33c4585f3abedc6dadd952d4f480297bbf Mon Sep 17 00:00:00 2001 From: lvliang Date: Fri, 24 Jul 2020 16:11:03 +0800 Subject: [PATCH] optimize-the-time-of-producting-cacha-key-in-pynative --- mindspore/ccsrc/pipeline/pynative/base.h | 3 +- .../pipeline/pynative/pynative_execute.cc | 29 ++++++++++--------- tests/st/pynative/test_pynative_resnet50.py | 2 +- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index bccb7bd9c..2fa238f2b 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -49,8 +49,9 @@ enum PynativeStatusCode { enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; struct OpExecInfo { - PrimitivePyPtr py_primitive; std::string op_name; + std::string prim_id; + PrimitivePyPtr py_primitive; AbstractBasePtr abstract; ValuePtr value = nullptr; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 7328b3b78..a7bff65af 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -144,6 +144,7 @@ static std::string GetId(const py::object &obj) { static std::string GetOpId(const OpExecInfoPtr &op_exec_info) { auto id = GetId(op_exec_info->py_primitive->GetPyObj()); + op_exec_info->prim_id = id; return id; } @@ -306,6 +307,8 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { auto inst = PynativeExecutor::GetInstance(); if (inst->grad_flag()) { op_exec_info->value = inst->GetForwardValue(op_exec_info); + } else { + (void)GetOpId(op_exec_info); } op_exec_info->op_inputs = args[PY_INPUTS]; ConvertInputs(prim, args[PY_INPUTS], op_exec_info); @@ -317,23 +320,21 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, MS_EXCEPTION_IF_NULL(op_exec_info); std::string graph_info; // get input tensor info - size_t input_num = op_exec_info->op_inputs.size(); - for (size_t index = 0; index < input_num; ++index) { - auto input = op_exec_info->op_inputs[index]; - if (py::isinstance(input)) { - auto tensor_ptr = py::cast(input); - (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_"); - } + for (const auto &tensor : input_tensors) { + MS_EXCEPTION_IF_NULL(tensor); + auto tensor_shape = tensor->shape(); + (void)std::for_each(tensor_shape.begin(), tensor_shape.end(), + [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); }); + (void)graph_info.append(std::to_string(tensor->data_type()) + "_"); } // get prim and abstract info - MS_EXCEPTION_IF_NULL(op_exec_info->abstract); - (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" + - op_exec_info->abstract->ToString()); + (void)graph_info.append(op_exec_info->prim_id + "_"); // get attr info - auto attr_map = op_exec_info->py_primitive->evaluate_added_attrs(); - for (const auto &element : attr_map) { - (void)graph_info.append(element.second->ToString() + " "); - } + const auto &op_prim = op_exec_info->py_primitive; + MS_EXCEPTION_IF_NULL(op_prim); + const auto &attr_map = op_prim->evaluate_added_attrs(); + (void)std::for_each(attr_map.begin(), attr_map.end(), + [&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); }); return graph_info; } diff --git a/tests/st/pynative/test_pynative_resnet50.py b/tests/st/pynative/test_pynative_resnet50.py index 23b86c724..720dad3ec 100644 --- a/tests/st/pynative/test_pynative_resnet50.py +++ b/tests/st/pynative/test_pynative_resnet50.py @@ -428,7 +428,7 @@ def test_pynative_resnet50(): end_time = time.time() cost_time = end_time - start_time print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) - if step > 1 and cost_time > 0.32: + if step > 1 and cost_time > 0.21: exceed_num = exceed_num + 1 assert exceed_num < 10 \ No newline at end of file -- GitLab