未验证 提交 7f14f78c 编写于 作者: Z zyfncg 提交者: GitHub

optimize the pybind in dygraph (#42343)

上级 66f1e82f
......@@ -125,7 +125,6 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
#ifdef PADDLE_WITH_MKLDNN
tran_lod_tensor->set_mem_desc(in_lod_tensor.mem_desc());
#endif
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<phi::SelectedRows>()) {
auto &in_selected_rows = in_var.Get<phi::SelectedRows>();
......
......@@ -220,30 +220,34 @@ void Tracer::TraceOpImpl(const std::string& type,
attr_checker == nullptr ? empty_attrs_map
: attr_checker->GetDefaultAttrMap();
NameVarMap<VarType> new_ins = ins;
std::unique_ptr<NameVarMap<VarType>> ins_amp = nullptr;
if (amp_level_ == AmpLevel::O1) {
if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer();
new_ins =
imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs, tracer);
VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, new_ins);
ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastInputs<VarType>(type, imperative::AutoTuneLayout<VarType>(
type, ins, outs, &attrs, tracer)));
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastBF16Inputs<VarType>(type, ins);
ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastBF16Inputs<VarType>(type, ins));
}
} else if (amp_level_ == AmpLevel::O2) {
if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer();
new_ins =
imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs, tracer);
VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, new_ins);
ins_amp =
std::make_unique<NameVarMap<VarType>>(CastPureFp16Inputs<VarType>(
type, imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs,
tracer)));
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureBf16Inputs<VarType>(type, ins);
ins_amp = std::make_unique<NameVarMap<VarType>>(
CastPureBf16Inputs<VarType>(type, ins));
}
}
const auto& new_ins = ins_amp == nullptr ? ins : *ins_amp;
try {
if (platform::is_gpu_place(place)) {
......
......@@ -282,6 +282,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
std::vector<int> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
value.reserve(len);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
......@@ -298,6 +299,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
}
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
value.reserve(len);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
......@@ -314,6 +316,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
}
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
value.reserve(len);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
......
......@@ -81,13 +81,13 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
const char* CAST_VAR_TEMPLATE = R"(
auto %s = GetVarBaseFromArgs("%s", "%s", args, %d, %s);)";
auto %s = GetVarBaseFromArgs(op_type, "%s", args, %d, %s);)";
const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d, %s);)";
auto %s = GetVarBaseListFromArgs(op_type, "%s", args, %d, %s);)";
const char* CAST_SIZE_T_TEMPLATE = R"(
auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)";
auto %s = GetUnsignedLongFromArgs(op_type, "%s", args, %d, %s);)";
const char* ARG_TEMPLATE = R"(const %s& %s)";
......@@ -126,16 +126,17 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
PyThreadState *tstate = nullptr;
try
{
std::string op_type = "%s";
platform::RecordEvent op_type_record_event("%s pybind_imperative_func");
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
ConstructAttrMapFromPyArgs(op_type, args, %d, PyTuple_GET_SIZE(args) , attrs);
tstate = PyEval_SaveThread();
%s
imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap ins = %s;
%s
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s});
imperative::GetCurrentTracer()->TraceOp(op_type, ins, outs, attrs, {%s});
PyEval_RestoreThread(tstate);
tstate = nullptr;
%s
......@@ -208,8 +209,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type,
in_name, arg_idx++, dispensable);
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, in_name,
arg_idx++, dispensable);
if (input.dispensable()) {
const auto in_template = input.duplicable()
......@@ -279,8 +280,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type,
out_name, arg_idx++, dispensable);
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, out_name,
arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE(
inplace_map[out_name], "",
......@@ -329,7 +330,7 @@ std::string GenerateOpFunctionsBody(
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str +=
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str, op_type,
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str,
out_num_str, arg_idx++, dispensable);
} else {
outs_initializer +=
......@@ -375,11 +376,11 @@ std::string GenerateOpFunctionsBody(
// generate op funtcion body
auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, func_name, op_type, ins_cast_str, op_type,
OP_FUNCTION_TEMPLATE, func_name, op_type, op_type, ins_cast_str,
input_args_num, inplace_strategy_str, outs_initializer, ins_initializer,
ins_initializer_with_null + outs_initializer_with_null +
view_strategy_str,
op_type, inplace_mapping_str, return_str);
inplace_mapping_str, return_str);
return op_function_str;
}
......
......@@ -80,9 +80,9 @@ struct KernelSignature {
KernelSignature& operator=(KernelSignature&& other) noexcept {
name = other.name;
input_names.swap(other.input_names);
attr_names.swap(other.attr_names);
output_names.swap(other.output_names);
input_names = std::move(other.input_names);
attr_names = std::move(other.attr_names);
output_names = std::move(other.output_names);
return *this;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册