未验证 提交 046553c7 编写于 作者: C Chen Weihang 提交者: GitHub

Support setting infershape function for custom grad op (#38776)

* unify infer_shape func calling

* support set grad infer shape fn for custom op

* unify infershape in new executor and eager

* remove todo comment

* revert infershape in operator
上级 cd2855b0
...@@ -174,8 +174,7 @@ static void PreparedOpRunImpl( ...@@ -174,8 +174,7 @@ static void PreparedOpRunImpl(
EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs, EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs,
op.Type()); op.Type());
static_cast<const paddle::framework::OperatorWithKernel&>(op).InferShape( op.Info().infer_shape_(&infer_shape_ctx);
&infer_shape_ctx);
func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs, func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs,
default_attrs)); default_attrs));
......
...@@ -94,7 +94,7 @@ std::vector<std::string> ParseAttrStr(const std::string& attr) { ...@@ -94,7 +94,7 @@ std::vector<std::string> ParseAttrStr(const std::string& attr) {
// 2. type // 2. type
rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1))); rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1)));
VLOG(1) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1]; VLOG(3) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];
return rlt; return rlt;
} }
...@@ -109,11 +109,11 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, ...@@ -109,11 +109,11 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) { const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: Start run KernelFunc."; VLOG(3) << "Custom Operator: Start run KernelFunc.";
std::vector<paddle::experimental::Tensor> custom_ins; std::vector<paddle::experimental::Tensor> custom_ins;
std::vector<std::vector<paddle::experimental::Tensor>> custom_vec_ins; std::vector<std::vector<paddle::experimental::Tensor>> custom_vec_ins;
for (auto& in_name : inputs) { for (auto& in_name : inputs) {
VLOG(1) << "Custom Operator: input name - " << in_name; VLOG(3) << "Custom Operator: input name - " << in_name;
if (detail::IsDuplicableVar(in_name)) { if (detail::IsDuplicableVar(in_name)) {
// return const std::vector<const Tensor*> // return const std::vector<const Tensor*>
auto vec_x = ctx.MultiInput<Tensor>(in_name); auto vec_x = ctx.MultiInput<Tensor>(in_name);
...@@ -185,11 +185,11 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, ...@@ -185,11 +185,11 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
} }
} }
VLOG(1) << "Custom Operator: Run ComputeFunc."; VLOG(3) << "Custom Operator: Run ComputeFunc.";
try { try {
auto outs = func(custom_ins, custom_vec_ins, custom_attrs); auto outs = func(custom_ins, custom_vec_ins, custom_attrs);
VLOG(1) << "Custom Operator: Share outputs into ExecutionContext."; VLOG(3) << "Custom Operator: Share outputs into ExecutionContext.";
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i]; auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) { if (detail::IsDuplicableVar(out_name)) {
...@@ -230,6 +230,95 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, ...@@ -230,6 +230,95 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
} }
} }
static void RunInferShapeFunc(framework::InferShapeContext* ctx,
const paddle::InferShapeFunc& func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes;
VLOG(3) << "Custom Operator: InferShape - get input ddim.";
for (auto& in_name : inputs) {
if (detail::IsDuplicableVar(in_name)) {
OP_INOUT_CHECK(ctx->HasInputs(in_name), "Input", in_name, "Custom");
auto vec_ddim = ctx->GetInputsDim(in_name);
std::vector<std::vector<int64_t>> vec_shape;
vec_shape.reserve(vec_ddim.size());
std::transform(vec_ddim.begin(), vec_ddim.end(),
std::back_inserter(vec_shape),
[&](const DDim& ddim) -> std::vector<int64_t> {
return framework::vectorize(ddim);
});
vec_input_shapes.emplace_back(vec_shape);
} else {
OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom");
auto ddim = ctx->GetInputDim(in_name);
input_shapes.emplace_back(framework::vectorize(ddim));
}
}
std::vector<paddle::any> custom_attrs;
for (auto& attr_str : attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx->Attrs().Get<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx->Attrs().Get<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx->Attrs().Get<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx->Attrs().Get<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx->Attrs().Get<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(ctx->Attrs().Get<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
// NOTE(chenweihang): InferShape can't support std::vector<int64_t>
// attr type, because the input type is std::vector<int64_t>, only
// can use one rule to parse std::vector<int64_t> parameter
continue;
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<std::string>`, "
"Please check whether the attribute data type and "
"data type string are matched.",
attr_type_str));
}
}
VLOG(3) << "Custom Operator: InferShape - calc output ddim.";
auto output_shapes = func(input_shapes, vec_input_shapes, custom_attrs);
VLOG(3) << "Custom Operator: InferShape - set output ddim.";
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) {
std::vector<DDim> vec_ddim;
vec_ddim.reserve(output_shapes.size());
std::transform(output_shapes.begin(), output_shapes.end(),
std::back_inserter(vec_ddim),
[&](const std::vector<int64_t>& shape) -> DDim {
return framework::make_ddim(shape);
});
ctx->SetOutputsDim(out_name, vec_ddim);
} else {
ctx->SetOutputDim(out_name, framework::make_ddim(output_shapes[i]));
}
}
}
//////////////////// Operator Define ///////////////// //////////////////// Operator Define /////////////////
class CustomOperator : public OperatorWithKernel { class CustomOperator : public OperatorWithKernel {
...@@ -239,7 +328,7 @@ class CustomOperator : public OperatorWithKernel { ...@@ -239,7 +328,7 @@ class CustomOperator : public OperatorWithKernel {
// Dummy infershape // Dummy infershape
// Because it is a pure virtual function, it must be implemented // Because it is a pure virtual function, it must be implemented
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
VLOG(1) << "Custom Operator: Dummy infer shape of custom operator."; VLOG(3) << "Custom Operator: Dummy infer shape of custom operator.";
} }
/** /**
...@@ -381,7 +470,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> { ...@@ -381,7 +470,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
auto fwd_op_outputs = this->OutputNames(); auto fwd_op_outputs = this->OutputNames();
for (auto& in_name : inputs_) { for (auto& in_name : inputs_) {
VLOG(1) << "Custom Operator: GradOpDescMaker - input: " << in_name; VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) { if (!detail::IsGradVar(in_name)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) { if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name)); grad_op->SetInput(in_name, this->Input(in_name));
...@@ -398,7 +487,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> { ...@@ -398,7 +487,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
} }
} }
for (auto& out_name : outputs_) { for (auto& out_name : outputs_) {
VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name; VLOG(3) << "Custom Operator: GradOpDescMaker - output: " << out_name;
if (detail::IsDuplicableVar(out_name)) { if (detail::IsDuplicableVar(out_name)) {
grad_op->SetOutput(out_name, grad_op->SetOutput(out_name,
this->InputGrad(detail::NoGrad(out_name), this->InputGrad(detail::NoGrad(out_name),
...@@ -447,7 +536,7 @@ class CustomGradOpMaker<imperative::OpBase> ...@@ -447,7 +536,7 @@ class CustomGradOpMaker<imperative::OpBase>
auto fwd_op_outputs = this->OutputNames(); auto fwd_op_outputs = this->OutputNames();
for (auto& in_name : inputs_) { for (auto& in_name : inputs_) {
VLOG(1) << "Custom Operator: GradOpBaseMaker - input: " << in_name; VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) { if (!detail::IsGradVar(in_name)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) { if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name)); grad_op->SetInput(in_name, this->Input(in_name));
...@@ -464,7 +553,7 @@ class CustomGradOpMaker<imperative::OpBase> ...@@ -464,7 +553,7 @@ class CustomGradOpMaker<imperative::OpBase>
} }
} }
for (auto& out_name : outputs_) { for (auto& out_name : outputs_) {
VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name; VLOG(3) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name))); grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
} }
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
...@@ -486,11 +575,11 @@ void RegisterOperatorKernelWithPlace(const std::string& name, ...@@ -486,11 +575,11 @@ void RegisterOperatorKernelWithPlace(const std::string& name,
const std::vector<std::string>& outputs, const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) { const std::vector<std::string>& attrs) {
OpKernelType key(type, experimental::ConvertExtPlaceToInnerPlace(place)); OpKernelType key(type, experimental::ConvertExtPlaceToInnerPlace(place));
VLOG(1) << "Custom Operator: op kernel key: " << key; VLOG(3) << "Custom Operator: op kernel key: " << key;
OperatorWithKernel::AllOpKernels()[name][key] = OperatorWithKernel::AllOpKernels()[name][key] =
[kernel_func, inputs, outputs, [kernel_func, inputs, outputs,
attrs](const framework::ExecutionContext& ctx) { attrs](const framework::ExecutionContext& ctx) {
VLOG(1) << "Custom Operator: run custom kernel func in lambda."; VLOG(3) << "Custom Operator: run custom kernel func in lambda.";
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs); RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs);
}; };
} }
...@@ -500,7 +589,7 @@ void RegisterOperatorKernel(const std::string& name, ...@@ -500,7 +589,7 @@ void RegisterOperatorKernel(const std::string& name,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) { const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: op name in kernel: " << name; VLOG(3) << "Custom Operator: op name in kernel: " << name;
// NOTE [ Dummy Op Kernel Key ] // NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based // TODO(chenweihang): Because execute engine need get device context based
// op_kernel_key.place_, so we should register kernel for each // op_kernel_key.place_, so we should register kernel for each
...@@ -535,12 +624,12 @@ void RegisterOperatorWithMetaInfo( ...@@ -535,12 +624,12 @@ void RegisterOperatorWithMetaInfo(
auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta); auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta);
auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta); auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta);
VLOG(1) << "Custom Operator: forward, op name: " << op_name; VLOG(3) << "Custom Operator: forward, op name: " << op_name;
VLOG(1) << "Custom Operator: forward, op inputs: " VLOG(3) << "Custom Operator: forward, op inputs: "
<< string::join_strings(op_inputs, ','); << string::join_strings(op_inputs, ',');
VLOG(1) << "Custom Operator: forward, op outputs: " VLOG(3) << "Custom Operator: forward, op outputs: "
<< string::join_strings(op_outputs, ','); << string::join_strings(op_outputs, ',');
VLOG(1) << "Custom Operator: forward, op attrs: " VLOG(3) << "Custom Operator: forward, op attrs: "
<< string::join_strings(op_attrs, ','); << string::join_strings(op_attrs, ',');
// Op // Op
...@@ -588,96 +677,13 @@ void RegisterOperatorWithMetaInfo( ...@@ -588,96 +677,13 @@ void RegisterOperatorWithMetaInfo(
"Please set the InferShapeFn of custom " "Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
VLOG(1) << "Custom Operator: Default InferShape - share ddim."; VLOG(3) << "Custom Operator: Default InferShape - share ddim.";
ctx->ShareDim(op_inputs[0], op_outputs[0]); ctx->ShareDim(op_inputs[0], op_outputs[0]);
}; };
} else { } else {
info.infer_shape_ = [op_inputs, op_outputs, op_attrs, info.infer_shape_ = [op_inputs, op_outputs, op_attrs,
infer_shape_func](InferShapeContext* ctx) { infer_shape_func](InferShapeContext* ctx) {
std::vector<std::vector<int64_t>> input_shapes; RunInferShapeFunc(ctx, infer_shape_func, op_inputs, op_outputs, op_attrs);
std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes;
VLOG(1) << "Custom Operator: InferShape - get input ddim.";
for (auto& in_name : op_inputs) {
if (detail::IsDuplicableVar(in_name)) {
OP_INOUT_CHECK(ctx->HasInputs(in_name), "Input", in_name, "Custom");
auto vec_ddim = ctx->GetInputsDim(in_name);
std::vector<std::vector<int64_t>> vec_shape;
vec_shape.reserve(vec_ddim.size());
std::transform(vec_ddim.begin(), vec_ddim.end(),
std::back_inserter(vec_shape),
[&](const DDim& ddim) -> std::vector<int64_t> {
return framework::vectorize(ddim);
});
vec_input_shapes.emplace_back(vec_shape);
} else {
OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom");
auto ddim = ctx->GetInputDim(in_name);
input_shapes.emplace_back(framework::vectorize(ddim));
}
}
std::vector<paddle::any> custom_attrs;
for (auto& attr_str : op_attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx->Attrs().Get<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx->Attrs().Get<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx->Attrs().Get<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx->Attrs().Get<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx->Attrs().Get<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
// NOTE(chenweihang): InferShape can't support std::vector<int64_t>
// attr type, because the input type is std::vector<int64_t>, only
// can use one rule to parse std::vector<int64_t> parameter
continue;
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<std::string>`, "
"Please check whether the attribute data type and "
"data type string are matched.",
attr_type_str));
}
}
VLOG(1) << "Custom Operator: InferShape - calc output ddim.";
auto output_shapes =
infer_shape_func(input_shapes, vec_input_shapes, custom_attrs);
VLOG(1) << "Custom Operator: InferShape - set output ddim.";
for (size_t i = 0; i < op_outputs.size(); ++i) {
auto out_name = op_outputs[i];
if (detail::IsDuplicableVar(out_name)) {
std::vector<DDim> vec_ddim;
vec_ddim.reserve(output_shapes.size());
std::transform(output_shapes.begin(), output_shapes.end(),
std::back_inserter(vec_ddim),
[&](const std::vector<int64_t>& shape) -> DDim {
return framework::make_ddim(shape);
});
ctx->SetOutputsDim(out_name, vec_ddim);
} else {
ctx->SetOutputDim(out_name, framework::make_ddim(output_shapes[i]));
}
}
}; };
} }
...@@ -706,7 +712,7 @@ void RegisterOperatorWithMetaInfo( ...@@ -706,7 +712,7 @@ void RegisterOperatorWithMetaInfo(
"Please set the InferDtypeFn of custom " "Please set the InferDtypeFn of custom "
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))")); "operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
VLOG(1) << "Custom Operator: InferDtype - share dtype."; VLOG(3) << "Custom Operator: InferDtype - share dtype.";
auto dtype = ctx->GetInputDataType(op_inputs[0]); auto dtype = ctx->GetInputDataType(op_inputs[0]);
ctx->SetOutputDataType(op_outputs[0], dtype); ctx->SetOutputDataType(op_outputs[0], dtype);
}; };
...@@ -716,7 +722,7 @@ void RegisterOperatorWithMetaInfo( ...@@ -716,7 +722,7 @@ void RegisterOperatorWithMetaInfo(
std::vector<DataType> input_dtypes; std::vector<DataType> input_dtypes;
std::vector<std::vector<DataType>> vec_input_dtypes; std::vector<std::vector<DataType>> vec_input_dtypes;
VLOG(1) << "Custom Operator: InferDtype - get input dtype."; VLOG(3) << "Custom Operator: InferDtype - get input dtype.";
for (auto& in_name : op_inputs) { for (auto& in_name : op_inputs) {
if (detail::IsDuplicableVar(in_name)) { if (detail::IsDuplicableVar(in_name)) {
std::vector<DataType> vec_custom_dtype; std::vector<DataType> vec_custom_dtype;
...@@ -731,10 +737,10 @@ void RegisterOperatorWithMetaInfo( ...@@ -731,10 +737,10 @@ void RegisterOperatorWithMetaInfo(
} }
} }
VLOG(1) << "Custom Operator: InferDtype - infer output dtype."; VLOG(3) << "Custom Operator: InferDtype - infer output dtype.";
auto output_dtypes = infer_dtype_func(input_dtypes, vec_input_dtypes); auto output_dtypes = infer_dtype_func(input_dtypes, vec_input_dtypes);
VLOG(1) << "Custom Operator: InferDtype - set output dtype."; VLOG(3) << "Custom Operator: InferDtype - set output dtype.";
for (size_t i = 0; i < op_outputs.size(); ++i) { for (size_t i = 0; i < op_outputs.size(); ++i) {
auto out_name = op_outputs[i]; auto out_name = op_outputs[i];
if (detail::IsDuplicableVar(out_name)) { if (detail::IsDuplicableVar(out_name)) {
...@@ -763,11 +769,12 @@ void RegisterOperatorWithMetaInfo( ...@@ -763,11 +769,12 @@ void RegisterOperatorWithMetaInfo(
auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op); auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op);
auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op); auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op);
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op); auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);
auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op);
VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name; VLOG(3) << "Custom Operator: backward, op name: " << grad_op_name;
VLOG(1) << "Custom Operator: backward, op inputs: " VLOG(3) << "Custom Operator: backward, op inputs: "
<< string::join_strings(grad_op_inputs, ','); << string::join_strings(grad_op_inputs, ',');
VLOG(1) << "Custom Operator: backward, op outputs: " VLOG(3) << "Custom Operator: backward, op outputs: "
<< string::join_strings(grad_op_outputs, ','); << string::join_strings(grad_op_outputs, ',');
// GradOpDescMaker // GradOpDescMaker
...@@ -809,40 +816,52 @@ void RegisterOperatorWithMetaInfo( ...@@ -809,40 +816,52 @@ void RegisterOperatorWithMetaInfo(
}; };
// Grad InferShape // Grad InferShape
grad_info.infer_shape_ = [grad_op_inputs, if (grad_infer_shape_fn == nullptr) {
grad_op_outputs](InferShapeContext* ctx) { grad_info.infer_shape_ = [grad_op_inputs,
// 1. if forward input exists, gradient's shape is same with forward input grad_op_outputs](InferShapeContext* ctx) {
// default // 1. if forward input exists, gradient's shape is same with forward
// [Suitable for most situations] // input
// 2. if forward input not exists, and only contains one grad input and // default
// output, // [Suitable for most situations]
// use grad input shape as grad output shape // 2. if forward input not exists, and only contains one grad input and
// [Suitable for the situation that forward input is not used as // output,
// backward input] // use grad input shape as grad output shape
// TODO(chenweihang): support set grad op infershape func if needed // [Suitable for the situation that forward input is not used as
for (auto& out_name : grad_op_outputs) { // backward input]
auto fwd_name = detail::NoGrad(out_name); for (auto& out_name : grad_op_outputs) {
if (detail::IsDuplicableVar(fwd_name)) { auto fwd_name = detail::NoGrad(out_name);
// Duplicable forward var must as backward input if (detail::IsDuplicableVar(fwd_name)) {
ctx->ShareDim(fwd_name, out_name); // Duplicable forward var must as backward input
} else {
if (ctx->HasInput(fwd_name)) {
ctx->ShareDim(fwd_name, out_name); ctx->ShareDim(fwd_name, out_name);
} else { } else {
PADDLE_ENFORCE_EQ( if (ctx->HasInput(fwd_name)) {
grad_op_inputs.size() == 1UL && grad_op_outputs.size() == 1UL, ctx->ShareDim(fwd_name, out_name);
true, } else {
platform::errors::Unavailable( PADDLE_ENFORCE_EQ(
"Custom grad operator infershape error. " grad_op_inputs.size() == 1UL && grad_op_outputs.size() == 1UL,
"If a custom grad operator contains only one input and " true,
"only one output, the input shape will be directly set to " platform::errors::Unavailable(
"the output shape. Otherwise, Please set the forward input " "Custom grad operator infershape error. "
"as the grad operator's input.")); "If a custom grad operator contains only one input and "
ctx->ShareDim(grad_op_inputs[0], out_name); "only one output, the input shape will be directly set "
"to "
"the output shape. Otherwise, Please set the forward "
"input "
"as the grad operator's input or set the InferShapeFn "
"of custom grad operator by "
".SetInferShapeFn(PD_INFER_SHAPE(...))"));
ctx->ShareDim(grad_op_inputs[0], out_name);
}
} }
} }
} };
}; } else {
grad_info.infer_shape_ = [grad_op_inputs, grad_op_outputs, grad_op_attrs,
grad_infer_shape_fn](InferShapeContext* ctx) {
RunInferShapeFunc(ctx, grad_infer_shape_fn, grad_op_inputs,
grad_op_outputs, grad_op_attrs);
};
}
// Kernel func // Kernel func
RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs, RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs,
...@@ -860,11 +879,11 @@ void RegisterOperatorWithMetaInfo( ...@@ -860,11 +879,11 @@ void RegisterOperatorWithMetaInfo(
void RegisterOperatorWithMetaInfoMap( void RegisterOperatorWithMetaInfoMap(
const paddle::OpMetaInfoMap& op_meta_info_map) { const paddle::OpMetaInfoMap& op_meta_info_map) {
auto& meta_info_map = op_meta_info_map.GetMap(); auto& meta_info_map = op_meta_info_map.GetMap();
VLOG(1) << "Custom Operator: size of op meta info map - " VLOG(3) << "Custom Operator: size of op meta info map - "
<< meta_info_map.size(); << meta_info_map.size();
// pair: {op_type, OpMetaInfo} // pair: {op_type, OpMetaInfo}
for (auto& pair : meta_info_map) { for (auto& pair : meta_info_map) {
VLOG(1) << "Custom Operator: pair first -> op name: " << pair.first; VLOG(3) << "Custom Operator: pair first -> op name: " << pair.first;
RegisterOperatorWithMetaInfo(pair.second); RegisterOperatorWithMetaInfo(pair.second);
} }
} }
...@@ -874,7 +893,7 @@ void RegisterOperatorWithMetaInfoMap( ...@@ -874,7 +893,7 @@ void RegisterOperatorWithMetaInfoMap(
// load op api // load op api
void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name); void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name);
VLOG(1) << "load custom_op lib: " << dso_name; VLOG(3) << "load custom_op lib: " << dso_name;
typedef OpMetaInfoMap& get_op_meta_info_map_t(); typedef OpMetaInfoMap& get_op_meta_info_map_t();
auto* get_op_meta_info_map = auto* get_op_meta_info_map =
detail::DynLoad<get_op_meta_info_map_t>(handle, "PD_GetOpMetaInfoMap"); detail::DynLoad<get_op_meta_info_map_t>(handle, "PD_GetOpMetaInfoMap");
......
...@@ -94,8 +94,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -94,8 +94,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
// 2. Execute infer shape and choose kernel // 2. Execute infer shape and choose kernel
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
static_cast<const framework::OperatorWithKernel*>(op.get())->InferShape( op.get()->Info().infer_shape_(&infer_shape_ctx);
&infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op_type); auto kernels_iter = all_op_kernels.find(op_type);
PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(), PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable( platform::errors::Unavailable(
......
...@@ -355,7 +355,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -355,7 +355,7 @@ void build_op_func_list(const platform::Place& place,
// TODO(Aurelius84): In case of control flow ops, they are NOT // TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted // inheritted
// from OperatorWithKernel. // from OperatorWithKernel.
op_with_kernel->InferShape(&infer_shape_ctx); op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
} }
auto kernels_iter = all_op_kernels.find(op->Type()); auto kernels_iter = all_op_kernels.find(op->Type());
......
...@@ -1090,7 +1090,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, ...@@ -1090,7 +1090,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place, const platform::Place& place,
const RuntimeContext& ctx) const { const RuntimeContext& ctx) const {
RuntimeInferShapeContext infer_shape_ctx(*this, ctx); RuntimeInferShapeContext infer_shape_ctx(*this, ctx);
this->InferShape(&infer_shape_ctx); this->Info().infer_shape_(&infer_shape_ctx);
} }
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
...@@ -1178,6 +1178,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1178,6 +1178,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("infer_shape", platform::RecordEvent record_event("infer_shape",
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx); RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx);
// TODO(chenweihang): replace this after removing `this->IsMKLDNNType()`
// in some mkldnn infershape functions, such conv2d infershape
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
......
...@@ -491,8 +491,7 @@ static void PreparedOpRunImpl( ...@@ -491,8 +491,7 @@ static void PreparedOpRunImpl(
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs, DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type()); &default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape( op.Info().infer_shape_(&infer_shape_ctx);
&infer_shape_ctx);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs, func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs, default_attrs)); attrs, default_attrs));
...@@ -537,8 +536,7 @@ static void PreparedOpRunPtImpl( ...@@ -537,8 +536,7 @@ static void PreparedOpRunPtImpl(
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs, DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type()); &default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape( op.Info().infer_shape_(&infer_shape_ctx);
&infer_shape_ctx);
BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins, BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
outs, attrs, default_attrs, dev_ctx, outs, attrs, default_attrs, dev_ctx,
......
...@@ -122,13 +122,6 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) { ...@@ -122,13 +122,6 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
} }
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) { OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) {
PADDLE_ENFORCE_EQ(
index_,
0UL,
platform::errors::Unimplemented(
"Currently, the InferShapeFn setting of Grad Op is not supported, "
"And backward Tensor `X@GRAD` will use the shape of forward Tensor "
"`X` by default."));
info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func)); info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func));
return *this; return *this;
} }
......
...@@ -105,3 +105,49 @@ PD_BUILD_GRAD_OP(custom_relu) ...@@ -105,3 +105,49 @@ PD_BUILD_GRAD_OP(custom_relu)
.Inputs({"X", "Out", paddle::Grad("Out")}) .Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")}) .Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward)); .SetKernelFn(PD_KERNEL(ReluBackward));
std::vector<paddle::Tensor> relu_cpu_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(out.place()),
out.size());
}));
return {grad_x};
}
std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out);
std::vector<paddle::Tensor> ReluBackwardWithoutX(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
if (out.place() == paddle::PlaceType::kCPU) {
return relu_cpu_backward_without_x(out, grad_out);
} else if (out.place() == paddle::PlaceType::kGPU) {
return relu_cuda_backward_without_x(out, grad_out);
} else {
PD_THROW("Not implemented.");
}
}
std::vector<std::vector<int64_t>> ReluBackwardWithoutXInferShape(
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& grad_out_shape) {
return {out_shape};
}
PD_BUILD_OP(custom_relu_no_x_in_backward)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward));
PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward)
.Inputs({"Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackwardWithoutX))
.SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape));
...@@ -70,3 +70,22 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x, ...@@ -70,3 +70,22 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
return {grad_x}; return {grad_x};
} }
std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, out.shape());
int numel = out.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block, 0, out.stream()>>>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(out.place()),
numel);
}));
return {grad_x};
}
...@@ -49,7 +49,8 @@ custom_module = load( ...@@ -49,7 +49,8 @@ custom_module = load(
class TestJITLoad(unittest.TestCase): class TestJITLoad(unittest.TestCase):
def setUp(self): def setUp(self):
self.custom_ops = [ self.custom_ops = [
custom_module.custom_relu, custom_module.custom_relu_dup custom_module.custom_relu, custom_module.custom_relu_dup,
custom_module.custom_relu_no_x_in_backward
] ]
self.dtypes = ['float32', 'float64'] self.dtypes = ['float32', 'float64']
if paddle.is_compiled_with_cuda(): if paddle.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册