From 382e460ba4c776620f4e83d96d6b0962d9eddce5 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Mon, 21 Mar 2022 11:22:06 +0800 Subject: [PATCH] [Phi]add pad3d kernel into phi (#40701) * add pad3d kernel into phi * add pad3d infermeta * fix build error * remove raw pad3d infershape function --- paddle/fluid/operators/pad3d_op.cc | 737 +----------------- paddle/fluid/operators/pad3d_op.cu | 793 -------------------- paddle/phi/infermeta/unary.cc | 71 ++ paddle/phi/infermeta/unary.h | 8 + paddle/phi/kernels/cpu/pad3d_grad_kernel.cc | 480 ++++++++++++ paddle/phi/kernels/cpu/pad3d_kernel.cc | 578 ++++++++++++++ paddle/phi/kernels/gpu/pad3d_grad_kernel.cu | 507 +++++++++++++ paddle/phi/kernels/gpu/pad3d_kernel.cu | 588 +++++++++++++++ paddle/phi/kernels/pad3d_grad_kernel.h | 32 + paddle/phi/kernels/pad3d_kernel.h | 31 + paddle/phi/ops/compat/pad3d_sig.cc | 45 ++ 11 files changed, 2347 insertions(+), 1523 deletions(-) delete mode 100644 paddle/fluid/operators/pad3d_op.cu create mode 100644 paddle/phi/kernels/cpu/pad3d_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/pad3d_kernel.cc create mode 100644 paddle/phi/kernels/gpu/pad3d_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/pad3d_kernel.cu create mode 100644 paddle/phi/kernels/pad3d_grad_kernel.h create mode 100644 paddle/phi/kernels/pad3d_kernel.h create mode 100644 paddle/phi/ops/compat/pad3d_sig.cc diff --git a/paddle/fluid/operators/pad3d_op.cc b/paddle/fluid/operators/pad3d_op.cc index 7b9a4ab155..e4952a2432 100644 --- a/paddle/fluid/operators/pad3d_op.cc +++ b/paddle/fluid/operators/pad3d_op.cc @@ -16,7 +16,9 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -24,734 +26,10 @@ namespace operators { using framework::Tensor; -template -void ConstPad3DFuncNCDHW(const T* in_data, T* out_data, const int in_depth, - const int in_height, const int in_width, - const int out_depth, const int out_height, - const int out_width, const int pad_front, - const int pad_top, const int pad_left, const int out_d, - const int out_h, const int out_w, const T value) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - out_data[out_d * out_height * out_width + out_h * out_width + out_w] = - (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || - in_h >= in_height || in_w >= in_width) - ? value - : in_data[in_d * in_height * in_width + in_h * in_width + in_w]; -} - -template -void ConstPad3DFuncNDHWC(const T* in_data, T* out_data, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const int out_d, const int out_h, - const int out_w, const T value) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * channels; - if (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || - in_h >= in_height || in_w >= in_width) { - for (int c = 0; c < channels; ++c) { - out_data[out_index + c] = value; - } - } else { - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - out_data[out_index + c] = in_data[in_index + c]; - } - } -} - -template -void ReflectPad3DFuncNCDHW(const T* in_data, T* out_data, const int in_depth, - const int in_height, const int in_width, - const int out_depth, const int out_height, - const int out_width, const int pad_front, - const int pad_top, const int pad_left, - const int out_d, const int out_h, const int out_w, - const T value) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = std::max(in_d, -in_d); // reflect by 0 - in_d = std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth - in_h = std::max(in_h, -in_h); // reflect by 0 - in_h = std::min(in_h, 2 * in_height - in_h - 2); // reflect by in_height - in_w = std::max(in_w, -in_w); // reflect by 0 - in_w = std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width - - out_data[out_d * out_height * out_width + out_h * out_width + out_w] = - in_data[in_d * in_height * in_width + in_h * in_width + in_w]; -} - -template -void ReflectPad3DFuncNDHWC(const T* in_data, T* out_data, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const int out_d, const int out_h, - const int out_w, const T value) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = std::max(in_d, -in_d); - in_d = std::min(in_d, 2 * in_depth - in_d - 2); - in_h = std::max(in_h, -in_h); - in_h = std::min(in_h, 2 * in_height - in_h - 2); - in_w = std::max(in_w, -in_w); - in_w = std::min(in_w, 2 * in_width - in_w - 2); - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - out_data[out_index + c] = in_data[in_index + c]; - } -} - -template -void ReplicatePad3DFuncNCDHW(const T* in_data, T* out_data, const int in_depth, - const int in_height, const int in_width, - const int out_depth, const int out_height, - const int out_width, const int pad_front, - const int pad_top, const int pad_left, - const int out_d, const int out_h, const int out_w, - const T value) { - int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); - int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); - int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); - - out_data[out_d * out_height * out_width + out_h * out_width + out_w] = - in_data[in_d * in_height * in_width + in_h * in_width + in_w]; -} - -template -void ReplicatePad3DFuncNDHWC(const T* in_data, T* out_data, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const int out_d, - const int out_h, const int out_w, const T value) { - int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); - int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); - int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - out_data[out_index + c] = in_data[in_index + c]; - } -} - -template -void CircularPad3DFuncNCDHW(const T* in_data, T* out_data, const int in_depth, - const int in_height, const int in_width, - const int out_depth, const int out_height, - const int out_width, const int pad_front, - const int pad_top, const int pad_left, - const int out_d, const int out_h, const int out_w, - const T value) { - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - - out_data[out_d * out_height * out_width + out_h * out_width + out_w] = - in_data[in_d * in_height * in_width + in_h * in_width + in_w]; -} - -template -void CircularPad3DFuncNDHWC(const T* in_data, T* out_data, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const int out_d, - const int out_h, const int out_w, const T value) { - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - out_data[out_index + c] = in_data[in_index + c]; - } -} - -template -void Pad3DNCDHW(const T* in_data, const int num, const int channels, - const int in_depth, const int in_height, const int in_width, - const int out_depth, const int out_height, const int out_width, - const int pad_front, const int pad_top, const int pad_left, - T value, T* out_data, - void (*pad_func)(const T*, T*, const int, const int, const int, - const int, const int, const int, const int, - const int, const int, const int, const int, - const int, const T)) { - for (int n = 0; n < num; ++n) { - for (int c = 0; c < channels; ++c) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - pad_func(in_data, out_data, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, - pad_left, out_d, out_h, out_w, value); - } - } - } - in_data += in_depth * in_height * in_width; - out_data += out_depth * out_height * out_width; - } - } -} - -template -void Pad3DNDHWC(const T* in_data, const int num, const int channels, - const int in_depth, const int in_height, const int in_width, - const int out_depth, const int out_height, const int out_width, - const int pad_front, const int pad_top, const int pad_left, - T value, T* out_data, - void (*pad_func)(const T*, T*, const int, const int, const int, - const int, const int, const int, const int, - const int, const int, const int, const int, - const int, const int, const T)) { - for (int n = 0; n < num; ++n) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - pad_func(in_data, out_data, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, - pad_left, out_d, out_h, out_w, value); - } - } - } - in_data += in_depth * in_height * in_width * channels; - out_data += out_depth * out_height * out_width * channels; - } -} - -template -void ConstPad3DGradNCDHW(T* d_in_data, const T* d_out_data, const int in_depth, - const int in_height, const int in_width, - const int out_depth, const int out_height, - const int out_width, const int pad_front, - const int pad_top, const int pad_left, const int out_d, - const int out_h, const int out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || - in_h >= in_height || in_w >= in_width)) { - d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] = - d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; - } -} - -template -void ConstPad3DGradNDHWC(T* d_in_data, const T* d_out_data, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const int out_d, const int out_h, - const int out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * channels; - if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || - in_h >= in_height || in_w >= in_width)) { - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - d_in_data[in_index + c] = d_out_data[out_index + c]; - } - } -} - -template -void ReflectPad3DGradNCDHW(T* d_in_data, const T* d_out_data, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const int out_d, const int out_h, - const int out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = std::max(in_d, -in_d); // reflect by 0 - in_d = std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth - in_h = std::max(in_h, -in_h); // reflect by 0 - in_h = std::min(in_h, 2 * in_height - in_h - 2); // reflect by in_height - in_w = std::max(in_w, -in_w); // reflect by 0 - in_w = std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width - - d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += - d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; -} - -template -void ReflectPad3DGradNDHWC(T* d_in_data, const T* d_out_data, - const int channels, const int in_depth, - const int in_height, const int in_width, - const int out_depth, const int out_height, - const int out_width, const int pad_front, - const int pad_top, const int pad_left, - const int out_d, const int out_h, const int out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = std::max(in_d, -in_d); - in_d = std::min(in_d, 2 * in_depth - in_d - 2); - in_h = std::max(in_h, -in_h); - in_h = std::min(in_h, 2 * in_height - in_h - 2); - in_w = std::max(in_w, -in_w); - in_w = std::min(in_w, 2 * in_width - in_w - 2); - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - d_in_data[in_index + c] += d_out_data[out_index + c]; - } -} - -template -void ReplicatePad3DGradNCDHW(T* d_in_data, const T* d_out_data, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const int out_d, - const int out_h, const int out_w) { - int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); - int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); - int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); - - d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += - d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; -} - -template -void ReplicatePad3DGradNDHWC(T* d_in_data, const T* d_out_data, - const int channels, const int in_depth, - const int in_height, const int in_width, - const int out_depth, const int out_height, - const int out_width, const int pad_front, - const int pad_top, const int pad_left, - const int out_d, const int out_h, - const int out_w) { - int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); - int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); - int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - d_in_data[in_index + c] += d_out_data[out_index + c]; - } -} - -template -void CircularPad3DGradNCDHW(T* d_in_data, const T* d_out_data, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const int out_d, - const int out_h, const int out_w) { - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += - d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; -} - -template -void CircularPad3DGradNDHWC(T* d_in_data, const T* d_out_data, - const int channels, const int in_depth, - const int in_height, const int in_width, - const int out_depth, const int out_height, - const int out_width, const int pad_front, - const int pad_top, const int pad_left, - const int out_d, const int out_h, const int out_w) { - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - d_in_data[in_index + c] += d_out_data[out_index + c]; - } -} - -template -void Pad3DGradNCDHW(T* d_in_data, const int num, const int channels, - const int in_depth, const int in_height, const int in_width, - const int out_depth, const int out_height, - const int out_width, const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data, - void (*pad_func)(T*, const T*, const int, const int, - const int, const int, const int, const int, - const int, const int, const int, const int, - const int, const int)) { - for (int n = 0; n < num; ++n) { - for (int c = 0; c < channels; ++c) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - pad_func(d_in_data, d_out_data, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, - pad_left, out_d, out_h, out_w); - } - } - } - d_in_data += in_depth * in_height * in_width; - d_out_data += out_depth * out_height * out_width; - } - } -} - -template -void Pad3DGradNDHWC(T* d_in_data, const int num, const int channels, - const int in_depth, const int in_height, const int in_width, - const int out_depth, const int out_height, - const int out_width, const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data, - void (*pad_func)(T*, const T*, const int, const int, - const int, const int, const int, const int, - const int, const int, const int, const int, - const int, const int, const int)) { - for (int n = 0; n < num; ++n) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - pad_func(d_in_data, d_out_data, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, pad_front, - pad_top, pad_left, out_d, out_h, out_w); - } - } - } - d_in_data += in_depth * in_height * in_width * channels; - d_out_data += out_depth * out_height * out_width * channels; - } -} - -static inline std::vector GetPaddings( - const framework::ExecutionContext& context) { - std::vector paddings(6); - auto* paddings_t = context.Input("Paddings"); - if (paddings_t) { - auto paddings_data = paddings_t->data(); - std::memcpy(paddings.data(), paddings_data, paddings.size() * sizeof(int)); - } else { - auto pads = context.Attr>("paddings"); - std::copy(pads.begin(), pads.end(), paddings.data()); - } - return paddings; -} - -template -class Pad3dCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - std::vector pads = GetPaddings(context); - auto mode = context.Attr("mode"); - auto data_format = context.Attr("data_format"); - T value = static_cast(context.Attr("value")); - - auto* x = context.Input("X"); - auto in_dims = x->dims(); - const T* in_data = x->data(); - - auto* out = context.Output("Out"); - if (data_format == "NCDHW") { - out->Resize({in_dims[0], in_dims[1], in_dims[2] + pads[4] + pads[5], - in_dims[3] + pads[2] + pads[3], - in_dims[4] + pads[0] + pads[1]}); - } else { - out->Resize({in_dims[0], in_dims[1] + pads[4] + pads[5], - in_dims[2] + pads[2] + pads[3], - in_dims[3] + pads[0] + pads[1], in_dims[4]}); - } - auto out_dims = out->dims(); - T* out_data = out->mutable_data(context.GetPlace()); - - int channels = in_dims[1]; - int in_depth = in_dims[2]; - int in_height = in_dims[3]; - int in_width = in_dims[4]; - int out_depth = out_dims[2]; - int out_height = out_dims[3]; - int out_width = out_dims[4]; - if (data_format == "NDHWC") { - channels = in_dims[4]; - in_depth = in_dims[1]; - in_height = in_dims[2]; - in_width = in_dims[3]; - out_depth = out_dims[1]; - out_height = out_dims[2]; - out_width = out_dims[3]; - } - - if (mode == "reflect") { - PADDLE_ENFORCE_GT(in_depth, pads[4], - platform::errors::InvalidArgument( - "The depth of Input(X)'s dimension should be " - "greater than pad_front" - " in reflect mode" - ", but received depth(%d) and pad_front(%d).", - in_depth, pads[4])); - PADDLE_ENFORCE_GT(in_depth, pads[5], - platform::errors::InvalidArgument( - "The depth of Input(X)'s dimension should be " - "greater than pad_back" - " in reflect mode" - ", but received depth(%d) and pad_back(%d).", - in_depth, pads[5])); - - PADDLE_ENFORCE_GT(in_height, pads[2], - platform::errors::InvalidArgument( - "The height of Input(X)'s dimension should be " - "greater than pad_top" - " in reflect mode" - ", but received depth(%d) and pad_top(%d).", - in_height, pads[2])); - PADDLE_ENFORCE_GT(in_height, pads[3], - platform::errors::InvalidArgument( - "The height of Input(X)'s dimension should be " - "greater than pad_bottom" - " in reflect mode" - ", but received depth(%d) and pad_bottom(%d).", - in_height, pads[3])); - - PADDLE_ENFORCE_GT(in_width, pads[0], - platform::errors::InvalidArgument( - "The width of Input(X)'s dimension should be " - "greater than pad_left" - " in reflect mode" - ", but received depth(%d) and pad_left(%d).", - in_width, pads[0])); - PADDLE_ENFORCE_GT(in_width, pads[1], - platform::errors::InvalidArgument( - "The width of Input(X)'s dimension should be " - "greater than pad_right" - " in reflect mode" - ", but received depth(%d) and pad_right(%d).", - in_width, pads[1])); - } else if (mode == "circular" || mode == "replicate") { - PADDLE_ENFORCE_NE(in_depth * in_height * in_width, 0, - platform::errors::InvalidArgument( - "The input tensor size can not be 0 for circular " - "or replicate padding mode.")); - } - - const int pad_left = pads[0]; - const int pad_top = pads[2]; - const int pad_front = pads[4]; - const int num = in_dims[0]; - if (data_format == "NCDHW") { - std::map - func_map; - - func_map["reflect"] = ReflectPad3DFuncNCDHW; - func_map["replicate"] = ReplicatePad3DFuncNCDHW; - func_map["circular"] = CircularPad3DFuncNCDHW; - func_map["constant"] = ConstPad3DFuncNCDHW; - Pad3DNCDHW(in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - value, out_data, func_map[mode]); - } else { - std::map - func_map; - - func_map["reflect"] = ReflectPad3DFuncNDHWC; - func_map["replicate"] = ReplicatePad3DFuncNDHWC; - func_map["circular"] = CircularPad3DFuncNDHWC; - func_map["constant"] = ConstPad3DFuncNDHWC; - Pad3DNDHWC(in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - value, out_data, func_map[mode]); - } - } -}; - -template -class Pad3dGradCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - std::vector pads = GetPaddings(context); - auto mode = context.Attr("mode"); - auto data_format = context.Attr("data_format"); - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_in = context.Output(framework::GradVarName("X")); - auto d_in_dims = d_in->dims(); - auto d_out_dims = d_out->dims(); - const T* d_out_data = d_out->data(); - T* d_in_data = d_in->mutable_data(context.GetPlace()); - phi::funcs::SetConstant set_zero; - set_zero(context.template device_context(), - d_in, static_cast(0)); - const int pad_left = pads[0]; - const int pad_top = pads[2]; - const int pad_front = pads[4]; - const int num = d_in_dims[0]; - if (data_format == "NCDHW") { - const int channels = d_in_dims[1]; - const int in_depth = d_in_dims[2]; - const int in_height = d_in_dims[3]; - const int in_width = d_in_dims[4]; - const int out_depth = d_out_dims[2]; - const int out_height = d_out_dims[3]; - const int out_width = d_out_dims[4]; - - std::map - func_map; - - func_map["reflect"] = ReflectPad3DGradNCDHW; - func_map["replicate"] = ReplicatePad3DGradNCDHW; - func_map["circular"] = CircularPad3DGradNCDHW; - func_map["constant"] = ConstPad3DGradNCDHW; - - Pad3DGradNCDHW(d_in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, - pad_left, d_out_data, func_map[mode]); - } else { - const int channels = d_in_dims[4]; - const int in_depth = d_in_dims[1]; - const int in_height = d_in_dims[2]; - const int in_width = d_in_dims[3]; - const int out_depth = d_out_dims[1]; - const int out_height = d_out_dims[2]; - const int out_width = d_out_dims[3]; - - std::map - func_map; - - func_map["reflect"] = ReflectPad3DGradNDHWC; - func_map["replicate"] = ReplicatePad3DGradNDHWC; - func_map["circular"] = CircularPad3DGradNDHWC; - func_map["constant"] = ConstPad3DGradNDHWC; - - Pad3DGradNDHWC(d_in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, - pad_left, d_out_data, func_map[mode]); - } - } -}; - class Pad3dOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad3d"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Pad3d"); - - auto x_dim = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(x_dim.size(), 5, - platform::errors::InvalidArgument( - "The size of Input(X)'s dimension should be equal to " - "5, but received %d. ", - x_dim.size())); - - std::vector out_dims(x_dim.size()); - auto data_format = ctx->Attrs().Get("data_format"); - out_dims[0] = x_dim[0]; - if (ctx->HasInput("Paddings")) { - auto paddings_dim = ctx->GetInputDim("Paddings"); - PADDLE_ENFORCE_EQ(paddings_dim.size(), 1, - platform::errors::InvalidArgument( - "Size of Input(Paddings)'s dimension should be " - "equal to 1, but received %d.", - paddings_dim.size())); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(paddings_dim[0], 6, - platform::errors::InvalidArgument( - "Shape of Input(Paddings) should be equal to " - "[6], but received [%d].", - paddings_dim[0])); - } - out_dims[1] = x_dim[1]; - out_dims[2] = x_dim[2]; - out_dims[3] = x_dim[3]; - } else { - auto paddings = ctx->Attrs().Get>("paddings"); - PADDLE_ENFORCE_EQ( - paddings.size(), 6, - platform::errors::InvalidArgument( - "Size of paddings should be equal to 4, but received %d.", - static_cast(paddings.size()))); - if (data_format == "NCDHW") { - out_dims[1] = x_dim[1]; // channel - out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0)) - ? x_dim[2] - : (x_dim[2] + paddings[4] + paddings[5]); // depth - - out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0)) - ? x_dim[3] - : (x_dim[3] + paddings[2] + paddings[3]); // height - - out_dims[4] = ((!ctx->IsRuntime()) && (x_dim[4] < 0)) - ? x_dim[4] - : (x_dim[4] + paddings[0] + paddings[1]); // width - } else { // NDHWC - out_dims[4] = x_dim[4]; // channel - - out_dims[1] = ((!ctx->IsRuntime()) && (x_dim[1] < 0)) - ? x_dim[1] - : (x_dim[1] + paddings[4] + paddings[5]); // depth - out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0)) - ? x_dim[2] - : (x_dim[2] + paddings[2] + paddings[3]); // height - out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0)) - ? x_dim[3] - : (x_dim[3] + paddings[0] + paddings[1]); // width - } - } - - ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); - ctx->ShareLoD("X", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -921,15 +199,14 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad3dOpGradNoNeedBufferVarsInferer, "X"); namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(pad3d, Pad3dInferShapeFunctor, + PD_INFER_META(phi::Pad3dInferMeta)); + REGISTER_OPERATOR(pad3d, ops::Pad3dOp, ops::Pad3dOpMaker, ops::Pad3dOpGradMaker, - ops::Pad3dOpGradMaker); + ops::Pad3dOpGradMaker, + Pad3dInferShapeFunctor); REGISTER_OPERATOR(pad3d_grad, ops::Pad3dOpGrad, ops::Pad3dOpDoubleGradMaker, ops::Pad3dOpDoubleGradMaker, ops::Pad3dOpGradNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL(pad3d, ops::Pad3dCPUKernel, - ops::Pad3dCPUKernel, ops::Pad3dCPUKernel, - ops::Pad3dCPUKernel); -REGISTER_OP_CPU_KERNEL(pad3d_grad, ops::Pad3dGradCPUKernel, - ops::Pad3dGradCPUKernel); diff --git a/paddle/fluid/operators/pad3d_op.cu b/paddle/fluid/operators/pad3d_op.cu deleted file mode 100644 index 9ab0eb9d44..0000000000 --- a/paddle/fluid/operators/pad3d_op.cu +++ /dev/null @@ -1,793 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using platform::PADDLE_CUDA_NUM_THREADS; - -using framework::Tensor; - -template -__global__ void Pad3DConstNCDHW(const int nthreads, const T* in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T value, T* out_data) { - CUDA_KERNEL_LOOP(index, nthreads) { - int nc = index / out_width; - - const int out_w = index % out_width; - const int out_h = nc % out_height; - nc /= out_height; - const int out_d = nc % out_depth; - nc /= out_depth; - - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - out_data[index] = - (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || - in_h >= in_height || in_w >= in_width) - ? value - : in_data[nc * in_depth * in_height * in_width + - in_d * in_height * in_width + in_h * in_width + in_w]; - } -} - -template -__global__ void Pad3DConstNDHWC(const int nthreads, const T* in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T value, T* out_data) { - CUDA_KERNEL_LOOP(index, nthreads) { - int n = index / channels; - const int c = index % channels; - const int out_w = n % out_width; - n /= out_width; - const int out_h = n % out_height; - n /= out_height; - const int out_d = n % out_depth; - n /= out_depth; - const int in_d = out_d - pad_front; - const int in_h = out_h - pad_top; - const int in_w = out_w - pad_left; - - out_data[index] = - (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || - in_h >= in_height || in_w >= in_width) - ? value - : in_data[n * in_depth * in_height * in_width * channels + - in_d * in_height * in_width * channels + - in_h * in_width * channels + in_w * channels + c]; - } -} - -template -__global__ void Pad3DReflectNCDHW(const int nthreads, const T* in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { - CUDA_KERNEL_LOOP(index, nthreads) { - int nc = index / out_width; - - const int out_w = index % out_width; - const int out_h = nc % out_height; - nc /= out_height; - const int out_d = nc % out_depth; - nc /= out_depth; - - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = max(in_d, -in_d); // reflect by 0 - in_d = min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth - in_h = max(in_h, -in_h); // reflect by 0 - in_h = min(in_h, 2 * in_height - in_h - 2); // reflect by in_height - in_w = max(in_w, -in_w); // reflect by 0 - in_w = min(in_w, 2 * in_width - in_w - 2); // reflect by in_width - out_data[index] = - in_data[(nc * in_depth * in_height + in_d * in_height + in_h) * - in_width + - in_w]; - } -} - -template -__global__ void Pad3DReflectNDHWC(const int nthreads, const T* in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { - CUDA_KERNEL_LOOP(index, nthreads) { - int n = index / channels; - const int c = index % channels; - const int out_w = n % out_width; - n /= out_width; - const int out_h = n % out_height; - n /= out_height; - const int out_d = n % out_depth; - n /= out_depth; - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = max(in_d, -in_d); - in_d = min(in_d, 2 * in_depth - in_d - 2); - in_h = max(in_h, -in_h); - in_h = min(in_h, 2 * in_height - in_h - 2); - in_w = max(in_w, -in_w); - in_w = min(in_w, 2 * in_width - in_w - 2); - - out_data[index] = in_data[n * in_depth * in_height * in_width * channels + - in_d * in_height * in_width * channels + - in_h * in_width * channels + in_w * channels + c]; - } -} - -template -__global__ void Pad3DReplicateNCDHW(const int nthreads, const T* in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { - CUDA_KERNEL_LOOP(index, nthreads) { - int nc = index / out_width; - - const int out_w = index % out_width; - const int out_h = nc % out_height; - nc /= out_height; - const int out_d = nc % out_depth; - nc /= out_depth; - - int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); - int in_h = min(in_height - 1, max(out_h - pad_top, 0)); - int in_w = min(in_width - 1, max(out_w - pad_left, 0)); - - out_data[index] = - in_data[(nc * in_depth * in_height + in_d * in_height + in_h) * - in_width + - in_w]; - } -} - -template -__global__ void Pad3DReplicateNDHWC(const int nthreads, const T* in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { - CUDA_KERNEL_LOOP(index, nthreads) { - int n = index / channels; - const int c = index % channels; - const int out_w = n % out_width; - n /= out_width; - const int out_h = n % out_height; - n /= out_height; - const int out_d = n % out_depth; - n /= out_depth; - - int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); - int in_h = min(in_height - 1, max(out_h - pad_top, 0)); - int in_w = min(in_width - 1, max(out_w - pad_left, 0)); - - out_data[index] = in_data[n * in_depth * in_height * in_width * channels + - in_d * in_height * in_width * channels + - in_h * in_width * channels + in_w * channels + c]; - } -} - -template -__global__ void Pad3DCircularNCDHW(const int nthreads, const T* in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { - CUDA_KERNEL_LOOP(index, nthreads) { - int nc = index / out_width; - - const int out_w = index % out_width; - const int out_h = nc % out_height; - nc /= out_height; - const int out_d = nc % out_depth; - nc /= out_depth; - - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - - out_data[index] = - in_data[(nc * in_depth * in_height + in_d * in_height + in_h) * - in_width + - in_w]; - } -} - -template -__global__ void Pad3DCircularNDHWC(const int nthreads, const T* in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { - CUDA_KERNEL_LOOP(index, nthreads) { - int n = index / channels; - const int c = index % channels; - const int out_w = n % out_width; - n /= out_width; - const int out_h = n % out_height; - n /= out_height; - const int out_d = n % out_depth; - n /= out_depth; - - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - - out_data[index] = in_data[n * in_depth * in_height * in_width * channels + - in_d * in_height * in_width * channels + - in_h * in_width * channels + in_w * channels + c]; - } -} - -template -__global__ void Pad3DGradConstNCDHW(const int in_size, T* d_in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { - CUDA_KERNEL_LOOP(in_index, in_size) { - const int in_w = in_index % in_width; - - int nc = in_index / in_width; - const int in_h = nc % in_height; - - nc /= in_height; - const int in_d = nc % in_depth; - - nc /= in_depth; - - const int out_d = in_d + pad_front; - const int out_h = in_h + pad_top; - const int out_w = in_w + pad_left; - d_in_data[in_index] = - d_out_data[nc * out_depth * out_height * out_width + - out_d * out_height * out_width + out_h * out_width + out_w]; - } -} - -template -__global__ void Pad3DGradConstNDHWC(const int in_size, T* d_in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { - CUDA_KERNEL_LOOP(in_index, in_size) { - const int c = in_index % channels; - int n = in_index / channels; - - const int in_w = n % in_width; - n /= in_width; - - const int in_h = n % in_height; - n /= in_height; - - const int in_d = n % in_depth; - n /= in_depth; - - const int out_d = in_d + pad_front; - const int out_h = in_h + pad_top; - const int out_w = in_w + pad_left; - - d_in_data[in_index] = - d_out_data[n * out_depth * out_height * out_width * channels + - out_d * out_height * out_width * channels + - out_h * out_width * channels + out_w * channels + c]; - } -} - -template -__global__ void Pad3DGradReflectNCDHW(const int out_size, T* d_in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { - CUDA_KERNEL_LOOP(out_index, out_size) { - int nc = out_index / out_width; - const int out_w = out_index % out_width; - const int out_h = nc % out_height; - nc /= out_height; - const int out_d = nc % out_depth; - nc /= out_depth; - - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = max(in_d, -in_d); - in_h = max(in_h, -in_h); - in_w = max(in_w, -in_w); - - in_d = min(in_d, 2 * in_depth - in_d - 2); - in_h = min(in_h, 2 * in_height - in_h - 2); - in_w = min(in_w, 2 * in_width - in_w - 2); - - platform::CudaAtomicAdd( - &d_in_data[nc * in_depth * in_height * in_width + - in_d * in_height * in_width + in_h * in_width + in_w], - d_out_data[out_index]); - } -} - -template -__global__ void Pad3DGradReflectNDHWC(const int out_size, T* d_in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { - CUDA_KERNEL_LOOP(out_index, out_size) { - const int c = out_index % channels; - int n = out_index / channels; - const int out_w = n % out_width; - n /= out_width; - const int out_h = n % out_height; - n /= out_height; - const int out_d = n % out_depth; - n /= out_depth; - - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = max(in_d, -in_d); - in_h = max(in_h, -in_h); - in_w = max(in_w, -in_w); - - in_d = min(in_d, in_depth * 2 - in_d - 2); - in_h = min(in_h, in_height * 2 - in_h - 2); - in_w = min(in_w, in_width * 2 - in_w - 2); - platform::CudaAtomicAdd( - &d_in_data[n * in_depth * in_height * in_width * channels + - in_d * in_height * in_width * channels + - in_h * in_width * channels + in_w * channels + c], - d_out_data[out_index]); - } -} - -template -__global__ void Pad3DGradReplicateNCDHW( - const int out_size, T* d_in_data, const int num, const int channels, - const int in_depth, const int in_height, const int in_width, - const int out_depth, const int out_height, const int out_width, - const int pad_front, const int pad_top, const int pad_left, - const T* d_out_data) { - CUDA_KERNEL_LOOP(out_index, out_size) { - int nc = out_index / out_width; - const int out_w = out_index % out_width; - const int out_h = nc % out_height; - nc /= out_height; - const int out_d = nc % out_depth; - nc /= out_depth; - - const int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); - const int in_h = min(in_height - 1, max(out_h - pad_top, 0)); - const int in_w = min(in_width - 1, max(out_w - pad_left, 0)); - - platform::CudaAtomicAdd( - &d_in_data[nc * in_depth * in_height * in_width + - in_d * in_height * in_width + in_h * in_width + in_w], - d_out_data[out_index]); - } -} - -template -__global__ void Pad3DGradReplicateNDHWC( - const int out_size, T* d_in_data, const int num, const int channels, - const int in_depth, const int in_height, const int in_width, - const int out_depth, const int out_height, const int out_width, - const int pad_front, const int pad_top, const int pad_left, - const T* d_out_data) { - CUDA_KERNEL_LOOP(out_index, out_size) { - const int c = out_index % channels; - int n = out_index / channels; - const int out_w = n % out_width; - n /= out_width; - const int out_h = n % out_height; - n /= out_height; - const int out_d = n % out_depth; - n /= out_depth; - - const int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); - const int in_h = min(in_height - 1, max(out_h - pad_top, 0)); - const int in_w = min(in_width - 1, max(out_w - pad_left, 0)); - - platform::CudaAtomicAdd( - &d_in_data[n * in_depth * in_height * in_width * channels + - in_d * in_height * in_width * channels + - in_h * in_width * channels + in_w * channels + c], - d_out_data[out_index]); - } -} - -template -__global__ void Pad3DGradCircularNCDHW(const int out_size, T* d_in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, - const int out_width, const int pad_front, - const int pad_top, const int pad_left, - const T* d_out_data) { - CUDA_KERNEL_LOOP(out_index, out_size) { - int nc = out_index / out_width; - const int out_w = out_index % out_width; - const int out_h = nc % out_height; - nc /= out_height; - const int out_d = nc % out_depth; - nc /= out_depth; - - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - - platform::CudaAtomicAdd( - &d_in_data[nc * in_depth * in_height * in_width + - in_d * in_height * in_width + in_h * in_width + in_w], - d_out_data[out_index]); - } -} - -template -__global__ void Pad3DGradCircularNDHWC(const int out_size, T* d_in_data, - const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, - const int out_width, const int pad_front, - const int pad_top, const int pad_left, - const T* d_out_data) { - CUDA_KERNEL_LOOP(out_index, out_size) { - const int c = out_index % channels; - int n = out_index / channels; - const int out_w = n % out_width; - n /= out_width; - const int out_h = n % out_height; - n /= out_height; - const int out_d = n % out_depth; - n /= out_depth; - - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - - platform::CudaAtomicAdd( - &d_in_data[n * in_depth * in_height * in_width * channels + - in_d * in_height * in_width * channels + - in_h * in_width * channels + in_w * channels + c], - d_out_data[out_index]); - } -} - -static inline std::vector GetPaddings( - const framework::ExecutionContext& context) { - std::vector paddings(6); - auto* paddings_data = context.Input("Paddings"); - if (paddings_data) { - Tensor pads; - framework::TensorCopySync(*paddings_data, platform::CPUPlace(), &pads); - auto pads_data = pads.data(); - std::memcpy(paddings.data(), pads_data, paddings.size() * sizeof(int)); - } else { - auto pads = context.Attr>("paddings"); - std::copy(pads.begin(), pads.end(), paddings.data()); - } - return paddings; -} - -template -class Pad3dCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - std::vector pads = GetPaddings(context); - auto mode = context.Attr("mode"); - auto data_format = context.Attr("data_format"); - T value = static_cast(context.Attr("value")); - - auto* x = context.Input("X"); - auto in_dims = x->dims(); - const T* in_data = x->data(); - auto* out = context.Output("Out"); - auto out_dims = out->dims(); - if (data_format == "NCDHW") { - out_dims[0] = in_dims[0]; - out_dims[1] = in_dims[1]; - out_dims[2] = in_dims[2] + pads[4] + pads[5]; - out_dims[3] = in_dims[3] + pads[2] + pads[3]; - out_dims[4] = in_dims[4] + pads[0] + pads[1]; - } else { - out_dims[0] = in_dims[0]; - out_dims[1] = in_dims[1] + pads[4] + pads[5]; - out_dims[2] = in_dims[2] + pads[2] + pads[3]; - out_dims[3] = in_dims[3] + pads[0] + pads[1]; - out_dims[4] = in_dims[4]; - } - T* out_data = out->mutable_data(out_dims, context.GetPlace()); - - int channels = in_dims[1]; - int in_depth = in_dims[2]; - int in_height = in_dims[3]; - int in_width = in_dims[4]; - int out_depth = out_dims[2]; - int out_height = out_dims[3]; - int out_width = out_dims[4]; - if (data_format == "NDHWC") { - channels = in_dims[4]; - in_depth = in_dims[1]; - in_height = in_dims[2]; - in_width = in_dims[3]; - out_depth = out_dims[1]; - out_height = out_dims[2]; - out_width = out_dims[3]; - } - - if (mode == "reflect") { - PADDLE_ENFORCE_GT(in_depth, pads[4], - platform::errors::InvalidArgument( - "The depth of Input(X)'s dimension should be " - "greater than pad_front" - " in reflect mode" - ", but received depth(%d) and pad_front(%d).", - in_depth, pads[4])); - PADDLE_ENFORCE_GT(in_depth, pads[5], - platform::errors::InvalidArgument( - "The depth of Input(X)'s dimension should be " - "greater than pad_back" - " in reflect mode" - ", but received depth(%d) and pad_back(%d).", - in_depth, pads[5])); - - PADDLE_ENFORCE_GT(in_height, pads[2], - platform::errors::InvalidArgument( - "The height of Input(X)'s dimension should be " - "greater than pad_top" - " in reflect mode" - ", but received depth(%d) and pad_top(%d).", - in_height, pads[2])); - PADDLE_ENFORCE_GT(in_height, pads[3], - platform::errors::InvalidArgument( - "The height of Input(X)'s dimension should be " - "greater than pad_bottom" - " in reflect mode" - ", but received depth(%d) and pad_bottom(%d).", - in_height, pads[3])); - - PADDLE_ENFORCE_GT(in_width, pads[0], - platform::errors::InvalidArgument( - "The width of Input(X)'s dimension should be " - "greater than pad_left" - " in reflect mode" - ", but received depth(%d) and pad_left(%d).", - in_width, pads[0])); - PADDLE_ENFORCE_GT(in_width, pads[1], - platform::errors::InvalidArgument( - "The width of Input(X)'s dimension should be " - "greater than pad_right" - " in reflect mode" - ", but received depth(%d) and pad_right(%d).", - in_width, pads[1])); - } else if (mode == "circular" || mode == "replicate") { - PADDLE_ENFORCE_NE(in_depth * in_height * in_width, 0, - platform::errors::InvalidArgument( - "The input tensor size can not be 0 for circular " - "or replicate padding mode.")); - } - - const int pad_left = pads[0]; - const int pad_top = pads[2]; - const int pad_front = pads[4]; - const int num = in_dims[0]; - - auto stream = context.cuda_device_context().stream(); - int block = PADDLE_CUDA_NUM_THREADS; - const int out_size = out->numel(); - int grid = (out_size + block - 1) / block; - - if (data_format == "NCDHW") { - if (mode == "reflect") { - Pad3DReflectNCDHW<<>>( - out_size, in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - out_data); - } else if (mode == "replicate") { - Pad3DReplicateNCDHW<<>>( - out_size, in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - out_data); - } else if (mode == "circular") { - Pad3DCircularNCDHW<<>>( - out_size, in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - out_data); - } else { - Pad3DConstNCDHW<<>>( - out_size, in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - value, out_data); - } - } else { - if (mode == "reflect") { - Pad3DReflectNDHWC<<>>( - out_size, in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - out_data); - } else if (mode == "replicate") { - Pad3DReplicateNDHWC<<>>( - out_size, in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - out_data); - } else if (mode == "circular") { - Pad3DCircularNDHWC<<>>( - out_size, in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - out_data); - } else { - Pad3DConstNDHWC<<>>( - out_size, in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - value, out_data); - } - } - } -}; - -template -class Pad3dGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - std::vector pads = GetPaddings(context); - auto mode = context.Attr("mode"); - auto data_format = context.Attr("data_format"); - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_in = context.Output(framework::GradVarName("X")); - auto d_in_dims = d_in->dims(); - auto d_out_dims = d_out->dims(); - const T* d_out_data = d_out->data(); - T* d_in_data = d_in->mutable_data(context.GetPlace()); - - phi::funcs::SetConstant set_zero; - set_zero(context.template device_context(), - d_in, static_cast(0)); - - const int pad_left = pads[0]; - const int pad_top = pads[2]; - const int pad_front = pads[4]; - - const int num = d_in_dims[0]; - - auto stream = context.cuda_device_context().stream(); - int block = PADDLE_CUDA_NUM_THREADS; - const int out_size = d_out->numel(); - const int in_size = d_in->numel(); - int grid = (out_size + block - 1) / block; - - if (data_format == "NCDHW") { - const int channels = d_in_dims[1]; - const int in_depth = d_in_dims[2]; - const int in_height = d_in_dims[3]; - const int in_width = d_in_dims[4]; - const int out_depth = d_out_dims[2]; - const int out_height = d_out_dims[3]; - const int out_width = d_out_dims[4]; - - if (mode == "reflect") { - Pad3DGradReflectNCDHW<<>>( - out_size, d_in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - d_out_data); - } else if (mode == "replicate") { - Pad3DGradReplicateNCDHW<<>>( - out_size, d_in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - d_out_data); - } else if (mode == "circular") { - Pad3DGradCircularNCDHW<<>>( - out_size, d_in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - d_out_data); - } else { - grid = (in_size + block - 1) / block; - Pad3DGradConstNCDHW<<>>( - in_size, d_in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - d_out_data); - } - } else { - const int channels = d_in_dims[4]; - const int in_depth = d_in_dims[1]; - const int in_height = d_in_dims[2]; - const int in_width = d_in_dims[3]; - const int out_depth = d_out_dims[1]; - const int out_height = d_out_dims[2]; - const int out_width = d_out_dims[3]; - if (mode == "reflect") { - Pad3DGradReflectNDHWC<<>>( - out_size, d_in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - d_out_data); - } else if (mode == "replicate") { - Pad3DGradReplicateNDHWC<<>>( - out_size, d_in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - d_out_data); - } else if (mode == "circular") { - Pad3DGradCircularNDHWC<<>>( - out_size, d_in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - d_out_data); - } else { - grid = (in_size + block - 1) / block; - Pad3DGradConstNDHWC<<>>( - in_size, d_in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, pad_left, - d_out_data); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL(pad3d, ops::Pad3dCUDAKernel, - ops::Pad3dCUDAKernel, - ops::Pad3dCUDAKernel, ops::Pad3dCUDAKernel, - ops::Pad3dCUDAKernel); -REGISTER_OP_CUDA_KERNEL(pad3d_grad, ops::Pad3dGradCUDAKernel, - ops::Pad3dGradCUDAKernel, - ops::Pad3dGradCUDAKernel); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index bcbc8f5262..7c5f38744f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -877,6 +877,77 @@ void PadInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void Pad3dInferMeta(const MetaTensor& x, + const ScalarArray& paddings_scalar_array, + const std::string& mode, + float value, + const std::string& data_format, + MetaTensor* out, + MetaConfig config) { + auto x_dim = x.dims(); + PADDLE_ENFORCE_EQ(x_dim.size(), + 5, + errors::InvalidArgument( + "The size of Input(X)'s dimension should be equal to " + "5, but received %d. ", + x_dim.size())); + + std::vector out_dims(x_dim.size()); + out_dims[0] = x_dim[0]; + if (paddings_scalar_array.FromTensor()) { + if (config.is_runtime) { + PADDLE_ENFORCE_EQ( + paddings_scalar_array.GetData().size(), + 6, + errors::InvalidArgument("Shape of Input(Paddings) should be equal to " + "[6], but received [%d].", + paddings_scalar_array.GetData().size())); + } + out_dims[1] = x_dim[1]; + out_dims[2] = x_dim[2]; + out_dims[3] = x_dim[3]; + } else { + auto paddings = paddings_scalar_array.GetData(); + + PADDLE_ENFORCE_EQ( + paddings.size(), + 6, + errors::InvalidArgument( + "Size of paddings should be equal to 6, but received %d.", + static_cast(paddings.size()))); + if (data_format == "NCDHW") { + out_dims[1] = x_dim[1]; // channel + out_dims[2] = ((!config.is_runtime) && (x_dim[2] < 0)) + ? x_dim[2] + : (x_dim[2] + paddings[4] + paddings[5]); // depth + + out_dims[3] = ((!config.is_runtime) && (x_dim[3] < 0)) + ? x_dim[3] + : (x_dim[3] + paddings[2] + paddings[3]); // height + + out_dims[4] = ((!config.is_runtime) && (x_dim[4] < 0)) + ? x_dim[4] + : (x_dim[4] + paddings[0] + paddings[1]); // width + } else { // NDHWC + out_dims[4] = x_dim[4]; // channel + + out_dims[1] = ((!config.is_runtime) && (x_dim[1] < 0)) + ? x_dim[1] + : (x_dim[1] + paddings[4] + paddings[5]); // depth + out_dims[2] = ((!config.is_runtime) && (x_dim[2] < 0)) + ? x_dim[2] + : (x_dim[2] + paddings[2] + paddings[3]); // height + out_dims[3] = ((!config.is_runtime) && (x_dim[3] < 0)) + ? x_dim[3] + : (x_dim[3] + paddings[0] + paddings[1]); // width + } + } + + out->set_dims(phi::make_ddim(out_dims)); + out->set_dtype(x.dtype()); + out->share_lod(x); +} + void PixelShuffleInferMeta(const MetaTensor& x, int upscale_factor, const std::string& data_format, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 1b4ff7c69a..d84283a65c 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -147,6 +147,14 @@ void PadInferMeta(const MetaTensor& input, MetaTensor* out, MetaConfig config = MetaConfig()); +void Pad3dInferMeta(const MetaTensor& x, + const ScalarArray& paddings, + const std::string& mode, + float value, + const std::string& data_format, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void PixelShuffleInferMeta(const MetaTensor& x, int upscale_factor, const std::string& data_format, diff --git a/paddle/phi/kernels/cpu/pad3d_grad_kernel.cc b/paddle/phi/kernels/cpu/pad3d_grad_kernel.cc new file mode 100644 index 0000000000..b1adb3e206 --- /dev/null +++ b/paddle/phi/kernels/cpu/pad3d_grad_kernel.cc @@ -0,0 +1,480 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pad3d_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void ConstPad3DGradNCDHW(T* d_in_data, + const T* d_out_data, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width)) { + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] = + d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; + } +} + +template +void ConstPad3DGradNDHWC(T* d_in_data, + const T* d_out_data, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width)) { + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] = d_out_data[out_index + c]; + } + } +} + +template +void ReflectPad3DGradNCDHW(T* d_in_data, + const T* d_out_data, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); // reflect by 0 + in_d = std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth + in_h = std::max(in_h, -in_h); // reflect by 0 + in_h = std::min(in_h, 2 * in_height - in_h - 2); // reflect by in_height + in_w = std::max(in_w, -in_w); // reflect by 0 + in_w = std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width + + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += + d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; +} + +template +void ReflectPad3DGradNDHWC(T* d_in_data, + const T* d_out_data, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); + in_d = std::min(in_d, 2 * in_depth - in_d - 2); + in_h = std::max(in_h, -in_h); + in_h = std::min(in_h, 2 * in_height - in_h - 2); + in_w = std::max(in_w, -in_w); + in_w = std::min(in_w, 2 * in_width - in_w - 2); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; + } +} + +template +void ReplicatePad3DGradNCDHW(T* d_in_data, + const T* d_out_data, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += + d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; +} + +template +void ReplicatePad3DGradNDHWC(T* d_in_data, + const T* d_out_data, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; + } +} + +template +void CircularPad3DGradNCDHW(T* d_in_data, + const T* d_out_data, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += + d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; +} + +template +void CircularPad3DGradNDHWC(T* d_in_data, + const T* d_out_data, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; + } +} + +template +void Pad3DGradNCDHW(T* d_in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const T* d_out_data, + void (*pad_func)(T*, + const T*, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int)) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + pad_func(d_in_data, + d_out_data, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + out_d, + out_h, + out_w); + } + } + } + d_in_data += in_depth * in_height * in_width; + d_out_data += out_depth * out_height * out_width; + } + } +} + +template +void Pad3DGradNDHWC(T* d_in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const T* d_out_data, + void (*pad_func)(T*, + const T*, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int)) { + for (int n = 0; n < num; ++n) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + pad_func(d_in_data, + d_out_data, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + out_d, + out_h, + out_w); + } + } + } + d_in_data += in_depth * in_height * in_width * channels; + d_out_data += out_depth * out_height * out_width * channels; + } +} + +template +void Pad3dGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const ScalarArray& paddings, + const std::string& mode, + float pad_value, + const std::string& data_format, + DenseTensor* x_grad) { + std::vector pads = paddings.GetData(); + + auto* d_out = &out_grad; + auto* d_in = x_grad; + auto d_in_dims = d_in->dims(); + auto d_out_dims = d_out->dims(); + const T* d_out_data = d_out->data(); + T* d_in_data = dev_ctx.template Alloc(d_in); + phi::funcs::SetConstant()(dev_ctx, d_in, static_cast(0)); + + const int pad_left = pads[0]; + const int pad_top = pads[2]; + const int pad_front = pads[4]; + const int num = d_in_dims[0]; + if (data_format == "NCDHW") { + const int channels = d_in_dims[1]; + const int in_depth = d_in_dims[2]; + const int in_height = d_in_dims[3]; + const int in_width = d_in_dims[4]; + const int out_depth = d_out_dims[2]; + const int out_height = d_out_dims[3]; + const int out_width = d_out_dims[4]; + + std::map + func_map; + + func_map["reflect"] = ReflectPad3DGradNCDHW; + func_map["replicate"] = ReplicatePad3DGradNCDHW; + func_map["circular"] = CircularPad3DGradNCDHW; + func_map["constant"] = ConstPad3DGradNCDHW; + + Pad3DGradNCDHW(d_in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + d_out_data, + func_map[mode]); + } else { + const int channels = d_in_dims[4]; + const int in_depth = d_in_dims[1]; + const int in_height = d_in_dims[2]; + const int in_width = d_in_dims[3]; + const int out_depth = d_out_dims[1]; + const int out_height = d_out_dims[2]; + const int out_width = d_out_dims[3]; + + std::map + func_map; + + func_map["reflect"] = ReflectPad3DGradNDHWC; + func_map["replicate"] = ReplicatePad3DGradNDHWC; + func_map["circular"] = CircularPad3DGradNDHWC; + func_map["constant"] = ConstPad3DGradNDHWC; + + Pad3DGradNDHWC(d_in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + d_out_data, + func_map[mode]); + } +} +} // namespace phi + +PD_REGISTER_KERNEL( + pad3d_grad, CPU, ALL_LAYOUT, phi::Pad3dGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/pad3d_kernel.cc b/paddle/phi/kernels/cpu/pad3d_kernel.cc new file mode 100644 index 0000000000..0dc01f485f --- /dev/null +++ b/paddle/phi/kernels/cpu/pad3d_kernel.cc @@ -0,0 +1,578 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pad3d_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void ConstPad3DFuncNCDHW(const T* in_data, + T* out_data, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w, + const T value) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + out_data[out_d * out_height * out_width + out_h * out_width + out_w] = + (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width) + ? value + : in_data[in_d * in_height * in_width + in_h * in_width + in_w]; +} + +template +void ConstPad3DFuncNDHWC(const T* in_data, + T* out_data, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w, + const T value) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + if (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width) { + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = value; + } + } else { + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } + } +} + +template +void ReflectPad3DFuncNCDHW(const T* in_data, + T* out_data, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w, + const T value) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); // reflect by 0 + in_d = std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth + in_h = std::max(in_h, -in_h); // reflect by 0 + in_h = std::min(in_h, 2 * in_height - in_h - 2); // reflect by in_height + in_w = std::max(in_w, -in_w); // reflect by 0 + in_w = std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width + + out_data[out_d * out_height * out_width + out_h * out_width + out_w] = + in_data[in_d * in_height * in_width + in_h * in_width + in_w]; +} + +template +void ReflectPad3DFuncNDHWC(const T* in_data, + T* out_data, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w, + const T value) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); + in_d = std::min(in_d, 2 * in_depth - in_d - 2); + in_h = std::max(in_h, -in_h); + in_h = std::min(in_h, 2 * in_height - in_h - 2); + in_w = std::max(in_w, -in_w); + in_w = std::min(in_w, 2 * in_width - in_w - 2); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } +} + +template +void ReplicatePad3DFuncNCDHW(const T* in_data, + T* out_data, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w, + const T value) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + out_data[out_d * out_height * out_width + out_h * out_width + out_w] = + in_data[in_d * in_height * in_width + in_h * in_width + in_w]; +} + +template +void ReplicatePad3DFuncNDHWC(const T* in_data, + T* out_data, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w, + const T value) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } +} + +template +void CircularPad3DFuncNCDHW(const T* in_data, + T* out_data, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w, + const T value) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + out_data[out_d * out_height * out_width + out_h * out_width + out_w] = + in_data[in_d * in_height * in_width + in_h * in_width + in_w]; +} + +template +void CircularPad3DFuncNDHWC(const T* in_data, + T* out_data, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const int out_d, + const int out_h, + const int out_w, + const T value) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } +} + +template +void Pad3DNCDHW(const T* in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + T value, + T* out_data, + void (*pad_func)(const T*, + T*, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const T)) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + pad_func(in_data, + out_data, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + out_d, + out_h, + out_w, + value); + } + } + } + in_data += in_depth * in_height * in_width; + out_data += out_depth * out_height * out_width; + } + } +} + +template +void Pad3DNDHWC(const T* in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + T value, + T* out_data, + void (*pad_func)(const T*, + T*, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const int, + const T)) { + for (int n = 0; n < num; ++n) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + pad_func(in_data, + out_data, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + out_d, + out_h, + out_w, + value); + } + } + } + in_data += in_depth * in_height * in_width * channels; + out_data += out_depth * out_height * out_width * channels; + } +} + +template +void Pad3dKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& paddings, + const std::string& mode, + float pad_value, + const std::string& data_format, + DenseTensor* out) { + T value = static_cast(pad_value); + std::vector pads = paddings.GetData(); + + auto in_dims = x.dims(); + const T* in_data = x.data(); + + if (data_format == "NCDHW") { + out->Resize({in_dims[0], + in_dims[1], + in_dims[2] + pads[4] + pads[5], + in_dims[3] + pads[2] + pads[3], + in_dims[4] + pads[0] + pads[1]}); + } else { + out->Resize({in_dims[0], + in_dims[1] + pads[4] + pads[5], + in_dims[2] + pads[2] + pads[3], + in_dims[3] + pads[0] + pads[1], + in_dims[4]}); + } + + auto out_dims = out->dims(); + T* out_data = dev_ctx.template Alloc(out); + + int channels = in_dims[1]; + int in_depth = in_dims[2]; + int in_height = in_dims[3]; + int in_width = in_dims[4]; + int out_depth = out_dims[2]; + int out_height = out_dims[3]; + int out_width = out_dims[4]; + if (data_format == "NDHWC") { + channels = in_dims[4]; + in_depth = in_dims[1]; + in_height = in_dims[2]; + in_width = in_dims[3]; + out_depth = out_dims[1]; + out_height = out_dims[2]; + out_width = out_dims[3]; + } + + if (mode == "reflect") { + PADDLE_ENFORCE_GT( + in_depth, + pads[4], + errors::InvalidArgument("The depth of Input(X)'s dimension should be " + "greater than pad_front" + " in reflect mode" + ", but received depth(%d) and pad_front(%d).", + in_depth, + pads[4])); + PADDLE_ENFORCE_GT( + in_depth, + pads[5], + errors::InvalidArgument("The depth of Input(X)'s dimension should be " + "greater than pad_back" + " in reflect mode" + ", but received depth(%d) and pad_back(%d).", + in_depth, + pads[5])); + + PADDLE_ENFORCE_GT( + in_height, + pads[2], + errors::InvalidArgument("The height of Input(X)'s dimension should be " + "greater than pad_top" + " in reflect mode" + ", but received depth(%d) and pad_top(%d).", + in_height, + pads[2])); + PADDLE_ENFORCE_GT( + in_height, + pads[3], + errors::InvalidArgument("The height of Input(X)'s dimension should be " + "greater than pad_bottom" + " in reflect mode" + ", but received depth(%d) and pad_bottom(%d).", + in_height, + pads[3])); + + PADDLE_ENFORCE_GT( + in_width, + pads[0], + errors::InvalidArgument("The width of Input(X)'s dimension should be " + "greater than pad_left" + " in reflect mode" + ", but received depth(%d) and pad_left(%d).", + in_width, + pads[0])); + PADDLE_ENFORCE_GT( + in_width, + pads[1], + errors::InvalidArgument("The width of Input(X)'s dimension should be " + "greater than pad_right" + " in reflect mode" + ", but received depth(%d) and pad_right(%d).", + in_width, + pads[1])); + } else if (mode == "circular" || mode == "replicate") { + PADDLE_ENFORCE_NE(in_depth * in_height * in_width, + 0, + errors::InvalidArgument( + "The input tensor size can not be 0 for circular " + "or replicate padding mode.")); + } + + const int pad_left = pads[0]; + const int pad_top = pads[2]; + const int pad_front = pads[4]; + const int num = in_dims[0]; + if (data_format == "NCDHW") { + std::map + func_map; + + func_map["reflect"] = ReflectPad3DFuncNCDHW; + func_map["replicate"] = ReplicatePad3DFuncNCDHW; + func_map["circular"] = CircularPad3DFuncNCDHW; + func_map["constant"] = ConstPad3DFuncNCDHW; + Pad3DNCDHW(in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + value, + out_data, + func_map[mode]); + } else { + std::map + func_map; + + func_map["reflect"] = ReflectPad3DFuncNDHWC; + func_map["replicate"] = ReplicatePad3DFuncNDHWC; + func_map["circular"] = CircularPad3DFuncNDHWC; + func_map["constant"] = ConstPad3DFuncNDHWC; + Pad3DNDHWC(in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + value, + out_data, + func_map[mode]); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + pad3d, CPU, ALL_LAYOUT, phi::Pad3dKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu b/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu new file mode 100644 index 0000000000..5ca8f3d73d --- /dev/null +++ b/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu @@ -0,0 +1,507 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pad3d_grad_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void Pad3DGradConstNCDHW(const int in_size, + T* d_in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(in_index, in_size) { + const int in_w = in_index % in_width; + + int nc = in_index / in_width; + const int in_h = nc % in_height; + + nc /= in_height; + const int in_d = nc % in_depth; + + nc /= in_depth; + + const int out_d = in_d + pad_front; + const int out_h = in_h + pad_top; + const int out_w = in_w + pad_left; + d_in_data[in_index] = + d_out_data[nc * out_depth * out_height * out_width + + out_d * out_height * out_width + out_h * out_width + out_w]; + } +} + +template +__global__ void Pad3DGradConstNDHWC(const int in_size, + T* d_in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(in_index, in_size) { + const int c = in_index % channels; + int n = in_index / channels; + + const int in_w = n % in_width; + n /= in_width; + + const int in_h = n % in_height; + n /= in_height; + + const int in_d = n % in_depth; + n /= in_depth; + + const int out_d = in_d + pad_front; + const int out_h = in_h + pad_top; + const int out_w = in_w + pad_left; + + d_in_data[in_index] = + d_out_data[n * out_depth * out_height * out_width * channels + + out_d * out_height * out_width * channels + + out_h * out_width * channels + out_w * channels + c]; + } +} + +template +__global__ void Pad3DGradReflectNCDHW(const int out_size, + T* d_in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + int nc = out_index / out_width; + const int out_w = out_index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = max(in_d, -in_d); + in_h = max(in_h, -in_h); + in_w = max(in_w, -in_w); + + in_d = min(in_d, 2 * in_depth - in_d - 2); + in_h = min(in_h, 2 * in_height - in_h - 2); + in_w = min(in_w, 2 * in_width - in_w - 2); + + paddle::platform::CudaAtomicAdd( + &d_in_data[nc * in_depth * in_height * in_width + + in_d * in_height * in_width + in_h * in_width + in_w], + d_out_data[out_index]); + } +} + +template +__global__ void Pad3DGradReflectNDHWC(const int out_size, + T* d_in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + const int c = out_index % channels; + int n = out_index / channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = max(in_d, -in_d); + in_h = max(in_h, -in_h); + in_w = max(in_w, -in_w); + + in_d = min(in_d, in_depth * 2 - in_d - 2); + in_h = min(in_h, in_height * 2 - in_h - 2); + in_w = min(in_w, in_width * 2 - in_w - 2); + paddle::platform::CudaAtomicAdd( + &d_in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c], + d_out_data[out_index]); + } +} + +template +__global__ void Pad3DGradReplicateNCDHW(const int out_size, + T* d_in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + int nc = out_index / out_width; + const int out_w = out_index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + const int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); + const int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + const int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + + paddle::platform::CudaAtomicAdd( + &d_in_data[nc * in_depth * in_height * in_width + + in_d * in_height * in_width + in_h * in_width + in_w], + d_out_data[out_index]); + } +} + +template +__global__ void Pad3DGradReplicateNDHWC(const int out_size, + T* d_in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + const int c = out_index % channels; + int n = out_index / channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + + const int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); + const int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + const int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + + paddle::platform::CudaAtomicAdd( + &d_in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c], + d_out_data[out_index]); + } +} + +template +__global__ void Pad3DGradCircularNCDHW(const int out_size, + T* d_in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + int nc = out_index / out_width; + const int out_w = out_index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + paddle::platform::CudaAtomicAdd( + &d_in_data[nc * in_depth * in_height * in_width + + in_d * in_height * in_width + in_h * in_width + in_w], + d_out_data[out_index]); + } +} + +template +__global__ void Pad3DGradCircularNDHWC(const int out_size, + T* d_in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + const int c = out_index % channels; + int n = out_index / channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + paddle::platform::CudaAtomicAdd( + &d_in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c], + d_out_data[out_index]); + } +} + +template +void Pad3dGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const ScalarArray& paddings, + const std::string& mode, + float pad_value, + const std::string& data_format, + DenseTensor* x_grad) { + std::vector pads = paddings.GetData(); + auto* d_out = &out_grad; + auto* d_in = x_grad; + auto d_in_dims = d_in->dims(); + auto d_out_dims = d_out->dims(); + const T* d_out_data = d_out->data(); + T* d_in_data = dev_ctx.template Alloc(d_in); + + phi::funcs::SetConstant()(dev_ctx, d_in, static_cast(0)); + + const int pad_left = pads[0]; + const int pad_top = pads[2]; + const int pad_front = pads[4]; + + const int num = d_in_dims[0]; + + auto stream = dev_ctx.stream(); + int block = PADDLE_CUDA_NUM_THREADS; + const int out_size = d_out->numel(); + const int in_size = d_in->numel(); + int grid = (out_size + block - 1) / block; + + if (data_format == "NCDHW") { + const int channels = d_in_dims[1]; + const int in_depth = d_in_dims[2]; + const int in_height = d_in_dims[3]; + const int in_width = d_in_dims[4]; + const int out_depth = d_out_dims[2]; + const int out_height = d_out_dims[3]; + const int out_width = d_out_dims[4]; + + if (mode == "reflect") { + Pad3DGradReflectNCDHW<<>>(out_size, + d_in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + d_out_data); + } else if (mode == "replicate") { + Pad3DGradReplicateNCDHW<<>>(out_size, + d_in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + d_out_data); + } else if (mode == "circular") { + Pad3DGradCircularNCDHW<<>>(out_size, + d_in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + d_out_data); + } else { + grid = (in_size + block - 1) / block; + Pad3DGradConstNCDHW<<>>(in_size, + d_in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + d_out_data); + } + } else { + const int channels = d_in_dims[4]; + const int in_depth = d_in_dims[1]; + const int in_height = d_in_dims[2]; + const int in_width = d_in_dims[3]; + const int out_depth = d_out_dims[1]; + const int out_height = d_out_dims[2]; + const int out_width = d_out_dims[3]; + if (mode == "reflect") { + Pad3DGradReflectNDHWC<<>>(out_size, + d_in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + d_out_data); + } else if (mode == "replicate") { + Pad3DGradReplicateNDHWC<<>>(out_size, + d_in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + d_out_data); + } else if (mode == "circular") { + Pad3DGradCircularNDHWC<<>>(out_size, + d_in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + d_out_data); + } else { + grid = (in_size + block - 1) / block; + Pad3DGradConstNDHWC<<>>(in_size, + d_in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + d_out_data); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + pad3d_grad, GPU, ALL_LAYOUT, phi::Pad3dGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/pad3d_kernel.cu b/paddle/phi/kernels/gpu/pad3d_kernel.cu new file mode 100644 index 0000000000..2cef77cc0e --- /dev/null +++ b/paddle/phi/kernels/gpu/pad3d_kernel.cu @@ -0,0 +1,588 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pad3d_kernel.h" + +#include + +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void Pad3DConstNCDHW(const int nthreads, + const T* in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + T value, + T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + out_data[index] = + (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width) + ? value + : in_data[nc * in_depth * in_height * in_width + + in_d * in_height * in_width + in_h * in_width + in_w]; + } +} + +template +__global__ void Pad3DConstNDHWC(const int nthreads, + const T* in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + T value, + T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + const int in_d = out_d - pad_front; + const int in_h = out_h - pad_top; + const int in_w = out_w - pad_left; + + out_data[index] = + (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width) + ? value + : in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c]; + } +} + +template +__global__ void Pad3DReflectNCDHW(const int nthreads, + const T* in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = max(in_d, -in_d); // reflect by 0 + in_d = min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth + in_h = max(in_h, -in_h); // reflect by 0 + in_h = min(in_h, 2 * in_height - in_h - 2); // reflect by in_height + in_w = max(in_w, -in_w); // reflect by 0 + in_w = min(in_w, 2 * in_width - in_w - 2); // reflect by in_width + out_data[index] = + in_data[(nc * in_depth * in_height + in_d * in_height + in_h) * + in_width + + in_w]; + } +} + +template +__global__ void Pad3DReflectNDHWC(const int nthreads, + const T* in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = max(in_d, -in_d); + in_d = min(in_d, 2 * in_depth - in_d - 2); + in_h = max(in_h, -in_h); + in_h = min(in_h, 2 * in_height - in_h - 2); + in_w = max(in_w, -in_w); + in_w = min(in_w, 2 * in_width - in_w - 2); + + out_data[index] = in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c]; + } +} + +template +__global__ void Pad3DReplicateNCDHW(const int nthreads, + const T* in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); + int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + + out_data[index] = + in_data[(nc * in_depth * in_height + in_d * in_height + in_h) * + in_width + + in_w]; + } +} + +template +__global__ void Pad3DReplicateNDHWC(const int nthreads, + const T* in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + + int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); + int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + + out_data[index] = in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c]; + } +} + +template +__global__ void Pad3DCircularNCDHW(const int nthreads, + const T* in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + out_data[index] = + in_data[(nc * in_depth * in_height + in_d * in_height + in_h) * + in_width + + in_w]; + } +} + +template +__global__ void Pad3DCircularNDHWC(const int nthreads, + const T* in_data, + const int num, + const int channels, + const int in_depth, + const int in_height, + const int in_width, + const int out_depth, + const int out_height, + const int out_width, + const int pad_front, + const int pad_top, + const int pad_left, + T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + out_data[index] = in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c]; + } +} + +template +void Pad3dKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& paddings, + const std::string& mode, + float pad_value, + const std::string& data_format, + DenseTensor* out) { + std::vector pads = paddings.GetData(); + + auto in_dims = x.dims(); + const T* in_data = x.data(); + auto out_dims = out->dims(); + T value = static_cast(pad_value); + + if (data_format == "NCDHW") { + out_dims[0] = in_dims[0]; + out_dims[1] = in_dims[1]; + out_dims[2] = in_dims[2] + pads[4] + pads[5]; + out_dims[3] = in_dims[3] + pads[2] + pads[3]; + out_dims[4] = in_dims[4] + pads[0] + pads[1]; + } else { + out_dims[0] = in_dims[0]; + out_dims[1] = in_dims[1] + pads[4] + pads[5]; + out_dims[2] = in_dims[2] + pads[2] + pads[3]; + out_dims[3] = in_dims[3] + pads[0] + pads[1]; + out_dims[4] = in_dims[4]; + } + out->Resize(out_dims); + T* out_data = dev_ctx.template Alloc(out); + + int channels = in_dims[1]; + int in_depth = in_dims[2]; + int in_height = in_dims[3]; + int in_width = in_dims[4]; + int out_depth = out_dims[2]; + int out_height = out_dims[3]; + int out_width = out_dims[4]; + if (data_format == "NDHWC") { + channels = in_dims[4]; + in_depth = in_dims[1]; + in_height = in_dims[2]; + in_width = in_dims[3]; + out_depth = out_dims[1]; + out_height = out_dims[2]; + out_width = out_dims[3]; + } + + if (mode == "reflect") { + PADDLE_ENFORCE_GT( + in_depth, + pads[4], + errors::InvalidArgument("The depth of Input(X)'s dimension should be " + "greater than pad_front" + " in reflect mode" + ", but received depth(%d) and pad_front(%d).", + in_depth, + pads[4])); + PADDLE_ENFORCE_GT( + in_depth, + pads[5], + errors::InvalidArgument("The depth of Input(X)'s dimension should be " + "greater than pad_back" + " in reflect mode" + ", but received depth(%d) and pad_back(%d).", + in_depth, + pads[5])); + + PADDLE_ENFORCE_GT( + in_height, + pads[2], + errors::InvalidArgument("The height of Input(X)'s dimension should be " + "greater than pad_top" + " in reflect mode" + ", but received depth(%d) and pad_top(%d).", + in_height, + pads[2])); + PADDLE_ENFORCE_GT( + in_height, + pads[3], + errors::InvalidArgument("The height of Input(X)'s dimension should be " + "greater than pad_bottom" + " in reflect mode" + ", but received depth(%d) and pad_bottom(%d).", + in_height, + pads[3])); + + PADDLE_ENFORCE_GT( + in_width, + pads[0], + errors::InvalidArgument("The width of Input(X)'s dimension should be " + "greater than pad_left" + " in reflect mode" + ", but received depth(%d) and pad_left(%d).", + in_width, + pads[0])); + PADDLE_ENFORCE_GT( + in_width, + pads[1], + errors::InvalidArgument("The width of Input(X)'s dimension should be " + "greater than pad_right" + " in reflect mode" + ", but received depth(%d) and pad_right(%d).", + in_width, + pads[1])); + } else if (mode == "circular" || mode == "replicate") { + PADDLE_ENFORCE_NE(in_depth * in_height * in_width, + 0, + errors::InvalidArgument( + "The input tensor size can not be 0 for circular " + "or replicate padding mode.")); + } + + const int pad_left = pads[0]; + const int pad_top = pads[2]; + const int pad_front = pads[4]; + const int num = in_dims[0]; + + auto stream = dev_ctx.stream(); + int block = PADDLE_CUDA_NUM_THREADS; + const int out_size = out->numel(); + int grid = (out_size + block - 1) / block; + + if (data_format == "NCDHW") { + if (mode == "reflect") { + Pad3DReflectNCDHW<<>>(out_size, + in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + out_data); + } else if (mode == "replicate") { + Pad3DReplicateNCDHW<<>>(out_size, + in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + out_data); + } else if (mode == "circular") { + Pad3DCircularNCDHW<<>>(out_size, + in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + out_data); + } else { + Pad3DConstNCDHW<<>>(out_size, + in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + value, + out_data); + } + } else { + if (mode == "reflect") { + Pad3DReflectNDHWC<<>>(out_size, + in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + out_data); + } else if (mode == "replicate") { + Pad3DReplicateNDHWC<<>>(out_size, + in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + out_data); + } else if (mode == "circular") { + Pad3DCircularNDHWC<<>>(out_size, + in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + out_data); + } else { + Pad3DConstNDHWC<<>>(out_size, + in_data, + num, + channels, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + pad_front, + pad_top, + pad_left, + value, + out_data); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(pad3d, + GPU, + ALL_LAYOUT, + phi::Pad3dKernel, + phi::dtype::float16, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/pad3d_grad_kernel.h b/paddle/phi/kernels/pad3d_grad_kernel.h new file mode 100644 index 0000000000..38f1e5335e --- /dev/null +++ b/paddle/phi/kernels/pad3d_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void Pad3dGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const ScalarArray& paddings, + const std::string& mode, + float pad_value, + const std::string& data_format, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/pad3d_kernel.h b/paddle/phi/kernels/pad3d_kernel.h new file mode 100644 index 0000000000..d8876c3e7b --- /dev/null +++ b/paddle/phi/kernels/pad3d_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void Pad3dKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& paddings, + const std::string& mode, + float pad_value, + const std::string& data_format, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/pad3d_sig.cc b/paddle/phi/ops/compat/pad3d_sig.cc new file mode 100644 index 0000000000..c43b98fa27 --- /dev/null +++ b/paddle/phi/ops/compat/pad3d_sig.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature Pad3dOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Paddings")) { + return KernelSignature( + "pad3d", {"X"}, {"Paddings", "mode", "value", "data_format"}, {"Out"}); + } + + return KernelSignature( + "pad3d", {"X"}, {"paddings", "mode", "value", "data_format"}, {"Out"}); +} + +KernelSignature Pad3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Paddings")) { + return KernelSignature("pad3d_grad", + {"X", GradVarName("Out")}, + {"Paddings", "mode", "value", "data_format"}, + {GradVarName("X")}); + } + return KernelSignature("pad3d_grad", + {"X", GradVarName("Out")}, + {"paddings", "mode", "value", "data_format"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(pad3d, phi::Pad3dOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(pad3d_grad, phi::Pad3dGradOpArgumentMapping); -- GitLab