未验证 提交 d39f3fb6 编写于 作者: 傅剑寒 提交者: GitHub

(fluid API clear)remove fluid.layers.brelu in nn.py under fluid (#47898)

* remove brelu in nn.py under fluid

* add brelu op test case
上级 5664306b
...@@ -115,7 +115,6 @@ __all__ = [ ...@@ -115,7 +115,6 @@ __all__ = [
'log', 'log',
'crop_tensor', 'crop_tensor',
'prelu', 'prelu',
'brelu',
'flatten', 'flatten',
'pad2d', 'pad2d',
'unique', 'unique',
...@@ -7831,52 +7830,6 @@ def prelu(x, mode, param_attr=None, data_format="NCHW", name=None): ...@@ -7831,52 +7830,6 @@ def prelu(x, mode, param_attr=None, data_format="NCHW", name=None):
return out return out
@templatedoc()
def brelu(x, t_min=0.0, t_max=24.0, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
t_min(${t_min_type}|0.0): ${t_min_comment}
t_max(${t_max_type}|24.0): ${t_max_comment}
name(str|None): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
Returns:
${out_type}: ${out_comment}
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
import numpy as np
paddle.enable_static()
input_brelu = np.array([[-1,6],[1,15.6]])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(input_brelu)
y = fluid.layers.brelu(x, t_min=1.0, t_max=10.0)
print(y.numpy())
#[[ 1. 6.]
#[ 1. 10.]]
"""
if _non_static_mode():
return _legacy_C_ops.brelu(x, 't_min', t_min, 't_max', t_max)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'brelu')
helper = LayerHelper('brelu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='brelu',
inputs={'X': x},
outputs={'Out': out},
attrs={'t_min': t_min, 't_max': t_max},
)
return out
def flatten(x, axis=1, name=None): def flatten(x, axis=1, name=None):
r""" r"""
**Flatten op** **Flatten op**
......
...@@ -63,7 +63,7 @@ class TestBase(IPUOpTest): ...@@ -63,7 +63,7 @@ class TestBase(IPUOpTest):
self.check() self.check()
class TestBReluCase0(TestBase): class TestHardTanhCase0(TestBase):
def set_data_feed(self): def set_data_feed(self):
data = np.random.uniform(size=[1, 3, 10, 10]) * 30 data = np.random.uniform(size=[1, 3, 10, 10]) * 30
self.feed_fp32 = {'in_0': data.astype(np.float32)} self.feed_fp32 = {'in_0': data.astype(np.float32)}
...@@ -71,14 +71,14 @@ class TestBReluCase0(TestBase): ...@@ -71,14 +71,14 @@ class TestBReluCase0(TestBase):
self.feed_list = list(self.feed_fp32.keys()) self.feed_list = list(self.feed_fp32.keys())
def set_test_op(self): def set_test_op(self):
self.op = paddle.fluid.layers.brelu self.op = paddle.nn.functional.hardtanh
self.op_attrs = {} self.op_attrs = {}
class TestBReluCase1(TestBReluCase0): class TestHardTanhCase1(TestHardTanhCase0):
def set_test_op(self): def set_test_op(self):
self.op = paddle.fluid.layers.brelu self.op = paddle.nn.functional.hardtanh
self.op_attrs = {"t_min": 0.1, 't_max': 10.0} self.op_attrs = {"min": 0.1, 'max': 10.0}
class TestEluCase1(TestBase): class TestEluCase1(TestBase):
......
...@@ -1891,51 +1891,6 @@ class TestBRelu(TestActivation): ...@@ -1891,51 +1891,6 @@ class TestBRelu(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestBreluAPI(unittest.TestCase):
# test paddle.fluid.layers.brelu
def setUp(self):
np.random.seed(1024)
self.t_min = 0.0
self.t_max = 24.0
self.x_np = np.random.uniform(-1, 30, [10, 12]).astype('float32')
self.out_ref = np.copy(self.x_np)
self.out_ref[self.out_ref < self.t_min] = self.t_min
self.out_ref[self.out_ref > self.t_max] = self.t_max
self.out_ref = self.out_ref.astype('float32')
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
def test_fluid_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [10, 12])
out = paddle.fluid.layers.brelu(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
np.testing.assert_allclose(self.out_ref, res[0], rtol=1e-05)
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out = paddle.fluid.layers.brelu(x)
np.testing.assert_allclose(self.out_ref, out.numpy(), rtol=1e-05)
paddle.enable_static()
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.brelu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.brelu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float16'
)
fluid.layers.brelu(x_fp16)
def ref_relu6(x, threshold=6.0): def ref_relu6(x, threshold=6.0):
out = np.copy(x) out = np.copy(x)
out[np.abs(x - threshold) < 0.005] = threshold + 0.02 out[np.abs(x - threshold) < 0.005] = threshold + 0.02
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册