未验证 提交 046553c7 编写于 作者: C Chen Weihang 提交者: GitHub

Support setting infershape function for custom grad op (#38776)

* unify infer_shape func calling

* support set grad infer shape fn for custom op

* unify infershape in new executor and eager

* remove todo comment

* revert infershape in operator
上级 cd2855b0
......@@ -174,8 +174,7 @@ static void PreparedOpRunImpl(
EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs,
op.Type());
static_cast<const paddle::framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
op.Info().infer_shape_(&infer_shape_ctx);
func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs,
default_attrs));
......
......@@ -94,8 +94,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
// 2. Execute infer shape and choose kernel
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
static_cast<const framework::OperatorWithKernel*>(op.get())->InferShape(
&infer_shape_ctx);
op.get()->Info().infer_shape_(&infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op_type);
PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
......
......@@ -355,7 +355,7 @@ void build_op_func_list(const platform::Place& place,
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted
// from OperatorWithKernel.
op_with_kernel->InferShape(&infer_shape_ctx);
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}
auto kernels_iter = all_op_kernels.find(op->Type());
......
......@@ -1090,7 +1090,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place,
const RuntimeContext& ctx) const {
RuntimeInferShapeContext infer_shape_ctx(*this, ctx);
this->InferShape(&infer_shape_ctx);
this->Info().infer_shape_(&infer_shape_ctx);
}
void OperatorWithKernel::RunImpl(const Scope& scope,
......@@ -1178,6 +1178,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("infer_shape",
platform::EventRole::kInnerOp);
RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx);
// TODO(chenweihang): replace this after removing `this->IsMKLDNNType()`
// in some mkldnn infershape functions, such conv2d infershape
this->InferShape(&infer_shape_ctx);
}
......
......@@ -491,8 +491,7 @@ static void PreparedOpRunImpl(
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
op.Info().infer_shape_(&infer_shape_ctx);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs, default_attrs));
......@@ -537,8 +536,7 @@ static void PreparedOpRunPtImpl(
const framework::AttributeMap& default_attrs) {
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
op.Info().infer_shape_(&infer_shape_ctx);
BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
outs, attrs, default_attrs, dev_ctx,
......
......@@ -122,13 +122,6 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) {
PADDLE_ENFORCE_EQ(
index_,
0UL,
platform::errors::Unimplemented(
"Currently, the InferShapeFn setting of Grad Op is not supported, "
"And backward Tensor `X@GRAD` will use the shape of forward Tensor "
"`X` by default."));
info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func));
return *this;
}
......
......@@ -105,3 +105,49 @@ PD_BUILD_GRAD_OP(custom_relu)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward));
std::vector<paddle::Tensor> relu_cpu_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(out.place()),
out.size());
}));
return {grad_x};
}
std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out);
std::vector<paddle::Tensor> ReluBackwardWithoutX(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
if (out.place() == paddle::PlaceType::kCPU) {
return relu_cpu_backward_without_x(out, grad_out);
} else if (out.place() == paddle::PlaceType::kGPU) {
return relu_cuda_backward_without_x(out, grad_out);
} else {
PD_THROW("Not implemented.");
}
}
std::vector<std::vector<int64_t>> ReluBackwardWithoutXInferShape(
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& grad_out_shape) {
return {out_shape};
}
PD_BUILD_OP(custom_relu_no_x_in_backward)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward));
PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward)
.Inputs({"Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackwardWithoutX))
.SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape));
......@@ -70,3 +70,22 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
return {grad_x};
}
std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, out.shape());
int numel = out.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block, 0, out.stream()>>>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(out.place()),
numel);
}));
return {grad_x};
}
......@@ -49,7 +49,8 @@ custom_module = load(
class TestJITLoad(unittest.TestCase):
def setUp(self):
self.custom_ops = [
custom_module.custom_relu, custom_module.custom_relu_dup
custom_module.custom_relu, custom_module.custom_relu_dup,
custom_module.custom_relu_no_x_in_backward
]
self.dtypes = ['float32', 'float64']
if paddle.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册