未验证 提交 88d42398 编写于 作者: R RuohengMa 提交者: GitHub

[PHI] Add reduce_min_grad xpu op and the corresponding unittest (#51431)

* [XPU] add reduce_min_grad XPU kernel

* add unittest for reduce_min xpu op
上级 e2cdd4a3
......@@ -527,6 +527,7 @@ XPUOpMap& get_kl2_ops() {
{"reduce_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})},
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_min_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/xpu/reduce.h"
namespace phi {
template <typename T, typename Context>
void ReduceMinGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
const IntArray& dims_arr,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims_arr, reduce_all);
auto dims = dims_arr.GetData();
dev_ctx.template Alloc<T>(x_grad);
const T* x_data = x.data<T>();
const T* out_data = out.data<T>();
const T* out_grad_data = out_grad.data<T>();
auto* x_grad_data = x_grad->data<T>();
const auto& input_dim_size = x.dims().size();
std::vector<int> true_dims;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) {
true_dims.push_back(dims[i] + input_dim_size);
} else {
true_dims.push_back(dims[i]);
}
}
std::vector<int> ydims(input_dim_size);
std::vector<int> xdims((input_dim_size));
std::set<int> dims_set(true_dims.begin(), true_dims.end());
for (auto i = 0; i < input_dim_size; i++) {
xdims[i] = x.dims()[i];
if (dims_set.find(i) != dims_set.end() || reduce_all) {
ydims[i] = 1;
} else {
ydims[i] = x.dims()[i];
}
}
T* brocast1 = nullptr;
T* brocast2 = nullptr;
bool* equal = nullptr;
PADDLE_ENFORCE_EQ(
xpu_malloc(reinterpret_cast<void**>(&brocast1), x.numel() * sizeof(T)),
XPU_SUCCESS,
errors::ResourceExhausted("XPU has no enough memory"));
PADDLE_ENFORCE_EQ(
xpu_malloc(reinterpret_cast<void**>(&equal), x.numel() * sizeof(bool)),
XPU_SUCCESS,
errors::ResourceExhausted("XPU has no enough memory"));
PADDLE_ENFORCE_EQ(
xpu_malloc(reinterpret_cast<void**>(&brocast2), x.numel() * sizeof(T)),
XPU_SUCCESS,
errors::ResourceExhausted("XPU has no enough memory"));
// use [1] to replace [], because xpu not support []
if (xdims.size() == 0) {
xdims = std::vector<int>({1});
}
if (ydims.size() == 0) {
ydims = std::vector<int>({1});
}
// step 1. brocast out and out_grad
int r =
xpu::broadcast<T>(dev_ctx.x_context(), out_data, brocast1, ydims, xdims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
r = xpu::broadcast<T>(
dev_ctx.x_context(), out_grad_data, brocast2, ydims, xdims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
// step 2. comparse out_brocast and x
r = xpu::equal<T>(dev_ctx.x_context(), x_data, brocast1, equal, x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "equal");
// step 3. get x_grad
r = xpu::constant<T>(dev_ctx.x_context(), brocast1, x.numel(), 0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::select<T>(dev_ctx.x_context(),
equal,
brocast2,
brocast1,
x_grad_data,
xdims,
xdims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "select");
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
xpu_free(brocast1);
xpu_free(brocast2);
xpu_free(equal);
}
} // namespace phi
PD_REGISTER_KERNEL(min_grad, XPU, ALL_LAYOUT, phi::ReduceMinGradKernel, float) {
}
......@@ -47,6 +47,7 @@ class XPUTestReduceMaxOp(XPUOpTestWrapper):
'use_xpu': True,
'reduce_all': self.reduce_all,
'keep_dim': self.keep_dim,
'dim': self.axis,
}
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
if self.attrs['reduce_all']:
......
......@@ -47,6 +47,7 @@ class XPUTestReduceMinOp(XPUOpTestWrapper):
'use_xpu': True,
'reduce_all': self.reduce_all,
'keep_dim': self.keep_dim,
'dim': self.axis,
}
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
if self.attrs['reduce_all']:
......@@ -68,15 +69,50 @@ class XPUTestReduceMinOp(XPUOpTestWrapper):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
self.check_grad_with_place(self.place, ['X'], 'Out')
class XPUTestReduceMinCase1(XPUTestReduceMinBase):
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (0,)
self.reduce_all = False
self.keep_dim = False
class XPUTestReduceMinCase2(XPUTestReduceMinBase):
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (0,)
self.reduce_all = False
self.keep_dim = True
class XPUTestReduceMinCase3(XPUTestReduceMinBase):
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (0,)
self.reduce_all = True
self.keep_dim = False
class XPUTestReduceMinCase4(XPUTestReduceMinBase):
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (1,)
self.reduce_all = False
self.keep_dim = False
class XPUTestReduceMinCase5(XPUTestReduceMinBase):
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (1,)
self.reduce_all = False
self.keep_dim = True
class XPUTestReduceMinCase6(XPUTestReduceMinBase):
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (1,)
self.reduce_all = True
self.keep_dim = False
support_types = get_xpu_op_support_types('reduce_min')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册