提交 1835ff33 编写于 作者: L lvliang

optimize-the-time-of-producting-cacha-key-in-pynative

上级 0b407dfe
......@@ -49,8 +49,9 @@ enum PynativeStatusCode {
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
struct OpExecInfo {
PrimitivePyPtr py_primitive;
std::string op_name;
std::string prim_id;
PrimitivePyPtr py_primitive;
AbstractBasePtr abstract;
ValuePtr value = nullptr;
......
......@@ -144,6 +144,7 @@ static std::string GetId(const py::object &obj) {
static std::string GetOpId(const OpExecInfoPtr &op_exec_info) {
auto id = GetId(op_exec_info->py_primitive->GetPyObj());
op_exec_info->prim_id = id;
return id;
}
......@@ -306,6 +307,8 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
auto inst = PynativeExecutor::GetInstance();
if (inst->grad_flag()) {
op_exec_info->value = inst->GetForwardValue(op_exec_info);
} else {
(void)GetOpId(op_exec_info);
}
op_exec_info->op_inputs = args[PY_INPUTS];
ConvertInputs(prim, args[PY_INPUTS], op_exec_info);
......@@ -317,23 +320,21 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
MS_EXCEPTION_IF_NULL(op_exec_info);
std::string graph_info;
// get input tensor info
size_t input_num = op_exec_info->op_inputs.size();
for (size_t index = 0; index < input_num; ++index) {
auto input = op_exec_info->op_inputs[index];
if (py::isinstance<tensor::Tensor>(input)) {
auto tensor_ptr = py::cast<tensor::TensorPtr>(input);
(void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_");
}
for (const auto &tensor : input_tensors) {
MS_EXCEPTION_IF_NULL(tensor);
auto tensor_shape = tensor->shape();
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
[&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
(void)graph_info.append(std::to_string(tensor->data_type()) + "_");
}
// get prim and abstract info
MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
(void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
op_exec_info->abstract->ToString());
(void)graph_info.append(op_exec_info->prim_id + "_");
// get attr info
auto attr_map = op_exec_info->py_primitive->evaluate_added_attrs();
for (const auto &element : attr_map) {
(void)graph_info.append(element.second->ToString() + " ");
}
const auto &op_prim = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(op_prim);
const auto &attr_map = op_prim->evaluate_added_attrs();
(void)std::for_each(attr_map.begin(), attr_map.end(),
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
return graph_info;
}
......
......@@ -428,7 +428,7 @@ def test_pynative_resnet50():
end_time = time.time()
cost_time = end_time - start_time
print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
if step > 1 and cost_time > 0.32:
if step > 1 and cost_time > 0.21:
exceed_num = exceed_num + 1
assert exceed_num < 10
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册