From 7f14f78cac6dfd8730832b268bb853a446f3b57b Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 28 Apr 2022 19:40:11 +0800 Subject: [PATCH] optimize the pybind in dygraph (#42343) --- paddle/fluid/framework/data_transform.cc | 1 - paddle/fluid/imperative/tracer.cc | 22 ++++++++++------- paddle/fluid/pybind/op_function_common.cc | 3 +++ paddle/fluid/pybind/op_function_generator.cc | 25 ++++++++++---------- paddle/phi/core/compat/arg_map_context.h | 6 ++--- 5 files changed, 32 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index 63e289af45..99e786d3b0 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -125,7 +125,6 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor, #ifdef PADDLE_WITH_MKLDNN tran_lod_tensor->set_mem_desc(in_lod_tensor.mem_desc()); #endif - tran_lod_tensor->set_layout(in_lod_tensor.layout()); tran_lod_tensor->ShareDataWith(tensor); } else if (in_var.IsType()) { auto &in_selected_rows = in_var.Get(); diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 6c31b02550..7b274339e3 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -220,30 +220,34 @@ void Tracer::TraceOpImpl(const std::string& type, attr_checker == nullptr ? empty_attrs_map : attr_checker->GetDefaultAttrMap(); - NameVarMap new_ins = ins; + std::unique_ptr> ins_amp = nullptr; if (amp_level_ == AmpLevel::O1) { if (amp_dtype_ == phi::DataType::FLOAT16) { const auto& tracer = imperative::GetCurrentTracer(); - new_ins = - imperative::AutoTuneLayout(type, ins, outs, &attrs, tracer); VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type; - new_ins = AutoCastInputs(type, new_ins); + ins_amp = std::make_unique>( + AutoCastInputs(type, imperative::AutoTuneLayout( + type, ins, outs, &attrs, tracer))); } else if (amp_dtype_ == phi::DataType::BFLOAT16) { VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type; - new_ins = AutoCastBF16Inputs(type, ins); + ins_amp = std::make_unique>( + AutoCastBF16Inputs(type, ins)); } } else if (amp_level_ == AmpLevel::O2) { if (amp_dtype_ == phi::DataType::FLOAT16) { const auto& tracer = imperative::GetCurrentTracer(); - new_ins = - imperative::AutoTuneLayout(type, ins, outs, &attrs, tracer); VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type; - new_ins = CastPureFp16Inputs(type, new_ins); + ins_amp = + std::make_unique>(CastPureFp16Inputs( + type, imperative::AutoTuneLayout(type, ins, outs, &attrs, + tracer))); } else if (amp_dtype_ == phi::DataType::BFLOAT16) { VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type; - new_ins = CastPureBf16Inputs(type, ins); + ins_amp = std::make_unique>( + CastPureBf16Inputs(type, ins)); } } + const auto& new_ins = ins_amp == nullptr ? ins : *ins_amp; try { if (platform::is_gpu_place(place)) { diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 5eed63d080..0e9c08cff2 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -282,6 +282,7 @@ std::vector CastPyArg2Ints(PyObject* obj, const std::string& op_type, std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); + value.reserve(len); PyObject* item = nullptr; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); @@ -298,6 +299,7 @@ std::vector CastPyArg2Ints(PyObject* obj, const std::string& op_type, } } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); + value.reserve(len); PyObject* item = nullptr; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); @@ -314,6 +316,7 @@ std::vector CastPyArg2Ints(PyObject* obj, const std::string& op_type, } } else if (PySequence_Check(obj)) { Py_ssize_t len = PySequence_Size(obj); + value.reserve(len); PyObject* item = nullptr; for (Py_ssize_t i = 0; i < len; i++) { item = PySequence_GetItem(obj, i); diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 9d5bcfac49..6bbaa147ac 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -81,13 +81,13 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr)"; const char* OUT_VAR_LIST_TYPE = R"(std::vector>)"; const char* CAST_VAR_TEMPLATE = R"( - auto %s = GetVarBaseFromArgs("%s", "%s", args, %d, %s);)"; + auto %s = GetVarBaseFromArgs(op_type, "%s", args, %d, %s);)"; const char* CAST_VAR_LIST_TEMPLATE = R"( - auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d, %s);)"; + auto %s = GetVarBaseListFromArgs(op_type, "%s", args, %d, %s);)"; const char* CAST_SIZE_T_TEMPLATE = R"( - auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)"; + auto %s = GetUnsignedLongFromArgs(op_type, "%s", args, %d, %s);)"; const char* ARG_TEMPLATE = R"(const %s& %s)"; @@ -126,16 +126,17 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs) PyThreadState *tstate = nullptr; try { + std::string op_type = "%s"; platform::RecordEvent op_type_record_event("%s pybind_imperative_func"); %s framework::AttributeMap attrs; - ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs); + ConstructAttrMapFromPyArgs(op_type, args, %d, PyTuple_GET_SIZE(args) , attrs); tstate = PyEval_SaveThread(); %s imperative::NameVarBaseMap outs = %s; imperative::NameVarBaseMap ins = %s; %s - imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s}); + imperative::GetCurrentTracer()->TraceOp(op_type, ins, outs, attrs, {%s}); PyEval_RestoreThread(tstate); tstate = nullptr; %s @@ -208,8 +209,8 @@ std::string GenerateOpFunctionsBody( const auto in_cast_type = input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; auto dispensable = input.dispensable() ? "true" : "false"; - ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type, - in_name, arg_idx++, dispensable); + ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, in_name, + arg_idx++, dispensable); if (input.dispensable()) { const auto in_template = input.duplicable() @@ -279,8 +280,8 @@ std::string GenerateOpFunctionsBody( const auto in_cast_type = output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; auto dispensable = output.dispensable() ? "true" : "false"; - ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type, - out_name, arg_idx++, dispensable); + ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, out_name, + arg_idx++, dispensable); } else if (use_inplace_strategy && inplace_map.count(out_name)) { PADDLE_ENFORCE_NE( inplace_map[out_name], "", @@ -329,7 +330,7 @@ std::string GenerateOpFunctionsBody( auto dispensable = output.dispensable() ? "true" : "false"; ins_cast_str += - paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str, op_type, + paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str, out_num_str, arg_idx++, dispensable); } else { outs_initializer += @@ -375,11 +376,11 @@ std::string GenerateOpFunctionsBody( // generate op funtcion body auto op_function_str = paddle::string::Sprintf( - OP_FUNCTION_TEMPLATE, func_name, op_type, ins_cast_str, op_type, + OP_FUNCTION_TEMPLATE, func_name, op_type, op_type, ins_cast_str, input_args_num, inplace_strategy_str, outs_initializer, ins_initializer, ins_initializer_with_null + outs_initializer_with_null + view_strategy_str, - op_type, inplace_mapping_str, return_str); + inplace_mapping_str, return_str); return op_function_str; } diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index 0c6fdcb139..f47e8d550e 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -80,9 +80,9 @@ struct KernelSignature { KernelSignature& operator=(KernelSignature&& other) noexcept { name = other.name; - input_names.swap(other.input_names); - attr_names.swap(other.attr_names); - output_names.swap(other.output_names); + input_names = std::move(other.input_names); + attr_names = std::move(other.attr_names); + output_names = std::move(other.output_names); return *this; } }; -- GitLab