未验证 提交 8de336f9 编写于 作者: H houj04 提交者: GitHub

[XPU] add tile_grad op (#48720)

上级 8fb829ba
...@@ -663,6 +663,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -663,6 +663,7 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"tile_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"transpose2_grad", {"transpose2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
......
...@@ -97,8 +97,8 @@ void TileGradKernel(const Context& dev_ctx, ...@@ -97,8 +97,8 @@ void TileGradKernel(const Context& dev_ctx,
PADDLE_ENFORCE_GE(dims, PADDLE_ENFORCE_GE(dims,
1, 1,
errors::InvalidArgument( errors::InvalidArgument(
"Th rank of the input 'Out@GRAD' for tile_grad op " "The rank of the input 'Out@GRAD' for tile_grad op "
" must be greater than or equal to 1, but " "must be greater than or equal to 1, but "
"the value received is %d.", "the value received is %d.",
dims)); dims));
PADDLE_ENFORCE_LE(dims, PADDLE_ENFORCE_LE(dims,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/tile_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void TileGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& repeat_times,
DenseTensor* x_grad) {
auto x_dims = x.dims();
auto vec_x_dims = phi::vectorize<int>(x_dims);
auto repeat_times_data = repeat_times.GetData();
if (repeat_times_data.size() < vec_x_dims.size()) {
int diff = vec_x_dims.size() - repeat_times_data.size();
repeat_times_data.insert(repeat_times_data.begin(), diff, 1);
} else {
int diff = repeat_times_data.size() - vec_x_dims.size();
vec_x_dims.insert(vec_x_dims.begin(), diff, 1);
}
// 1. reshape_dims_vec is the broadcast parameter.
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
// each dimension expanded, the gradients should be summed to original
// size.
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times_data.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times_data[i]);
reshape_dims_vec.push_back(vec_x_dims[i]);
}
dev_ctx.template Alloc<T>(x_grad);
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times_data.size(); i++) {
if (repeat_times_data[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
// TensorCopy may change the dims of dx
x_grad->Resize(x_dims);
} else {
PADDLE_ENFORCE_GE(dims,
1,
errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for tile_grad op "
"must be greater than or equal to 1, but "
"the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims,
MAX_RANK_SUPPORTED,
errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for tile_grad op "
"must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
dims));
using XPUType = typename XPUTypeTrait<T>::Type;
// int reduce_sum(Context* ctx, const T* x, T* y, const std::vector<int>&
// xshape, const std::vector<int>& rdims)
const auto* out_data = out_grad.data<XPUType>();
auto* x_grad_data = x_grad->data<XPUType>();
int r = xpu::reduce_sum<XPUType>(dev_ctx.x_context(),
out_data,
x_grad_data,
reshape_dims_vec,
reduce_dims_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
}
}
} // namespace phi
PD_REGISTER_KERNEL(tile_grad, XPU, ALL_LAYOUT, phi::TileGradKernel, float) {}
...@@ -59,6 +59,9 @@ class XPUTestTileOpRank1(XPUOpTestWrapper): ...@@ -59,6 +59,9 @@ class XPUTestTileOpRank1(XPUOpTestWrapper):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
# with dimension expanding # with dimension expanding
class TestTileOpRank2Expanding(TestTileOpRank1): class TestTileOpRank2Expanding(TestTileOpRank1):
def init_data(self): def init_data(self):
...@@ -126,6 +129,9 @@ class XPUTestTileOpRank1_tensor_attr(XPUOpTestWrapper): ...@@ -126,6 +129,9 @@ class XPUTestTileOpRank1_tensor_attr(XPUOpTestWrapper):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestTileOpRank2_Corner_tensor_attr(TestTileOpRank1_tensor_attr): class TestTileOpRank2_Corner_tensor_attr(TestTileOpRank1_tensor_attr):
def init_data(self): def init_data(self):
self.ori_shape = [12, 14] self.ori_shape = [12, 14]
...@@ -168,6 +174,9 @@ class XPUTestTileOpRank1_tensor(XPUOpTestWrapper): ...@@ -168,6 +174,9 @@ class XPUTestTileOpRank1_tensor(XPUOpTestWrapper):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestTileOpRank2_tensor(TestTileOpRank1_tensor): class TestTileOpRank2_tensor(TestTileOpRank1_tensor):
def init_data(self): def init_data(self):
self.ori_shape = [12, 14] self.ori_shape = [12, 14]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册