未验证 提交 2de034e1 编写于 作者: Q Qi Li 提交者: GitHub

fix prelu, test=develop (#26613)

上级 7af5cb9b
...@@ -41,6 +41,7 @@ from ...fluid import core ...@@ -41,6 +41,7 @@ from ...fluid import core
from ...fluid.framework import in_dygraph_mode from ...fluid.framework import in_dygraph_mode
from ...fluid.param_attr import ParamAttr from ...fluid.param_attr import ParamAttr
from ...fluid.initializer import Constant from ...fluid.initializer import Constant
from paddle.framework import get_default_dtype
from .. import functional as F from .. import functional as F
...@@ -423,7 +424,7 @@ class PReLU(layers.Layer): ...@@ -423,7 +424,7 @@ class PReLU(layers.Layer):
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Shape: Shape:
- input: Tensor with any shape. - input: Tensor with any shape. Default dtype is float32.
- output: Tensor with the same shape as input. - output: Tensor with the same shape as input.
Examples: Examples:
...@@ -433,13 +434,14 @@ class PReLU(layers.Layer): ...@@ -433,13 +434,14 @@ class PReLU(layers.Layer):
import numpy as np import numpy as np
paddle.disable_static() paddle.disable_static()
paddle.set_default_dtype("float64")
data = np.array([[[[-2.0, 3.0, -4.0, 5.0], data = np.array([[[[-2.0, 3.0, -4.0, 5.0],
[ 3.0, -4.0, 5.0, -6.0], [ 3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]], [-7.0, -8.0, 8.0, 9.0]],
[[ 1.0, -2.0, -3.0, 4.0], [[ 1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0], [-5.0, 6.0, 7.0, -8.0],
[ 6.0, 7.0, 8.0, 9.0]]]], 'float32') [ 6.0, 7.0, 8.0, 9.0]]]], 'float64')
x = paddle.to_tensor(data) x = paddle.to_tensor(data)
m = paddle.nn.PReLU(1, 0.25) m = paddle.nn.PReLU(1, 0.25)
out = m(x) out = m(x)
...@@ -461,10 +463,10 @@ class PReLU(layers.Layer): ...@@ -461,10 +463,10 @@ class PReLU(layers.Layer):
self._weight = self.create_parameter( self._weight = self.create_parameter(
attr=self._weight_attr, attr=self._weight_attr,
shape=[num_parameters], shape=[self._num_parameters],
dtype='float32', dtype=get_default_dtype(),
is_bias=False, is_bias=False,
default_initializer=Constant(init)) default_initializer=Constant(self._init))
def forward(self, x): def forward(self, x):
return F.prelu(x, self._weight) return F.prelu(x, self._weight)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册