未验证 提交 9f74b84e 编写于 作者: X xiongkun 提交者: GitHub

[phi] transfer pad kernel into phi and pass the test_pad_op (#40012)

* add pad forward

* fix error

* transfer pad and pass the test_pad_op
上级 b565b349
...@@ -25,10 +25,10 @@ limitations under the License. */ ...@@ -25,10 +25,10 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_helper.h" #include "paddle/fluid/operators/conv_cudnn_helper.h"
#endif #endif
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/kernels/funcs/padding.h"
DECLARE_bool(cudnn_deterministic); DECLARE_bool(cudnn_deterministic);
DECLARE_uint64(conv_workspace_size_limit); DECLARE_uint64(conv_workspace_size_limit);
...@@ -148,7 +148,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -148,7 +148,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_input; Tensor transformed_input;
std::vector<int> padding_common(data_dim, 0); std::vector<int> padding_common(data_dim, 0);
...@@ -196,13 +196,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -196,13 +196,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
T pad_value(0.0); T pad_value(0.0);
switch (rank) { switch (rank) {
case 4: { case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value, dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input); &transformed_input);
} break; } break;
case 5: { case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value, dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input); &transformed_input);
} break; } break;
default: default:
...@@ -488,7 +488,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -488,7 +488,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// cuDNN only supports padding the same amount on every dimension. // cuDNN only supports padding the same amount on every dimension.
// So we create a new padded input tensor. // So we create a new padded input tensor.
int data_dim = strides.size(); // 2d or 3d int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_input(input->type()); Tensor transformed_input(input->type());
Tensor transformed_input_grad(input->type()); Tensor transformed_input_grad(input->type());
std::vector<int> padding_common(data_dim, 0); std::vector<int> padding_common(data_dim, 0);
...@@ -544,13 +544,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -544,13 +544,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
T pad_value(0.0); T pad_value(0.0);
switch (rank) { switch (rank) {
case 4: { case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value, dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input); &transformed_input);
} break; } break;
case 5: { case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value, dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input); &transformed_input);
} break; } break;
default: default:
...@@ -956,7 +956,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -956,7 +956,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_X(X->type()); Tensor transformed_X(X->type());
Tensor transformed_ddX(X->type()); Tensor transformed_ddX(X->type());
...@@ -1004,20 +1004,22 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -1004,20 +1004,22 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
T pad_value(0.0); T pad_value(0.0);
switch (rank) { switch (rank) {
case 4: { case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); dev_ctx, input_pad, transformed_X_channel, pad_value,
&transformed_X);
if (ddX) { if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_ddX_channel, pad_value, dev_ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX); &transformed_ddX);
} }
} break; } break;
case 5: { case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); dev_ctx, input_pad, transformed_X_channel, pad_value,
&transformed_X);
if (ddX) { if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_ddX_channel, pad_value, dev_ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX); &transformed_ddX);
} }
} break; } break;
......
...@@ -21,8 +21,8 @@ limitations under the License. */ ...@@ -21,8 +21,8 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_helper.h" #include "paddle/fluid/operators/conv_cudnn_helper.h"
#endif #endif
#include "paddle/fluid/operators/conv_transpose_op.h" #include "paddle/fluid/operators/conv_transpose_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/padding.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -108,7 +108,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -108,7 +108,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
std::vector<int> input_pad(input_transpose.dims().size() * 2, 0); std::vector<int> input_pad(input_transpose.dims().size() * 2, 0);
Tensor transformed_input; Tensor transformed_input;
...@@ -139,12 +139,14 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -139,12 +139,14 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
T pad_value(0.0); T pad_value(0.0);
switch (rank) { switch (rank) {
case 4: { case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, input_transpose, pad_value, &transformed_input); dev_ctx, input_pad, input_transpose, pad_value,
&transformed_input);
} break; } break;
case 5: { case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, input_transpose, pad_value, &transformed_input); dev_ctx, input_pad, input_transpose, pad_value,
&transformed_input);
} break; } break;
default: default:
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -375,7 +377,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -375,7 +377,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
std::vector<int> input_pad(input_transpose.dims().size() * 2, 0); std::vector<int> input_pad(input_transpose.dims().size() * 2, 0);
Tensor transformed_output_grad; Tensor transformed_output_grad;
...@@ -407,13 +409,13 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -407,13 +409,13 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
T pad_value(0.0); T pad_value(0.0);
switch (rank) { switch (rank) {
case 4: { case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, output_grad_transpose, pad_value, dev_ctx, input_pad, output_grad_transpose, pad_value,
&transformed_output_grad); &transformed_output_grad);
} break; } break;
case 5: { case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, output_grad_transpose, pad_value, dev_ctx, input_pad, output_grad_transpose, pad_value,
&transformed_output_grad); &transformed_output_grad);
} break; } break;
default: default:
...@@ -735,7 +737,7 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -735,7 +737,7 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_X(X->type()); Tensor transformed_X(X->type());
Tensor transformed_ddX(X->type()); Tensor transformed_ddX(X->type());
...@@ -794,26 +796,28 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -794,26 +796,28 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
T pad_value(0.0); T pad_value(0.0);
switch (rank) { switch (rank) {
case 4: { case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); dev_ctx, input_pad, transformed_X_channel, pad_value,
&transformed_X);
if (dO) { if (dO) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_dO_channel, pad_value, dev_ctx, input_pad, transformed_dO_channel, pad_value,
&transformed_dO); &transformed_dO);
} }
if (ddX) { if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_ddX_channel, pad_value, dev_ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX); &transformed_ddX);
} }
} break; } break;
case 5: { case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); dev_ctx, input_pad, transformed_X_channel, pad_value,
&transformed_X);
if (ddX) { if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_ddX_channel, pad_value, dev_ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX); &transformed_ddX);
} }
} break; } break;
......
...@@ -17,8 +17,8 @@ limitations under the License. */ ...@@ -17,8 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/kernels/funcs/padding.h"
DECLARE_int64(cudnn_exhaustive_search_times); DECLARE_int64(cudnn_exhaustive_search_times);
...@@ -86,7 +86,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -86,7 +86,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_input; Tensor transformed_input;
std::vector<int> padding_common(data_dim, 0); std::vector<int> padding_common(data_dim, 0);
...@@ -118,13 +118,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -118,13 +118,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
T pad_value(0.0); T pad_value(0.0);
switch (rank) { switch (rank) {
case 4: { case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value, dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input); &transformed_input);
} break; } break;
case 5: { case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value, dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input); &transformed_input);
} break; } break;
default: default:
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/padding.h" #include "paddle/phi/kernels/funcs/padding.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -50,7 +50,8 @@ class PadConstantLikeKernel : public framework::OpKernel<T> { ...@@ -50,7 +50,8 @@ class PadConstantLikeKernel : public framework::OpKernel<T> {
pads[j * 2 + 1] = static_cast<int>(in_x->dims()[j] - in_y->dims()[j]); pads[j * 2 + 1] = static_cast<int>(in_x->dims()[j] - in_y->dims()[j]);
} }
math::PaddingFunctor<DeviceContext, T>(rank, context, pads, pad_value, phi::funcs::PaddingFunctor<DeviceContext, T>(
rank, context.template device_context<DeviceContext>(), pads, pad_value,
*in_y, out); *in_y, out);
} }
}; };
...@@ -82,7 +83,8 @@ class PadConstantLikeGradKernel : public framework::OpKernel<T> { ...@@ -82,7 +83,8 @@ class PadConstantLikeGradKernel : public framework::OpKernel<T> {
pads[j * 2 + 1] = static_cast<int>(in_dout->dims()[j] - in_y->dims()[j]); pads[j * 2 + 1] = static_cast<int>(in_dout->dims()[j] - in_y->dims()[j]);
} }
math::PaddingGradFunctor<DeviceContext, T>(rank, context, pads, *in_dout, phi::funcs::PaddingGradFunctor<DeviceContext, T>(
rank, context.template device_context<DeviceContext>(), pads, *in_dout,
d_y); d_y);
} }
}; };
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/pad_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
...@@ -167,40 +167,3 @@ REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker, ...@@ -167,40 +167,3 @@ REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker,
REGISTER_OPERATOR(pad_grad, ops::PadOpGrad, REGISTER_OPERATOR(pad_grad, ops::PadOpGrad,
ops::PadOpDoubleGradMaker<paddle::framework::OpDesc>, ops::PadOpDoubleGradMaker<paddle::framework::OpDesc>,
ops::PadOpDoubleGradMaker<paddle::imperative::OpBase>); ops::PadOpDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
pad, ops::PadKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadKernel<paddle::platform::CPUDeviceContext, double>,
ops::PadKernel<paddle::platform::CPUDeviceContext, int>,
ops::PadKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::PadKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::PadKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
pad, ops::PadKernel<paddle::platform::CUDADeviceContext, double>,
ops::PadKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadKernel<paddle::platform::CUDADeviceContext, int>,
ops::PadKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/padding.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class PadKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto pads = context.Attr<std::vector<int>>("paddings");
float pad_value = context.Attr<float>("pad_value");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
int rank = x->dims().size();
math::PaddingFunctor<DeviceContext, T>(rank, context, pads,
static_cast<T>(pad_value), *x, out);
}
};
template <typename DeviceContext, typename T>
class PadGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto pads = context.Attr<std::vector<int>>("paddings");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x == nullptr) {
return;
}
d_x->mutable_data<T>(context.GetPlace());
int rank = d_out->dims().size();
math::PaddingGradFunctor<DeviceContext, T>(rank, context, pads, *d_out,
d_x);
}
};
} // namespace operators
} // namespace paddle
...@@ -23,9 +23,9 @@ ...@@ -23,9 +23,9 @@
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/conj_op.h" #include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/padding.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
#include "thrust/device_vector.h" #include "thrust/device_vector.h"
#endif #endif
...@@ -389,8 +389,9 @@ class FFTR2CGradKernel : public framework::OpKernel<T> { ...@@ -389,8 +389,9 @@ class FFTR2CGradKernel : public framework::OpKernel<T> {
std::vector<int> pads(rank * 2, 0); std::vector<int> pads(rank * 2, 0);
pads[axes.back() * 2 + 1] = zero_length; pads[axes.back() * 2 + 1] = zero_length;
paddle::operators::math::PaddingFunctor<DeviceContext, C>( phi::funcs::PaddingFunctor<DeviceContext, C>(
rank, ctx, pads, static_cast<C>(0), *dy, &full_dy); rank, ctx.template device_context<DeviceContext>(), pads,
static_cast<C>(0), *dy, &full_dy);
fft_c2c_func(dev_ctx, &full_dy, &complex_dx, axes, normalization, fft_c2c_func(dev_ctx, &full_dy, &complex_dx, axes, normalization,
!forward); !forward);
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pad_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pad_grad_kernel_impl.h"
PD_REGISTER_KERNEL(pad_grad,
CPU,
ALL_LAYOUT,
phi::PadGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pad_kernel_impl.h"
PD_REGISTER_KERNEL(pad,
CPU,
ALL_LAYOUT,
phi::PadKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -15,21 +15,26 @@ limitations under the License. */ ...@@ -15,21 +15,26 @@ limitations under the License. */
#pragma once #pragma once
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/tensor.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <typename T, size_t D, int MajorType = Eigen::RowMajor, template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = EigenTensor<T, D, MajorType, IndexType>;
template <typename DeviceContext, typename T, size_t D> template <typename DeviceContext, typename T, size_t D>
void PadFunction(const framework::ExecutionContext& context, void PadFunction(const DeviceContext& context,
const std::vector<int>& pads, const framework::Tensor& src, const std::vector<int>& pads,
T pad_value, framework::Tensor* out) { const DenseTensor& src,
T pad_value,
DenseTensor* out) {
std::array<std::pair<int64_t, int64_t>, D> paddings; std::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < paddings.size(); ++i) {
...@@ -40,16 +45,16 @@ void PadFunction(const framework::ExecutionContext& context, ...@@ -40,16 +45,16 @@ void PadFunction(const framework::ExecutionContext& context,
auto src_tensor = EigenTensor<T, D>::From(src); auto src_tensor = EigenTensor<T, D>::From(src);
auto out_tensor = EigenTensor<T, D>::From(*out); auto out_tensor = EigenTensor<T, D>::From(*out);
auto& place = auto& place = *(context.eigen_device());
*context.template device_context<DeviceContext>().eigen_device();
EigenPad<std::decay_t<decltype(place)>, T, D>::Eval( EigenPad<std::decay_t<decltype(place)>, T, D>::Eval(
place, out_tensor, src_tensor, paddings, pad_value); place, out_tensor, src_tensor, paddings, pad_value);
} }
template <typename DeviceContext, typename T, size_t D> template <typename DeviceContext, typename T, size_t D>
void PadGradFunction(const framework::ExecutionContext& context, void PadGradFunction(const DeviceContext& context,
const std::vector<int>& pads, const framework::Tensor& src, const std::vector<int>& pads,
framework::Tensor* d_out) { const DenseTensor& src,
DenseTensor* d_out) {
std::array<std::pair<int64_t, int64_t>, D> paddings; std::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = -pads[i * 2]; paddings[i].first = -pads[i * 2];
...@@ -58,16 +63,18 @@ void PadGradFunction(const framework::ExecutionContext& context, ...@@ -58,16 +63,18 @@ void PadGradFunction(const framework::ExecutionContext& context,
auto d_out_tensor = EigenTensor<T, D>::From(*d_out); auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
auto src_tensor = EigenTensor<T, D>::From(src); auto src_tensor = EigenTensor<T, D>::From(src);
auto& place = auto& place = *(context.eigen_device());
*context.template device_context<DeviceContext>().eigen_device();
EigenPad<std::decay_t<decltype(place)>, T, D>::Eval( EigenPad<std::decay_t<decltype(place)>, T, D>::Eval(
place, d_out_tensor, src_tensor, paddings, static_cast<T>(0)); place, d_out_tensor, src_tensor, paddings, static_cast<T>(0));
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void PaddingFunctor(int rank, const framework::ExecutionContext& context, void PaddingFunctor(int rank,
const std::vector<int>& pads, T pad_value, const DeviceContext& context,
const framework::Tensor& src, framework::Tensor* out) { const std::vector<int>& pads,
T pad_value,
const DenseTensor& src,
DenseTensor* out) {
switch (rank) { switch (rank) {
case 1: case 1:
PadFunction<DeviceContext, T, 1>(context, pads, src, pad_value, out); PadFunction<DeviceContext, T, 1>(context, pads, src, pad_value, out);
...@@ -88,16 +95,18 @@ void PaddingFunctor(int rank, const framework::ExecutionContext& context, ...@@ -88,16 +95,18 @@ void PaddingFunctor(int rank, const framework::ExecutionContext& context,
PadFunction<DeviceContext, T, 6>(context, pads, src, pad_value, out); PadFunction<DeviceContext, T, 6>(context, pads, src, pad_value, out);
break; break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(
"PadOp only support tensors with no more" phi::errors::Unimplemented("PadOp only support tensors with no more"
" than 6 dimensions currently.")); " than 6 dimensions currently."));
} }
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void PaddingGradFunctor(int rank, const framework::ExecutionContext& context, void PaddingGradFunctor(int rank,
const DeviceContext& context,
const std::vector<int>& pads, const std::vector<int>& pads,
const framework::Tensor& src, framework::Tensor* out) { const DenseTensor& src,
DenseTensor* out) {
switch (rank) { switch (rank) {
case 1: case 1:
PadGradFunction<DeviceContext, T, 1>(context, pads, src, out); PadGradFunction<DeviceContext, T, 1>(context, pads, src, out);
...@@ -118,8 +127,8 @@ void PaddingGradFunctor(int rank, const framework::ExecutionContext& context, ...@@ -118,8 +127,8 @@ void PaddingGradFunctor(int rank, const framework::ExecutionContext& context,
PadGradFunction<DeviceContext, T, 6>(context, pads, src, out); PadGradFunction<DeviceContext, T, 6>(context, pads, src, out);
break; break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(
"PadOp only support tensors with no more" phi::errors::Unimplemented("PadOp only support tensors with no more"
" than 6 dimensions currently.")); " than 6 dimensions currently."));
} }
} }
...@@ -137,6 +146,5 @@ inline bool IsSymmetricPadding(const std::vector<int>& pads, ...@@ -137,6 +146,5 @@ inline bool IsSymmetricPadding(const std::vector<int>& pads,
} }
return is_sys_pad; return is_sys_pad;
} }
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pad_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pad_grad_kernel_impl.h"
PD_REGISTER_KERNEL(pad_grad,
GPU,
ALL_LAYOUT,
phi::PadGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pad_kernel_impl.h"
#include "paddle/phi/kernels/pad_kernel.h"
PD_REGISTER_KERNEL(pad,
GPU,
ALL_LAYOUT,
phi::PadKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/padding.h"
namespace phi {
template <typename T, typename Context>
void PadGradKernel(const Context& dev_ctx,
const DenseTensor& d_out,
const std::vector<int>& paddings,
float pad_value,
DenseTensor* d_x) {
if (d_x == nullptr) {
return;
}
dev_ctx.template Alloc<T>(d_x);
int rank = d_out.dims().size();
phi::funcs::PaddingGradFunctor<Context, T>(
rank, dev_ctx, paddings, d_out, d_x);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/padding.h"
namespace phi {
template <typename T, typename Context>
void PadKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& paddings,
float pad_value,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
int rank = x.dims().size();
funcs::PaddingFunctor<Context, T>(
rank, dev_ctx, paddings, static_cast<T>(pad_value), x, out);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void PadGradKernel(const Context& dev_ctx,
const DenseTensor& d_out,
const std::vector<int>& paddings,
float pad_value,
DenseTensor* d_x);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void PadKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& paddings,
float pad_value,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PadGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("pad_grad",
{GradVarName("Out")},
{"paddings", "pad_value"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(pad_grad, phi::PadGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册