提交 1835ff33 编写于 作者: L lvliang

optimize-the-time-of-producting-cacha-key-in-pynative

上级 0b407dfe
...@@ -49,8 +49,9 @@ enum PynativeStatusCode { ...@@ -49,8 +49,9 @@ enum PynativeStatusCode {
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
struct OpExecInfo { struct OpExecInfo {
PrimitivePyPtr py_primitive;
std::string op_name; std::string op_name;
std::string prim_id;
PrimitivePyPtr py_primitive;
AbstractBasePtr abstract; AbstractBasePtr abstract;
ValuePtr value = nullptr; ValuePtr value = nullptr;
......
...@@ -144,6 +144,7 @@ static std::string GetId(const py::object &obj) { ...@@ -144,6 +144,7 @@ static std::string GetId(const py::object &obj) {
static std::string GetOpId(const OpExecInfoPtr &op_exec_info) { static std::string GetOpId(const OpExecInfoPtr &op_exec_info) {
auto id = GetId(op_exec_info->py_primitive->GetPyObj()); auto id = GetId(op_exec_info->py_primitive->GetPyObj());
op_exec_info->prim_id = id;
return id; return id;
} }
...@@ -306,6 +307,8 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { ...@@ -306,6 +307,8 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
auto inst = PynativeExecutor::GetInstance(); auto inst = PynativeExecutor::GetInstance();
if (inst->grad_flag()) { if (inst->grad_flag()) {
op_exec_info->value = inst->GetForwardValue(op_exec_info); op_exec_info->value = inst->GetForwardValue(op_exec_info);
} else {
(void)GetOpId(op_exec_info);
} }
op_exec_info->op_inputs = args[PY_INPUTS]; op_exec_info->op_inputs = args[PY_INPUTS];
ConvertInputs(prim, args[PY_INPUTS], op_exec_info); ConvertInputs(prim, args[PY_INPUTS], op_exec_info);
...@@ -317,23 +320,21 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, ...@@ -317,23 +320,21 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info);
std::string graph_info; std::string graph_info;
// get input tensor info // get input tensor info
size_t input_num = op_exec_info->op_inputs.size(); for (const auto &tensor : input_tensors) {
for (size_t index = 0; index < input_num; ++index) { MS_EXCEPTION_IF_NULL(tensor);
auto input = op_exec_info->op_inputs[index]; auto tensor_shape = tensor->shape();
if (py::isinstance<tensor::Tensor>(input)) { (void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
auto tensor_ptr = py::cast<tensor::TensorPtr>(input); [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
(void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_"); (void)graph_info.append(std::to_string(tensor->data_type()) + "_");
}
} }
// get prim and abstract info // get prim and abstract info
MS_EXCEPTION_IF_NULL(op_exec_info->abstract); (void)graph_info.append(op_exec_info->prim_id + "_");
(void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
op_exec_info->abstract->ToString());
// get attr info // get attr info
auto attr_map = op_exec_info->py_primitive->evaluate_added_attrs(); const auto &op_prim = op_exec_info->py_primitive;
for (const auto &element : attr_map) { MS_EXCEPTION_IF_NULL(op_prim);
(void)graph_info.append(element.second->ToString() + " "); const auto &attr_map = op_prim->evaluate_added_attrs();
} (void)std::for_each(attr_map.begin(), attr_map.end(),
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
return graph_info; return graph_info;
} }
......
...@@ -428,7 +428,7 @@ def test_pynative_resnet50(): ...@@ -428,7 +428,7 @@ def test_pynative_resnet50():
end_time = time.time() end_time = time.time()
cost_time = end_time - start_time cost_time = end_time - start_time
print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
if step > 1 and cost_time > 0.32: if step > 1 and cost_time > 0.21:
exceed_num = exceed_num + 1 exceed_num = exceed_num + 1
assert exceed_num < 10 assert exceed_num < 10
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册