未验证 提交 5db88d08 编写于 作者: C Ccc 提交者: GitHub

Several ops support zero dim on GPU and CPU (#49959)

* paddle.nn.functional.softmax
* paddle.nn.functional.log_softmax
* paddle.nn.functional.gumbel_softmax
* paddle.nn.functional.prelu
上级 2b4dd5b9
......@@ -77,10 +77,20 @@ class PreluOpGradFunctor {
for (size_t i = 0; i < input_dims.size(); ++i) {
numel *= input_dims[i];
}
size_t plane_size = numel / input_dims[0] / input_dims[1];
size_t spatial_size = numel / input_dims[0];
size_t channel =
mode == ChannelLast ? input_dims[input_dims.size() - 1] : input_dims[1];
size_t plane_size;
size_t spatial_size;
size_t channel;
if (mode == PRELU_Scalar) {
plane_size = 1;
spatial_size = 1;
channel = 1;
} else {
plane_size = numel / input_dims[0] / input_dims[1];
spatial_size = numel / input_dims[0];
channel = mode == ChannelLast ? input_dims[input_dims.size() - 1]
: input_dims[1];
}
PReluOpGradKernel<T>
<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
......@@ -120,7 +130,6 @@ void PReluGradKernel(const Context& dev_ctx,
int numel = x.numel();
auto dim = x.dims();
auto x_rank = dim.size();
std::vector<int> input_shape = phi::vectorize<int>(dim);
auto stream = dev_ctx.stream();
T* alpha_grad_tmp_ptr;
......
......@@ -84,6 +84,9 @@ unary_api_list = [
paddle.poisson,
paddle.bernoulli,
paddle.median,
paddle.nn.functional.softmax,
paddle.nn.functional.log_softmax,
paddle.nn.functional.gumbel_softmax,
]
inplace_api_list = [
......@@ -1501,6 +1504,26 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_prelu(self):
x = paddle.full([], 1.0, 'float32')
x.stop_gradient = False
w1 = paddle.to_tensor([0.25], dtype='float32')
out1 = paddle.nn.functional.prelu(x, w1)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(x.grad.shape, [])
w2 = paddle.full([], 0.25, dtype='float32')
out2 = paddle.nn.functional.prelu(x, w2)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x.grad.shape, [])
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -2403,6 +2426,38 @@ class TestSundryAPIStatic(unittest.TestCase):
res = self.exe.run(prog, feed={"x": x_tensor}, fetch_list=[out])
self.assertEqual(res[0].shape, (3, 4, 2))
def test_prelu(self):
x1 = paddle.full([], 1.0, 'float32')
x1.stop_gradient = False
w1 = paddle.to_tensor([0.25], dtype='float32')
out1 = paddle.nn.functional.prelu(x1, w1)
paddle.static.append_backward(out1.sum())
x2 = paddle.full([], 1.0, 'float32')
x2.stop_gradient = False
w2 = paddle.full([], 0.25, dtype='float32')
out2 = paddle.nn.functional.prelu(x2, w2)
paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
x1.grad_name,
x2.grad_name,
out1.grad_name,
out2.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
self.assertEqual(res[4].shape, ())
self.assertEqual(res[5].shape, ())
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
......@@ -463,7 +463,7 @@ def prelu(x, weight, data_format="NCHW", name=None):
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
weight (Tensor): The learnable parameter with data type same as ``x``.
The weight shape is [1] or [in], where `in` is the input channel of ``x``.
The weight shape is [], [1] or [in], where `in` is the input channel of ``x``.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
......@@ -495,12 +495,11 @@ def prelu(x, weight, data_format="NCHW", name=None):
# [ 6. , 7. , 8. , 9. ]]]]
"""
assert (
len(weight.shape) == 1
), "The dim count of weight shape should be 1 in prelu()."
len(weight.shape) == 0 or len(weight.shape) == 1
), "The dim count of weight shape should be 0 or 1 in prelu()."
mode = 'all'
if weight.shape[0] > 1:
if len(weight.shape) == 1 and weight.shape[0] > 1:
true_data_format = [
'NC',
'NCL',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册