未验证 提交 7f14f78c 编写于 作者: Z zyfncg 提交者: GitHub

optimize the pybind in dygraph (#42343)

上级 66f1e82f
...@@ -125,7 +125,6 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor, ...@@ -125,7 +125,6 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
tran_lod_tensor->set_mem_desc(in_lod_tensor.mem_desc()); tran_lod_tensor->set_mem_desc(in_lod_tensor.mem_desc());
#endif #endif
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor); tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<phi::SelectedRows>()) { } else if (in_var.IsType<phi::SelectedRows>()) {
auto &in_selected_rows = in_var.Get<phi::SelectedRows>(); auto &in_selected_rows = in_var.Get<phi::SelectedRows>();
......
...@@ -220,30 +220,34 @@ void Tracer::TraceOpImpl(const std::string& type, ...@@ -220,30 +220,34 @@ void Tracer::TraceOpImpl(const std::string& type,
attr_checker == nullptr ? empty_attrs_map attr_checker == nullptr ? empty_attrs_map
: attr_checker->GetDefaultAttrMap(); : attr_checker->GetDefaultAttrMap();
NameVarMap<VarType> new_ins = ins; std::unique_ptr<NameVarMap<VarType>> ins_amp = nullptr;
if (amp_level_ == AmpLevel::O1) { if (amp_level_ == AmpLevel::O1) {
if (amp_dtype_ == phi::DataType::FLOAT16) { if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer(); const auto& tracer = imperative::GetCurrentTracer();
new_ins =
imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs, tracer);
VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type; VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, new_ins); ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastInputs<VarType>(type, imperative::AutoTuneLayout<VarType>(
type, ins, outs, &attrs, tracer)));
} else if (amp_dtype_ == phi::DataType::BFLOAT16) { } else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type; VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastBF16Inputs<VarType>(type, ins); ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastBF16Inputs<VarType>(type, ins));
} }
} else if (amp_level_ == AmpLevel::O2) { } else if (amp_level_ == AmpLevel::O2) {
if (amp_dtype_ == phi::DataType::FLOAT16) { if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer(); const auto& tracer = imperative::GetCurrentTracer();
new_ins =
imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs, tracer);
VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type; VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, new_ins); ins_amp =
std::make_unique<NameVarMap<VarType>>(CastPureFp16Inputs<VarType>(
type, imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs,
tracer)));
} else if (amp_dtype_ == phi::DataType::BFLOAT16) { } else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type; VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureBf16Inputs<VarType>(type, ins); ins_amp = std::make_unique<NameVarMap<VarType>>(
CastPureBf16Inputs<VarType>(type, ins));
} }
} }
const auto& new_ins = ins_amp == nullptr ? ins : *ins_amp;
try { try {
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
......
...@@ -282,6 +282,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type, ...@@ -282,6 +282,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
std::vector<int> value; std::vector<int> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
value.reserve(len);
PyObject* item = nullptr; PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
...@@ -298,6 +299,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type, ...@@ -298,6 +299,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
} }
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
value.reserve(len);
PyObject* item = nullptr; PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
...@@ -314,6 +316,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type, ...@@ -314,6 +316,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
} }
} else if (PySequence_Check(obj)) { } else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj); Py_ssize_t len = PySequence_Size(obj);
value.reserve(len);
PyObject* item = nullptr; PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i); item = PySequence_GetItem(obj, i);
......
...@@ -81,13 +81,13 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)"; ...@@ -81,13 +81,13 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)"; const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
const char* CAST_VAR_TEMPLATE = R"( const char* CAST_VAR_TEMPLATE = R"(
auto %s = GetVarBaseFromArgs("%s", "%s", args, %d, %s);)"; auto %s = GetVarBaseFromArgs(op_type, "%s", args, %d, %s);)";
const char* CAST_VAR_LIST_TEMPLATE = R"( const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d, %s);)"; auto %s = GetVarBaseListFromArgs(op_type, "%s", args, %d, %s);)";
const char* CAST_SIZE_T_TEMPLATE = R"( const char* CAST_SIZE_T_TEMPLATE = R"(
auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)"; auto %s = GetUnsignedLongFromArgs(op_type, "%s", args, %d, %s);)";
const char* ARG_TEMPLATE = R"(const %s& %s)"; const char* ARG_TEMPLATE = R"(const %s& %s)";
...@@ -126,16 +126,17 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs) ...@@ -126,16 +126,17 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try try
{ {
std::string op_type = "%s";
platform::RecordEvent op_type_record_event("%s pybind_imperative_func"); platform::RecordEvent op_type_record_event("%s pybind_imperative_func");
%s %s
framework::AttributeMap attrs; framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs); ConstructAttrMapFromPyArgs(op_type, args, %d, PyTuple_GET_SIZE(args) , attrs);
tstate = PyEval_SaveThread(); tstate = PyEval_SaveThread();
%s %s
imperative::NameVarBaseMap outs = %s; imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap ins = %s; imperative::NameVarBaseMap ins = %s;
%s %s
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s}); imperative::GetCurrentTracer()->TraceOp(op_type, ins, outs, attrs, {%s});
PyEval_RestoreThread(tstate); PyEval_RestoreThread(tstate);
tstate = nullptr; tstate = nullptr;
%s %s
...@@ -208,8 +209,8 @@ std::string GenerateOpFunctionsBody( ...@@ -208,8 +209,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type = const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false"; auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type, ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, in_name,
in_name, arg_idx++, dispensable); arg_idx++, dispensable);
if (input.dispensable()) { if (input.dispensable()) {
const auto in_template = input.duplicable() const auto in_template = input.duplicable()
...@@ -279,8 +280,8 @@ std::string GenerateOpFunctionsBody( ...@@ -279,8 +280,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type = const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false"; auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type, ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, out_name,
out_name, arg_idx++, dispensable); arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) { } else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
inplace_map[out_name], "", inplace_map[out_name], "",
...@@ -329,7 +330,7 @@ std::string GenerateOpFunctionsBody( ...@@ -329,7 +330,7 @@ std::string GenerateOpFunctionsBody(
auto dispensable = output.dispensable() ? "true" : "false"; auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += ins_cast_str +=
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str, op_type, paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str,
out_num_str, arg_idx++, dispensable); out_num_str, arg_idx++, dispensable);
} else { } else {
outs_initializer += outs_initializer +=
...@@ -375,11 +376,11 @@ std::string GenerateOpFunctionsBody( ...@@ -375,11 +376,11 @@ std::string GenerateOpFunctionsBody(
// generate op funtcion body // generate op funtcion body
auto op_function_str = paddle::string::Sprintf( auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, func_name, op_type, ins_cast_str, op_type, OP_FUNCTION_TEMPLATE, func_name, op_type, op_type, ins_cast_str,
input_args_num, inplace_strategy_str, outs_initializer, ins_initializer, input_args_num, inplace_strategy_str, outs_initializer, ins_initializer,
ins_initializer_with_null + outs_initializer_with_null + ins_initializer_with_null + outs_initializer_with_null +
view_strategy_str, view_strategy_str,
op_type, inplace_mapping_str, return_str); inplace_mapping_str, return_str);
return op_function_str; return op_function_str;
} }
......
...@@ -80,9 +80,9 @@ struct KernelSignature { ...@@ -80,9 +80,9 @@ struct KernelSignature {
KernelSignature& operator=(KernelSignature&& other) noexcept { KernelSignature& operator=(KernelSignature&& other) noexcept {
name = other.name; name = other.name;
input_names.swap(other.input_names); input_names = std::move(other.input_names);
attr_names.swap(other.attr_names); attr_names = std::move(other.attr_names);
output_names.swap(other.output_names); output_names = std::move(other.output_names);
return *this; return *this;
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册