未验证 提交 69aae171 编写于 作者: C Ccc 提交者: GitHub

[Zero-dim] Zero-dim Tensor for XPU prelu, softmax and log_softmax (#50433)

上级 cea6a7c6
......@@ -33,6 +33,7 @@ void LogSoftmaxGradKernel(const Context& dev_ctx,
// For 0D Tensor
if (rank == 0) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}
......
......@@ -32,6 +32,7 @@ void LogSoftmaxKernel(const Context& dev_ctx,
// For 0D Tensor
if (rank == 0) {
dev_ctx.template Alloc<T>(out);
phi::funcs::set_constant(dev_ctx, out, 0.0);
return;
}
......
......@@ -40,17 +40,25 @@ void PReluGradKernel(const Context& dev_ctx,
auto x_rank = x_dim.size();
std::vector<int> x_shape(x_rank);
if (x_rank == 0) {
x_shape = std::vector<int>({1});
} else {
for (int i = 0; i < x_rank; i++) {
x_shape[i] = x_dim[i];
}
}
auto alpha_dim = alpha.dims();
auto alpha_rank = alpha_dim.size();
std::vector<int> alpha_shape(alpha_rank);
if (alpha_rank == 0) {
alpha_shape = std::vector<int>({1});
} else {
for (int i = 0; i < x_rank; i++) {
alpha_shape[i] = alpha_dim[i];
}
}
// mode = 0: channel_nchw, slope_shape = {c}, default. meanwhile, xhsape = {n,
// c, h, w}
......
......@@ -34,21 +34,27 @@ void PReluKernel(const Context& dev_ctx,
auto x_dim = x.dims();
auto x_rank = x_dim.size();
std::vector<int> x_shape(x_rank);
if (x_rank == 0) {
x_shape = std::vector<int>({1});
} else {
for (int i = 0; i < x_rank; i++) {
x_shape[i] = x_dim[i];
}
}
auto alpha_dim = alpha.dims();
auto alpha_rank = alpha_dim.size();
std::vector<int> alpha_shape(x_rank, 1); // same size with x_shape
if (x_rank == 0) {
alpha_shape = std::vector<int>({1});
} else {
for (int i = 0; i < alpha_rank; i++) {
alpha_shape[i] = alpha_dim[i];
}
}
int r = xpu::prelu(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_ptr),
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -35,6 +36,12 @@ void SoftmaxGradKernel(const Context& dev_ctx,
return;
}
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}
std::vector<int> x_dims;
for (int i = 0; i < rank; i++) {
x_dims.push_back(x_grad->dims()[i]);
......
......@@ -1628,24 +1628,29 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [])
def test_prelu(self):
x = paddle.full([], 1.0, 'float32')
x.stop_gradient = False
w1 = paddle.to_tensor([0.25], dtype='float32')
out1 = paddle.nn.functional.prelu(x, w1)
x1 = paddle.full([], 1.0, 'float32')
x1.stop_gradient = False
w1 = paddle.full([], 0.25, dtype='float32')
out1 = paddle.nn.functional.prelu(x1, w1)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out1.grad.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 1.0)
x2 = paddle.full([], -1.0, 'float32')
x2.stop_gradient = False
w2 = paddle.full([], 0.25, dtype='float32')
out2 = paddle.nn.functional.prelu(x, w2)
out2 = paddle.nn.functional.prelu(x2, w2)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(out2.numpy(), -0.25)
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x2.grad.numpy(), 0.25)
def test_while_loop(self):
def cond(i, x):
......
......@@ -84,6 +84,8 @@ unary_api_list = [
paddle.lgamma,
paddle.poisson,
paddle.bernoulli,
paddle.nn.functional.softmax,
paddle.nn.functional.log_softmax,
]
inplace_api_list = [
......@@ -1033,6 +1035,33 @@ class TestSundryAPI(unittest.TestCase):
out2.backward()
self.assertEqual(out2.shape, [1])
def test_prelu(self):
x1 = paddle.full([], 1.0, 'float32')
x1.stop_gradient = False
w1 = paddle.full([], 0.25, dtype='float32')
w1.stop_gradient = False
out1 = paddle.nn.functional.prelu(x1, w1)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out1.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 1.0)
x2 = paddle.full([], -1.0, 'float32')
x2.stop_gradient = False
w2 = paddle.full([], 0.25, dtype='float32')
w2.stop_gradient = False
out2 = paddle.nn.functional.prelu(x2, w2)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(out2.numpy(), -0.25)
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x2.grad.numpy(), 0.25)
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册