未验证 提交 046553c7 编写于 作者: C Chen Weihang 提交者: GitHub

Support setting infershape function for custom grad op (#38776)

* unify infer_shape func calling

* support set grad infer shape fn for custom op

* unify infershape in new executor and eager

* remove todo comment

* revert infershape in operator
上级 cd2855b0
......@@ -174,8 +174,7 @@ static void PreparedOpRunImpl(
EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs,
op.Type());
static_cast<const paddle::framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
op.Info().infer_shape_(&infer_shape_ctx);
func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs,
default_attrs));
......
......@@ -94,7 +94,7 @@ std::vector<std::string> ParseAttrStr(const std::string& attr) {
// 2. type
rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1)));
VLOG(1) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];
VLOG(3) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];
return rlt;
}
......@@ -109,11 +109,11 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: Start run KernelFunc.";
VLOG(3) << "Custom Operator: Start run KernelFunc.";
std::vector<paddle::experimental::Tensor> custom_ins;
std::vector<std::vector<paddle::experimental::Tensor>> custom_vec_ins;
for (auto& in_name : inputs) {
VLOG(1) << "Custom Operator: input name - " << in_name;
VLOG(3) << "Custom Operator: input name - " << in_name;
if (detail::IsDuplicableVar(in_name)) {
// return const std::vector<const Tensor*>
auto vec_x = ctx.MultiInput<Tensor>(in_name);
......@@ -185,11 +185,11 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
}
}
VLOG(1) << "Custom Operator: Run ComputeFunc.";
VLOG(3) << "Custom Operator: Run ComputeFunc.";
try {
auto outs = func(custom_ins, custom_vec_ins, custom_attrs);
VLOG(1) << "Custom Operator: Share outputs into ExecutionContext.";
VLOG(3) << "Custom Operator: Share outputs into ExecutionContext.";
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) {
......@@ -230,6 +230,95 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
}
}
static void RunInferShapeFunc(framework::InferShapeContext* ctx,
const paddle::InferShapeFunc& func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes;
VLOG(3) << "Custom Operator: InferShape - get input ddim.";
for (auto& in_name : inputs) {
if (detail::IsDuplicableVar(in_name)) {
OP_INOUT_CHECK(ctx->HasInputs(in_name), "Input", in_name, "Custom");
auto vec_ddim = ctx->GetInputsDim(in_name);
std::vector<std::vector<int64_t>> vec_shape;
vec_shape.reserve(vec_ddim.size());
std::transform(vec_ddim.begin(), vec_ddim.end(),
std::back_inserter(vec_shape),
[&](const DDim& ddim) -> std::vector<int64_t> {
return framework::vectorize(ddim);
});
vec_input_shapes.emplace_back(vec_shape);
} else {
OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom");
auto ddim = ctx->GetInputDim(in_name);
input_shapes.emplace_back(framework::vectorize(ddim));
}
}
std::vector<paddle::any> custom_attrs;
for (auto& attr_str : attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx->Attrs().Get<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx->Attrs().Get<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx->Attrs().Get<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx->Attrs().Get<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx->Attrs().Get<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(ctx->Attrs().Get<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
// NOTE(chenweihang): InferShape can't support std::vector<int64_t>
// attr type, because the input type is std::vector<int64_t>, only
// can use one rule to parse std::vector<int64_t> parameter
continue;
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<std::string>`, "
"Please check whether the attribute data type and "
"data type string are matched.",
attr_type_str));
}
}
VLOG(3) << "Custom Operator: InferShape - calc output ddim.";
auto output_shapes = func(input_shapes, vec_input_shapes, custom_attrs);
VLOG(3) << "Custom Operator: InferShape - set output ddim.";
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) {
std::vector<DDim> vec_ddim;
vec_ddim.reserve(output_shapes.size());
std::transform(output_shapes.begin(), output_shapes.end(),
std::back_inserter(vec_ddim),
[&](const std::vector<int64_t>& shape) -> DDim {
return framework::make_ddim(shape);
});
ctx->SetOutputsDim(out_name, vec_ddim);
} else {
ctx->SetOutputDim(out_name, framework::make_ddim(output_shapes[i]));
}
}
}
//////////////////// Operator Define /////////////////
class CustomOperator : public OperatorWithKernel {
......@@ -239,7 +328,7 @@ class CustomOperator : public OperatorWithKernel {
// Dummy infershape
// Because it is a pure virtual function, it must be implemented
void InferShape(framework::InferShapeContext* ctx) const override {
VLOG(1) << "Custom Operator: Dummy infer shape of custom operator.";
VLOG(3) << "Custom Operator: Dummy infer shape of custom operator.";
}
/**
......@@ -381,7 +470,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
auto fwd_op_outputs = this->OutputNames();
for (auto& in_name : inputs_) {
VLOG(1) << "Custom Operator: GradOpDescMaker - input: " << in_name;
VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
......@@ -398,7 +487,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
}
}
for (auto& out_name : outputs_) {
VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name;
VLOG(3) << "Custom Operator: GradOpDescMaker - output: " << out_name;
if (detail::IsDuplicableVar(out_name)) {
grad_op->SetOutput(out_name,
this->InputGrad(detail::NoGrad(out_name),
......@@ -447,7 +536,7 @@ class CustomGradOpMaker<imperative::OpBase>
auto fwd_op_outputs = this->OutputNames();
for (auto& in_name : inputs_) {
VLOG(1) << "Custom Operator: GradOpBaseMaker - input: " << in_name;
VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
......@@ -464,7 +553,7 @@ class CustomGradOpMaker<imperative::OpBase>
}
}
for (auto& out_name : outputs_) {
VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
VLOG(3) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
}
grad_op->SetAttrMap(this->Attrs());
......@@ -486,11 +575,11 @@ void RegisterOperatorKernelWithPlace(const std::string& name,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
OpKernelType key(type, experimental::ConvertExtPlaceToInnerPlace(place));
VLOG(1) << "Custom Operator: op kernel key: " << key;
VLOG(3) << "Custom Operator: op kernel key: " << key;
OperatorWithKernel::AllOpKernels()[name][key] =
[kernel_func, inputs, outputs,
attrs](const framework::ExecutionContext& ctx) {
VLOG(1) << "Custom Operator: run custom kernel func in lambda.";
VLOG(3) << "Custom Operator: run custom kernel func in lambda.";
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs);
};
}
......@@ -500,7 +589,7 @@ void RegisterOperatorKernel(const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: op name in kernel: " << name;
VLOG(3) << "Custom Operator: op name in kernel: " << name;
// NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based
// op_kernel_key.place_, so we should register kernel for each
......@@ -535,12 +624,12 @@ void RegisterOperatorWithMetaInfo(
auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta);
auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta);
VLOG(1) << "Custom Operator: forward, op name: " << op_name;
VLOG(1) << "Custom Operator: forward, op inputs: "
VLOG(3) << "Custom Operator: forward, op name: " << op_name;
VLOG(3) << "Custom Operator: forward, op inputs: "
<< string::join_strings(op_inputs, ',');
VLOG(1) << "Custom Operator: forward, op outputs: "
VLOG(3) << "Custom Operator: forward, op outputs: "
<< string::join_strings(op_outputs, ',');
VLOG(1) << "Custom Operator: forward, op attrs: "
VLOG(3) << "Custom Operator: forward, op attrs: "
<< string::join_strings(op_attrs, ',');
// Op
......@@ -588,96 +677,13 @@ void RegisterOperatorWithMetaInfo(
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
VLOG(1) << "Custom Operator: Default InferShape - share ddim.";
VLOG(3) << "Custom Operator: Default InferShape - share ddim.";
ctx->ShareDim(op_inputs[0], op_outputs[0]);
};
} else {
info.infer_shape_ = [op_inputs, op_outputs, op_attrs,
infer_shape_func](InferShapeContext* ctx) {
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes;
VLOG(1) << "Custom Operator: InferShape - get input ddim.";
for (auto& in_name : op_inputs) {
if (detail::IsDuplicableVar(in_name)) {
OP_INOUT_CHECK(ctx->HasInputs(in_name), "Input", in_name, "Custom");
auto vec_ddim = ctx->GetInputsDim(in_name);
std::vector<std::vector<int64_t>> vec_shape;
vec_shape.reserve(vec_ddim.size());
std::transform(vec_ddim.begin(), vec_ddim.end(),
std::back_inserter(vec_shape),
[&](const DDim& ddim) -> std::vector<int64_t> {
return framework::vectorize(ddim);
});
vec_input_shapes.emplace_back(vec_shape);
} else {
OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom");
auto ddim = ctx->GetInputDim(in_name);
input_shapes.emplace_back(framework::vectorize(ddim));
}
}
std::vector<paddle::any> custom_attrs;
for (auto& attr_str : op_attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx->Attrs().Get<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx->Attrs().Get<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx->Attrs().Get<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx->Attrs().Get<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx->Attrs().Get<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
// NOTE(chenweihang): InferShape can't support std::vector<int64_t>
// attr type, because the input type is std::vector<int64_t>, only
// can use one rule to parse std::vector<int64_t> parameter
continue;
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<std::string>`, "
"Please check whether the attribute data type and "
"data type string are matched.",
attr_type_str));
}
}
VLOG(1) << "Custom Operator: InferShape - calc output ddim.";
auto output_shapes =
infer_shape_func(input_shapes, vec_input_shapes, custom_attrs);
VLOG(1) << "Custom Operator: InferShape - set output ddim.";
for (size_t i = 0; i < op_outputs.size(); ++i) {
auto out_name = op_outputs[i];
if (detail::IsDuplicableVar(out_name)) {
std::vector<DDim> vec_ddim;
vec_ddim.reserve(output_shapes.size());
std::transform(output_shapes.begin(), output_shapes.end(),
std::back_inserter(vec_ddim),
[&](const std::vector<int64_t>& shape) -> DDim {
return framework::make_ddim(shape);
});
ctx->SetOutputsDim(out_name, vec_ddim);
} else {
ctx->SetOutputDim(out_name, framework::make_ddim(output_shapes[i]));
}
}
RunInferShapeFunc(ctx, infer_shape_func, op_inputs, op_outputs, op_attrs);
};
}
......@@ -706,7 +712,7 @@ void RegisterOperatorWithMetaInfo(
"Please set the InferDtypeFn of custom "
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
VLOG(1) << "Custom Operator: InferDtype - share dtype.";
VLOG(3) << "Custom Operator: InferDtype - share dtype.";
auto dtype = ctx->GetInputDataType(op_inputs[0]);
ctx->SetOutputDataType(op_outputs[0], dtype);
};
......@@ -716,7 +722,7 @@ void RegisterOperatorWithMetaInfo(
std::vector<DataType> input_dtypes;
std::vector<std::vector<DataType>> vec_input_dtypes;
VLOG(1) << "Custom Operator: InferDtype - get input dtype.";
VLOG(3) << "Custom Operator: InferDtype - get input dtype.";
for (auto& in_name : op_inputs) {
if (detail::IsDuplicableVar(in_name)) {
std::vector<DataType> vec_custom_dtype;
......@@ -731,10 +737,10 @@ void RegisterOperatorWithMetaInfo(
}
}
VLOG(1) << "Custom Operator: InferDtype - infer output dtype.";
VLOG(3) << "Custom Operator: InferDtype - infer output dtype.";
auto output_dtypes = infer_dtype_func(input_dtypes, vec_input_dtypes);
VLOG(1) << "Custom Operator: InferDtype - set output dtype.";
VLOG(3) << "Custom Operator: InferDtype - set output dtype.";
for (size_t i = 0; i < op_outputs.size(); ++i) {
auto out_name = op_outputs[i];
if (detail::IsDuplicableVar(out_name)) {
......@@ -763,11 +769,12 @@ void RegisterOperatorWithMetaInfo(
auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op);
auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op);
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);
auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op);
VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name;
VLOG(1) << "Custom Operator: backward, op inputs: "
VLOG(3) << "Custom Operator: backward, op name: " << grad_op_name;
VLOG(3) << "Custom Operator: backward, op inputs: "
<< string::join_strings(grad_op_inputs, ',');
VLOG(1) << "Custom Operator: backward, op outputs: "
VLOG(3) << "Custom Operator: backward, op outputs: "
<< string::join_strings(grad_op_outputs, ',');
// GradOpDescMaker
......@@ -809,40 +816,52 @@ void RegisterOperatorWithMetaInfo(
};
// Grad InferShape
grad_info.infer_shape_ = [grad_op_inputs,
grad_op_outputs](InferShapeContext* ctx) {
// 1. if forward input exists, gradient's shape is same with forward input
// default
// [Suitable for most situations]
// 2. if forward input not exists, and only contains one grad input and
// output,
// use grad input shape as grad output shape
// [Suitable for the situation that forward input is not used as
// backward input]
// TODO(chenweihang): support set grad op infershape func if needed
for (auto& out_name : grad_op_outputs) {
auto fwd_name = detail::NoGrad(out_name);
if (detail::IsDuplicableVar(fwd_name)) {
// Duplicable forward var must as backward input
ctx->ShareDim(fwd_name, out_name);
} else {
if (ctx->HasInput(fwd_name)) {
if (grad_infer_shape_fn == nullptr) {
grad_info.infer_shape_ = [grad_op_inputs,
grad_op_outputs](InferShapeContext* ctx) {
// 1. if forward input exists, gradient's shape is same with forward
// input
// default
// [Suitable for most situations]
// 2. if forward input not exists, and only contains one grad input and
// output,
// use grad input shape as grad output shape
// [Suitable for the situation that forward input is not used as
// backward input]
for (auto& out_name : grad_op_outputs) {
auto fwd_name = detail::NoGrad(out_name);
if (detail::IsDuplicableVar(fwd_name)) {
// Duplicable forward var must as backward input
ctx->ShareDim(fwd_name, out_name);
} else {
PADDLE_ENFORCE_EQ(
grad_op_inputs.size() == 1UL && grad_op_outputs.size() == 1UL,
true,
platform::errors::Unavailable(
"Custom grad operator infershape error. "
"If a custom grad operator contains only one input and "
"only one output, the input shape will be directly set to "
"the output shape. Otherwise, Please set the forward input "
"as the grad operator's input."));
ctx->ShareDim(grad_op_inputs[0], out_name);
if (ctx->HasInput(fwd_name)) {
ctx->ShareDim(fwd_name, out_name);
} else {
PADDLE_ENFORCE_EQ(
grad_op_inputs.size() == 1UL && grad_op_outputs.size() == 1UL,
true,
platform::errors::Unavailable(
"Custom grad operator infershape error. "
"If a custom grad operator contains only one input and "
"only one output, the input shape will be directly set "
"to "
"the output shape. Otherwise, Please set the forward "
"input "
"as the grad operator's input or set the InferShapeFn "
"of custom grad operator by "
".SetInferShapeFn(PD_INFER_SHAPE(...))"));
ctx->ShareDim(grad_op_inputs[0], out_name);
}
}
}
}
};
};
} else {
grad_info.infer_shape_ = [grad_op_inputs, grad_op_outputs, grad_op_attrs,
grad_infer_shape_fn](InferShapeContext* ctx) {
RunInferShapeFunc(ctx, grad_infer_shape_fn, grad_op_inputs,
grad_op_outputs, grad_op_attrs);
};
}
// Kernel func
RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs,
......@@ -860,11 +879,11 @@ void RegisterOperatorWithMetaInfo(
void RegisterOperatorWithMetaInfoMap(
const paddle::OpMetaInfoMap& op_meta_info_map) {
auto& meta_info_map = op_meta_info_map.GetMap();
VLOG(1) << "Custom Operator: size of op meta info map - "
VLOG(3) << "Custom Operator: size of op meta info map - "
<< meta_info_map.size();
// pair: {op_type, OpMetaInfo}
for (auto& pair : meta_info_map) {
VLOG(1) << "Custom Operator: pair first -> op name: " << pair.first;
VLOG(3) << "Custom Operator: pair first -> op name: " << pair.first;
RegisterOperatorWithMetaInfo(pair.second);
}
}
......@@ -874,7 +893,7 @@ void RegisterOperatorWithMetaInfoMap(
// load op api
void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name);
VLOG(1) << "load custom_op lib: " << dso_name;
VLOG(3) << "load custom_op lib: " << dso_name;
typedef OpMetaInfoMap& get_op_meta_info_map_t();
auto* get_op_meta_info_map =
detail::DynLoad<get_op_meta_info_map_t>(handle, "PD_GetOpMetaInfoMap");
......
......@@ -94,8 +94,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
// 2. Execute infer shape and choose kernel
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
static_cast<const framework::OperatorWithKernel*>(op.get())->InferShape(
&infer_shape_ctx);
op.get()->Info().infer_shape_(&infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op_type);
PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
......
......@@ -355,7 +355,7 @@ void build_op_func_list(const platform::Place& place,
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted
// from OperatorWithKernel.
op_with_kernel->InferShape(&infer_shape_ctx);
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}
auto kernels_iter = all_op_kernels.find(op->Type());
......
......@@ -1090,7 +1090,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place,
const RuntimeContext& ctx) const {
RuntimeInferShapeContext infer_shape_ctx(*this, ctx);
this->InferShape(&infer_shape_ctx);
this->Info().infer_shape_(&infer_shape_ctx);
}
void OperatorWithKernel::RunImpl(const Scope& scope,
......@@ -1178,6 +1178,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("infer_shape",
platform::EventRole::kInnerOp);
RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx);
// TODO(chenweihang): replace this after removing `this->IsMKLDNNType()`
// in some mkldnn infershape functions, such conv2d infershape
this->InferShape(&infer_shape_ctx);
}
......
......@@ -491,8 +491,7 @@ static void PreparedOpRunImpl(
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
op.Info().infer_shape_(&infer_shape_ctx);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs, default_attrs));
......@@ -537,8 +536,7 @@ static void PreparedOpRunPtImpl(
const framework::AttributeMap& default_attrs) {
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
op.Info().infer_shape_(&infer_shape_ctx);
BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
outs, attrs, default_attrs, dev_ctx,
......
......@@ -122,13 +122,6 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) {
PADDLE_ENFORCE_EQ(
index_,
0UL,
platform::errors::Unimplemented(
"Currently, the InferShapeFn setting of Grad Op is not supported, "
"And backward Tensor `X@GRAD` will use the shape of forward Tensor "
"`X` by default."));
info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func));
return *this;
}
......
......@@ -105,3 +105,49 @@ PD_BUILD_GRAD_OP(custom_relu)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward));
std::vector<paddle::Tensor> relu_cpu_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(out.place()),
out.size());
}));
return {grad_x};
}
std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out);
std::vector<paddle::Tensor> ReluBackwardWithoutX(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
if (out.place() == paddle::PlaceType::kCPU) {
return relu_cpu_backward_without_x(out, grad_out);
} else if (out.place() == paddle::PlaceType::kGPU) {
return relu_cuda_backward_without_x(out, grad_out);
} else {
PD_THROW("Not implemented.");
}
}
std::vector<std::vector<int64_t>> ReluBackwardWithoutXInferShape(
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& grad_out_shape) {
return {out_shape};
}
PD_BUILD_OP(custom_relu_no_x_in_backward)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward));
PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward)
.Inputs({"Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackwardWithoutX))
.SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape));
......@@ -70,3 +70,22 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
return {grad_x};
}
std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, out.shape());
int numel = out.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block, 0, out.stream()>>>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(out.place()),
numel);
}));
return {grad_x};
}
......@@ -49,7 +49,8 @@ custom_module = load(
class TestJITLoad(unittest.TestCase):
def setUp(self):
self.custom_ops = [
custom_module.custom_relu, custom_module.custom_relu_dup
custom_module.custom_relu, custom_module.custom_relu_dup,
custom_module.custom_relu_no_x_in_backward
]
self.dtypes = ['float32', 'float64']
if paddle.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册