未验证 提交 9f74b84e 编写于 作者: X xiongkun 提交者: GitHub

[phi] transfer pad kernel into phi and pass the test_pad_op (#40012)

* add pad forward

* fix error

* transfer pad and pass the test_pad_op
上级 b565b349
......@@ -25,10 +25,10 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#endif
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/kernels/funcs/padding.h"
DECLARE_bool(cudnn_deterministic);
DECLARE_uint64(conv_workspace_size_limit);
......@@ -148,7 +148,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_input;
std::vector<int> padding_common(data_dim, 0);
......@@ -196,13 +196,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
default:
......@@ -488,7 +488,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// cuDNN only supports padding the same amount on every dimension.
// So we create a new padded input tensor.
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_input(input->type());
Tensor transformed_input_grad(input->type());
std::vector<int> padding_common(data_dim, 0);
......@@ -544,13 +544,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
default:
......@@ -956,7 +956,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_X(X->type());
Tensor transformed_ddX(X->type());
......@@ -1004,20 +1004,22 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_X_channel, pad_value,
&transformed_X);
if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_ddX_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
}
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_X_channel, pad_value,
&transformed_X);
if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_ddX_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
}
} break;
......
......@@ -21,8 +21,8 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#endif
#include "paddle/fluid/operators/conv_transpose_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/padding.h"
namespace paddle {
namespace operators {
......@@ -108,7 +108,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
std::vector<int> input_pad(input_transpose.dims().size() * 2, 0);
Tensor transformed_input;
......@@ -139,12 +139,14 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, input_transpose, pad_value, &transformed_input);
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, input_transpose, pad_value,
&transformed_input);
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, input_transpose, pad_value, &transformed_input);
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, input_transpose, pad_value,
&transformed_input);
} break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -375,7 +377,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
std::vector<int> input_pad(input_transpose.dims().size() * 2, 0);
Tensor transformed_output_grad;
......@@ -407,13 +409,13 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, output_grad_transpose, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, output_grad_transpose, pad_value,
&transformed_output_grad);
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, output_grad_transpose, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, output_grad_transpose, pad_value,
&transformed_output_grad);
} break;
default:
......@@ -735,7 +737,7 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_X(X->type());
Tensor transformed_ddX(X->type());
......@@ -794,26 +796,28 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_X_channel, pad_value,
&transformed_X);
if (dO) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_dO_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_dO_channel, pad_value,
&transformed_dO);
}
if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_ddX_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
}
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_X_channel, pad_value,
&transformed_X);
if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_ddX_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
}
} break;
......
......@@ -17,8 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/kernels/funcs/padding.h"
DECLARE_int64(cudnn_exhaustive_search_times);
......@@ -86,7 +86,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_input;
std::vector<int> padding_common(data_dim, 0);
......@@ -118,13 +118,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
default:
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/phi/kernels/funcs/padding.h"
namespace paddle {
namespace operators {
......@@ -50,8 +50,9 @@ class PadConstantLikeKernel : public framework::OpKernel<T> {
pads[j * 2 + 1] = static_cast<int>(in_x->dims()[j] - in_y->dims()[j]);
}
math::PaddingFunctor<DeviceContext, T>(rank, context, pads, pad_value,
*in_y, out);
phi::funcs::PaddingFunctor<DeviceContext, T>(
rank, context.template device_context<DeviceContext>(), pads, pad_value,
*in_y, out);
}
};
......@@ -82,8 +83,9 @@ class PadConstantLikeGradKernel : public framework::OpKernel<T> {
pads[j * 2 + 1] = static_cast<int>(in_dout->dims()[j] - in_y->dims()[j]);
}
math::PaddingGradFunctor<DeviceContext, T>(rank, context, pads, *in_dout,
d_y);
phi::funcs::PaddingGradFunctor<DeviceContext, T>(
rank, context.template device_context<DeviceContext>(), pads, *in_dout,
d_y);
}
};
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pad_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle {
......@@ -167,40 +167,3 @@ REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker,
REGISTER_OPERATOR(pad_grad, ops::PadOpGrad,
ops::PadOpDoubleGradMaker<paddle::framework::OpDesc>,
ops::PadOpDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
pad, ops::PadKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadKernel<paddle::platform::CPUDeviceContext, double>,
ops::PadKernel<paddle::platform::CPUDeviceContext, int>,
ops::PadKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::PadKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::PadKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
pad, ops::PadKernel<paddle::platform::CUDADeviceContext, double>,
ops::PadKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadKernel<paddle::platform::CUDADeviceContext, int>,
ops::PadKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/padding.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class PadKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto pads = context.Attr<std::vector<int>>("paddings");
float pad_value = context.Attr<float>("pad_value");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
int rank = x->dims().size();
math::PaddingFunctor<DeviceContext, T>(rank, context, pads,
static_cast<T>(pad_value), *x, out);
}
};
template <typename DeviceContext, typename T>
class PadGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto pads = context.Attr<std::vector<int>>("paddings");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x == nullptr) {
return;
}
d_x->mutable_data<T>(context.GetPlace());
int rank = d_out->dims().size();
math::PaddingGradFunctor<DeviceContext, T>(rank, context, pads, *d_out,
d_x);
}
};
} // namespace operators
} // namespace paddle
......@@ -23,9 +23,9 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/padding.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "thrust/device_vector.h"
#endif
......@@ -389,8 +389,9 @@ class FFTR2CGradKernel : public framework::OpKernel<T> {
std::vector<int> pads(rank * 2, 0);
pads[axes.back() * 2 + 1] = zero_length;
paddle::operators::math::PaddingFunctor<DeviceContext, C>(
rank, ctx, pads, static_cast<C>(0), *dy, &full_dy);
phi::funcs::PaddingFunctor<DeviceContext, C>(
rank, ctx.template device_context<DeviceContext>(), pads,
static_cast<C>(0), *dy, &full_dy);
fft_c2c_func(dev_ctx, &full_dy, &complex_dx, axes, normalization,
!forward);
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pad_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pad_grad_kernel_impl.h"
PD_REGISTER_KERNEL(pad_grad,
CPU,
ALL_LAYOUT,
phi::PadGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pad_kernel_impl.h"
PD_REGISTER_KERNEL(pad,
CPU,
ALL_LAYOUT,
phi::PadKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -15,21 +15,26 @@ limitations under the License. */
#pragma once
#include <utility>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using EigenTensor = EigenTensor<T, D, MajorType, IndexType>;
template <typename DeviceContext, typename T, size_t D>
void PadFunction(const framework::ExecutionContext& context,
const std::vector<int>& pads, const framework::Tensor& src,
T pad_value, framework::Tensor* out) {
void PadFunction(const DeviceContext& context,
const std::vector<int>& pads,
const DenseTensor& src,
T pad_value,
DenseTensor* out) {
std::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
......@@ -40,16 +45,16 @@ void PadFunction(const framework::ExecutionContext& context,
auto src_tensor = EigenTensor<T, D>::From(src);
auto out_tensor = EigenTensor<T, D>::From(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto& place = *(context.eigen_device());
EigenPad<std::decay_t<decltype(place)>, T, D>::Eval(
place, out_tensor, src_tensor, paddings, pad_value);
}
template <typename DeviceContext, typename T, size_t D>
void PadGradFunction(const framework::ExecutionContext& context,
const std::vector<int>& pads, const framework::Tensor& src,
framework::Tensor* d_out) {
void PadGradFunction(const DeviceContext& context,
const std::vector<int>& pads,
const DenseTensor& src,
DenseTensor* d_out) {
std::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = -pads[i * 2];
......@@ -58,16 +63,18 @@ void PadGradFunction(const framework::ExecutionContext& context,
auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
auto src_tensor = EigenTensor<T, D>::From(src);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto& place = *(context.eigen_device());
EigenPad<std::decay_t<decltype(place)>, T, D>::Eval(
place, d_out_tensor, src_tensor, paddings, static_cast<T>(0));
}
template <typename DeviceContext, typename T>
void PaddingFunctor(int rank, const framework::ExecutionContext& context,
const std::vector<int>& pads, T pad_value,
const framework::Tensor& src, framework::Tensor* out) {
void PaddingFunctor(int rank,
const DeviceContext& context,
const std::vector<int>& pads,
T pad_value,
const DenseTensor& src,
DenseTensor* out) {
switch (rank) {
case 1:
PadFunction<DeviceContext, T, 1>(context, pads, src, pad_value, out);
......@@ -88,16 +95,18 @@ void PaddingFunctor(int rank, const framework::ExecutionContext& context,
PadFunction<DeviceContext, T, 6>(context, pads, src, pad_value, out);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"PadOp only support tensors with no more"
" than 6 dimensions currently."));
PADDLE_THROW(
phi::errors::Unimplemented("PadOp only support tensors with no more"
" than 6 dimensions currently."));
}
}
template <typename DeviceContext, typename T>
void PaddingGradFunctor(int rank, const framework::ExecutionContext& context,
void PaddingGradFunctor(int rank,
const DeviceContext& context,
const std::vector<int>& pads,
const framework::Tensor& src, framework::Tensor* out) {
const DenseTensor& src,
DenseTensor* out) {
switch (rank) {
case 1:
PadGradFunction<DeviceContext, T, 1>(context, pads, src, out);
......@@ -118,9 +127,9 @@ void PaddingGradFunctor(int rank, const framework::ExecutionContext& context,
PadGradFunction<DeviceContext, T, 6>(context, pads, src, out);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"PadOp only support tensors with no more"
" than 6 dimensions currently."));
PADDLE_THROW(
phi::errors::Unimplemented("PadOp only support tensors with no more"
" than 6 dimensions currently."));
}
}
......@@ -137,6 +146,5 @@ inline bool IsSymmetricPadding(const std::vector<int>& pads,
}
return is_sys_pad;
}
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pad_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pad_grad_kernel_impl.h"
PD_REGISTER_KERNEL(pad_grad,
GPU,
ALL_LAYOUT,
phi::PadGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pad_kernel_impl.h"
#include "paddle/phi/kernels/pad_kernel.h"
PD_REGISTER_KERNEL(pad,
GPU,
ALL_LAYOUT,
phi::PadKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/padding.h"
namespace phi {
template <typename T, typename Context>
void PadGradKernel(const Context& dev_ctx,
const DenseTensor& d_out,
const std::vector<int>& paddings,
float pad_value,
DenseTensor* d_x) {
if (d_x == nullptr) {
return;
}
dev_ctx.template Alloc<T>(d_x);
int rank = d_out.dims().size();
phi::funcs::PaddingGradFunctor<Context, T>(
rank, dev_ctx, paddings, d_out, d_x);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/padding.h"
namespace phi {
template <typename T, typename Context>
void PadKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& paddings,
float pad_value,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
int rank = x.dims().size();
funcs::PaddingFunctor<Context, T>(
rank, dev_ctx, paddings, static_cast<T>(pad_value), x, out);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void PadGradKernel(const Context& dev_ctx,
const DenseTensor& d_out,
const std::vector<int>& paddings,
float pad_value,
DenseTensor* d_x);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void PadKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& paddings,
float pad_value,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PadGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("pad_grad",
{GradVarName("Out")},
{"paddings", "pad_value"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(pad_grad, phi::PadGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册