未验证 提交 8cc09552 编写于 作者: Y YuanRisheng 提交者: GitHub

refactor reshape grad kernel (#38833)

上级 be817719
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/kernels/reshape_grad_kernel.h"
#include "paddle/pten/kernels/reshape_kernel.h" #include "paddle/pten/kernels/reshape_kernel.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -467,13 +468,27 @@ class ReshapeGradKernel { ...@@ -467,13 +468,27 @@ class ReshapeGradKernel {
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto in_dims = d_x->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type()); d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopy(
*d_out, ctx.GetPlace(), auto pt_d_x = paddle::experimental::MakePtenDenseTensor(*d_x);
ctx.template device_context<platform::DeviceContext>(), d_x); auto pt_d_out = paddle::experimental::MakePtenDenseTensor(*d_out);
d_x->Resize(in_dims);
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get());
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get());
}
#endif
} }
}; };
...@@ -482,14 +497,27 @@ class ReshapeDoubleGradKernel { ...@@ -482,14 +497,27 @@ class ReshapeDoubleGradKernel {
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
auto *dd_x = ctx.Input<framework::Tensor>("DDX"); auto *dd_x = ctx.Input<framework::Tensor>("DDX");
auto *dd_out = ctx.Output<framework::Tensor>("DDOut"); auto *dd_out = ctx.Output<framework::Tensor>("DDOut");
dd_out->mutable_data(ctx.GetPlace(), dd_x->type());
auto out_dims = dd_out->dims(); auto pt_dd_x = paddle::experimental::MakePtenDenseTensor(*dd_x);
auto pt_dd_out = paddle::experimental::MakePtenDenseTensor(*dd_out);
dd_out->mutable_data(ctx.GetPlace(), dd_x->type()); if (platform::is_cpu_place(ctx.GetPlace())) {
framework::TensorCopy( auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
*dd_x, ctx.GetPlace(), pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get());
ctx.template device_context<platform::DeviceContext>(), dd_out); }
dd_out->Resize(out_dims); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get());
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get());
}
#endif
} }
}; };
...@@ -624,6 +652,13 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -624,6 +652,13 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("reshape_grad",
{framework::GradVarName("Out")}, {},
{framework::GradVarName("X")});
}
}; };
class Reshape2DoubleGradOp : public framework::OperatorWithKernel { class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
...@@ -660,6 +695,11 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { ...@@ -660,6 +695,11 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("reshape_double_grad", {"DDX"}, {},
{"DDOut"});
}
}; };
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"});
......
...@@ -35,6 +35,8 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = { ...@@ -35,6 +35,8 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"reduce_mean", "mean"}, {"reduce_mean", "mean"},
{"reduce_sum", "sum"}, {"reduce_sum", "sum"},
{"reshape2", "reshape"}, {"reshape2", "reshape"},
{"reshape2_grad", "reshape_grad"},
{"reshape2_grad_grad", "reshape_double_grad"},
// fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated // fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated
{"flatten", "deprecated"}, {"flatten", "deprecated"},
{"flatten_grad", "deprecated"}, {"flatten_grad", "deprecated"},
...@@ -43,6 +45,7 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = { ...@@ -43,6 +45,7 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"matmul_grad_grad", "deprecated"}, {"matmul_grad_grad", "deprecated"},
{"mean", "deprecated"}, {"mean", "deprecated"},
{"reshape", "deprecated"}, {"reshape", "deprecated"},
{"reshape_grad", "deprecated"},
{"sum", "deprecated"}}; {"sum", "deprecated"}};
} // namespace pten } // namespace pten
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/reshape_grad_kernel.h"
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/copy_kernel.h"
namespace pten {
template <typename Context>
void ReshapeGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto x_dims = x_grad->dims();
pten::Copy(dev_ctx, out_grad, false, x_grad);
x_grad->Resize(x_dims);
}
template <typename Context>
void ReshapeDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x_grad_grad,
DenseTensor* out_grad_grad) {
ReshapeGradKernel(dev_ctx, x_grad_grad, out_grad_grad);
}
} // namespace pten
PT_REGISTER_GENERAL_KERNEL(reshape_grad,
CPU,
ALL_LAYOUT,
pten::ReshapeGradKernel<pten::CPUContext>,
ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_double_grad,
CPU,
ALL_LAYOUT,
pten::ReshapeDoubleGradKernel<pten::CPUContext>,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_GENERAL_KERNEL(reshape_grad,
GPU,
ALL_LAYOUT,
pten::ReshapeGradKernel<pten::GPUContext>,
ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_double_grad,
GPU,
ALL_LAYOUT,
pten::ReshapeDoubleGradKernel<pten::GPUContext>,
ALL_DTYPE) {}
#endif
#ifdef PADDLE_WITH_XPU
PT_REGISTER_GENERAL_KERNEL(reshape_grad,
XPU,
ALL_LAYOUT,
pten::ReshapeGradKernel<pten::XPUContext>,
ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_double_grad,
XPU,
ALL_LAYOUT,
pten::ReshapeDoubleGradKernel<pten::XPUContext>,
ALL_DTYPE) {}
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename Context>
void ReshapeGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
DenseTensor* x_grad);
template <typename Context>
void ReshapeDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x_grad_grad,
DenseTensor* out_grad_grad);
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册