未验证 提交 cbeff5fc 编写于 作者: C Charles-hit 提交者: GitHub

support activation prim op bf16 dtype (#54193)

* support activation prim op bf16 dtype

* remove useless code
上级 2db64d08
...@@ -3083,11 +3083,15 @@ struct CudaRsqrtFunctor : public BaseActivationFunctor<T> { ...@@ -3083,11 +3083,15 @@ struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
template <typename T> template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> { struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
T minus_one_half = static_cast<T>(-0.5f); using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType minus_one_half = static_cast<MPType>(-0.5f);
// dx = -0.5 * dout * out^3 // dx = -0.5 * dout * out^3
__device__ __forceinline__ T operator()(const T dout, const T out) const { __device__ __forceinline__ T operator()(const T arg_dout,
return minus_one_half * dout * out * out * out; const T arg_out) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
return static_cast<T>(minus_one_half * dout * out * out * out);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { static constexpr ActBwdOpFwdDeps FwdDeps() {
......
...@@ -578,45 +578,45 @@ class PrimForwardChecker: ...@@ -578,45 +578,45 @@ class PrimForwardChecker:
# forward comp only for comp op # forward comp only for comp op
if self.prim_op_type == "prim": if self.prim_op_type == "prim":
return return
paddle.enable_static() with paddle.fluid.framework._static_guard():
core._set_prim_forward_enabled(self.enable_fw_comp) core._set_prim_forward_enabled(self.enable_fw_comp)
startup_program, main_program = ( startup_program, main_program = (
paddle.static.Program(), paddle.static.Program(),
paddle.static.Program(), paddle.static.Program(),
)
with paddle.static.program_guard(main_program, startup_program):
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=True
)
args = OpTestUtils.prepare_python_api_arguments(
self.public_python_api,
static_inputs,
attrs,
self.kernel_sig,
)
inputs_sig, _, _ = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret = flatten(_as_list(self.public_python_api(*args)))
primapi.to_prim(main_program.blocks)
# ensure the operator not in program if check_prim is True
forward_ops = [op.type for op in main_program.blocks[0].ops]
assert self.op_type not in forward_ops, (
"%s shouldn't appear in program when check_prim is True"
) % (self.op_type)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
ret = exe.run(main_program, feed=feed, fetch_list=ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = paddle.utils.map_structure(
lambda x: convert_uint16_to_float(x), ret
) )
with paddle.static.program_guard(main_program, startup_program):
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=True
)
args = OpTestUtils.prepare_python_api_arguments(
self.public_python_api,
static_inputs,
attrs,
self.kernel_sig,
)
inputs_sig, _, _ = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret = flatten(_as_list(self.public_python_api(*args)))
primapi.to_prim(main_program.blocks)
# ensure the operator not in program if check_prim is True
forward_ops = [op.type for op in main_program.blocks[0].ops]
assert self.op_type not in forward_ops, (
"%s shouldn't appear in program when check_prim is True"
) % (self.op_type)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
ret = exe.run(main_program, feed=feed, fetch_list=ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = paddle.utils.map_structure(
lambda x: convert_uint16_to_float(x), ret
)
# check static forward # check static forward
if len(ret) != len(self.eager_desire): if len(ret) != len(self.eager_desire):
msg = ( msg = (
...@@ -1024,7 +1024,6 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1024,7 +1024,6 @@ class PrimGradChecker(PrimForwardChecker):
core.set_prim_eager_enabled(False) core.set_prim_eager_enabled(False)
def check_static_comp(self): def check_static_comp(self):
paddle.enable_static()
if self.prim_op_type == "prim": if self.prim_op_type == "prim":
core._set_prim_backward_enabled(self.enable_rev_comp) core._set_prim_backward_enabled(self.enable_rev_comp)
else: else:
...@@ -1032,67 +1031,70 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1032,67 +1031,70 @@ class PrimGradChecker(PrimForwardChecker):
core._set_prim_backward_enabled(self.enable_rev_comp) core._set_prim_backward_enabled(self.enable_rev_comp)
atol = self.rev_comp_atol if self.enable_rev_comp else self.fw_comp_atol atol = self.rev_comp_atol if self.enable_rev_comp else self.fw_comp_atol
rtol = self.rev_comp_rtol if self.enable_rev_comp else self.fw_comp_rtol rtol = self.rev_comp_rtol if self.enable_rev_comp else self.fw_comp_rtol
startup_program, main_program = ( with paddle.fluid.framework._static_guard():
paddle.static.Program(), startup_program, main_program = (
paddle.static.Program(), paddle.static.Program(),
) paddle.static.Program(),
with paddle.static.program_guard(main_program, startup_program):
(
static_inputs,
attrs,
inputs_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=False
)
args = OpTestUtils.prepare_python_api_arguments(
self.public_python_api,
static_inputs,
attrs,
self.kernel_sig,
)
inputs_sig, _, outputs_sig = self.kernel_sig
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
fw_outs = _as_list(self.public_python_api(*args))
outputs_dict = self.get_output_dict(
self.outputs, fw_outs, outputs_sig
)
primapi.to_prim(main_program.blocks)
ys = []
if isinstance(self.output_names, list):
for output_name in self.output_names:
ys.append(outputs_dict[output_name])
else:
ys.append(outputs_dict[self.output_names])
xs = []
if isinstance(self.inputs_to_check, list):
for input_name in self.inputs_to_check:
xs.append(inputs_dict[input_name])
else:
xs.append(inputs_dict[self.inputs_to_check])
vs, vs_feed = self.gen_static_grad_outputs_and_feed()
feed.update(vs_feed)
no_grad_vars = self.gen_no_grad_set(
var_dict={**inputs_dict, **outputs_dict}
)
ret = paddle.static.gradients(ys, xs, vs, no_grad_set=no_grad_vars)
# check the backward operator not in program when check_prim is True
ops = [op.type for op in main_program.blocks[0].ops]
backward_op_type = self.op_type + "_grad"
assert backward_op_type not in ops, (
"%s shouldn't appear in program when check_prim is True"
) % (backward_op_type)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
actual_ret = exe.run(main_program, feed=feed, fetch_list=ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
actual_ret = paddle.utils.map_structure(
lambda x: convert_uint16_to_float(x), actual_ret
) )
with paddle.static.program_guard(main_program, startup_program):
(
static_inputs,
attrs,
inputs_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=False
)
args = OpTestUtils.prepare_python_api_arguments(
self.public_python_api,
static_inputs,
attrs,
self.kernel_sig,
)
inputs_sig, _, outputs_sig = self.kernel_sig
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
fw_outs = _as_list(self.public_python_api(*args))
outputs_dict = self.get_output_dict(
self.outputs, fw_outs, outputs_sig
)
primapi.to_prim(main_program.blocks)
ys = []
if isinstance(self.output_names, list):
for output_name in self.output_names:
ys.append(outputs_dict[output_name])
else:
ys.append(outputs_dict[self.output_names])
xs = []
if isinstance(self.inputs_to_check, list):
for input_name in self.inputs_to_check:
xs.append(inputs_dict[input_name])
else:
xs.append(inputs_dict[self.inputs_to_check])
vs, vs_feed = self.gen_static_grad_outputs_and_feed()
feed.update(vs_feed)
no_grad_vars = self.gen_no_grad_set(
var_dict={**inputs_dict, **outputs_dict}
)
ret = paddle.static.gradients(
ys, xs, vs, no_grad_set=no_grad_vars
)
# check the backward operator not in program when check_prim is True
ops = [op.type for op in main_program.blocks[0].ops]
backward_op_type = self.op_type + "_grad"
assert backward_op_type not in ops, (
"%s shouldn't appear in program when check_prim is True"
) % (backward_op_type)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
actual_ret = exe.run(main_program, feed=feed, fetch_list=ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
actual_ret = paddle.utils.map_structure(
lambda x: convert_uint16_to_float(x), actual_ret
)
# check static grad out # check static grad out
if len(actual_ret) != len(self.eager_desire): if len(actual_ret) != len(self.eager_desire):
msg = ( msg = (
......
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
import warnings import warnings
import numpy as np import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16, paddle_static_guard from eager_op_test import OpTest, convert_float_to_uint16
from scipy.special import erf, expit from scipy.special import erf, expit
import paddle import paddle
...@@ -29,7 +29,7 @@ from paddle.fluid.layer_helper import LayerHelper ...@@ -29,7 +29,7 @@ from paddle.fluid.layer_helper import LayerHelper
class TestSqrtOpError(unittest.TestCase): class TestSqrtOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
# The input type of sqrt op must be Variable or numpy.ndarray. # The input type of sqrt op must be Variable or numpy.ndarray.
in1 = 1 in1 = 1
...@@ -49,10 +49,10 @@ class TestSqrtOpError(unittest.TestCase): ...@@ -49,10 +49,10 @@ class TestSqrtOpError(unittest.TestCase):
class TestActivation(OpTest): class TestActivation(OpTest):
def setUp(self): def setUp(self):
self.op_type = "exp" self.op_type = "exp"
self.prim_op_type = "prim"
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.init_kernel_type() self.init_kernel_type()
self.if_enable_cinn()
self.python_api = paddle.exp self.python_api = paddle.exp
self.public_python_api = paddle.exp self.public_python_api = paddle.exp
...@@ -88,6 +88,9 @@ class TestActivation(OpTest): ...@@ -88,6 +88,9 @@ class TestActivation(OpTest):
def convert_input_output(self): def convert_input_output(self):
pass pass
def if_enable_cinn(self):
pass
class TestActivation_ZeroDim(TestActivation): class TestActivation_ZeroDim(TestActivation):
def init_shape(self): def init_shape(self):
...@@ -124,7 +127,7 @@ class TestExpFp32_Prim(OpTest): ...@@ -124,7 +127,7 @@ class TestExpFp32_Prim(OpTest):
self.shape = [12, 17] self.shape = [12, 17]
def if_enable_cinn(self): def if_enable_cinn(self):
self.enable_cinn = True pass
class TestExpFp64_Prim(TestExpFp32_Prim): class TestExpFp64_Prim(TestExpFp32_Prim):
...@@ -183,7 +186,7 @@ class TestExpm1API(unittest.TestCase): ...@@ -183,7 +186,7 @@ class TestExpm1API(unittest.TestCase):
def test_static_api(self): def test_static_api(self):
def run(place): def run(place):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
X = paddle.static.data('X', self.shape, dtype=self.dtype) X = paddle.static.data('X', self.shape, dtype=self.dtype)
out = paddle.expm1(X) out = paddle.expm1(X)
...@@ -205,7 +208,7 @@ class TestExpm1API(unittest.TestCase): ...@@ -205,7 +208,7 @@ class TestExpm1API(unittest.TestCase):
run(place) run(place)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
X = paddle.static.data('X', self.shape, dtype='int32') X = paddle.static.data('X', self.shape, dtype='int32')
self.assertRaises(TypeError, paddle.expm1, X) self.assertRaises(TypeError, paddle.expm1, X)
...@@ -214,7 +217,7 @@ class TestExpm1API(unittest.TestCase): ...@@ -214,7 +217,7 @@ class TestExpm1API(unittest.TestCase):
class TestParameter: class TestParameter:
def test_out_name(self): def test_out_name(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
np_x = np.array([0.1]).astype('float32').reshape((-1, 1)) np_x = np.array([0.1]).astype('float32').reshape((-1, 1))
data = paddle.static.data( data = paddle.static.data(
...@@ -240,12 +243,11 @@ class TestSigmoid(TestActivation): ...@@ -240,12 +243,11 @@ class TestSigmoid(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "sigmoid" self.op_type = "sigmoid"
self.prim_op_type = "comp" self.prim_op_type = "comp"
self.enable_cinn = False
self.python_api = paddle.nn.functional.sigmoid self.python_api = paddle.nn.functional.sigmoid
self.public_python_api = paddle.nn.functional.sigmoid self.public_python_api = paddle.nn.functional.sigmoid
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = 1 / (1 + np.exp(-x)) out = 1 / (1 + np.exp(-x))
...@@ -258,6 +260,9 @@ class TestSigmoid(TestActivation): ...@@ -258,6 +260,9 @@ class TestSigmoid(TestActivation):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def if_enable_cinn(self):
pass
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
...@@ -268,6 +273,9 @@ class TestSigmoid_ZeroDim(TestSigmoid): ...@@ -268,6 +273,9 @@ class TestSigmoid_ZeroDim(TestSigmoid):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
...@@ -276,11 +284,11 @@ class TestSigmoidBF16(OpTest): ...@@ -276,11 +284,11 @@ class TestSigmoidBF16(OpTest):
def setUp(self): def setUp(self):
self.op_type = "sigmoid" self.op_type = "sigmoid"
self.prim_op_type = "comp" self.prim_op_type = "comp"
self.enable_cinn = False
self.python_api = paddle.nn.functional.sigmoid self.python_api = paddle.nn.functional.sigmoid
self.public_python_api = paddle.nn.functional.sigmoid self.public_python_api = paddle.nn.functional.sigmoid
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(np.float32) x = np.random.uniform(-1, 1, self.shape).astype(np.float32)
out = 1 / (1 + np.exp(-x)) out = 1 / (1 + np.exp(-x))
...@@ -296,14 +304,17 @@ class TestSigmoidBF16(OpTest): ...@@ -296,14 +304,17 @@ class TestSigmoidBF16(OpTest):
def init_shape(self): def init_shape(self):
self.shape = [11, 17] self.shape = [11, 17]
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
# elementwise_pow doesn't support bfloat16, skip check_prim here. # elementwise_pow doesn't support bfloat16, skip check_prim here.
self.check_output_with_place(place) self.check_output_with_place(place, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)
''' '''
...@@ -318,7 +329,6 @@ class TestSilu(TestActivation): ...@@ -318,7 +329,6 @@ class TestSilu(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "silu" self.op_type = "silu"
self.prim_op_type = "comp" self.prim_op_type = "comp"
self.enable_cinn = True
self.python_api = paddle.nn.functional.silu self.python_api = paddle.nn.functional.silu
self.public_python_api = paddle.nn.functional.silu self.public_python_api = paddle.nn.functional.silu
self.init_dtype() self.init_dtype()
...@@ -362,7 +372,7 @@ class TestSiluAPI(unittest.TestCase): ...@@ -362,7 +372,7 @@ class TestSiluAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [11, 17]) x = paddle.static.data('X', [11, 17])
out1 = F.silu(x) out1 = F.silu(x)
...@@ -384,7 +394,7 @@ class TestSiluAPI(unittest.TestCase): ...@@ -384,7 +394,7 @@ class TestSiluAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.silu, 1) self.assertRaises(TypeError, F.silu, 1)
...@@ -438,7 +448,7 @@ class TestLogSigmoidAPI(unittest.TestCase): ...@@ -438,7 +448,7 @@ class TestLogSigmoidAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [11, 17]) x = paddle.static.data('X', [11, 17])
out1 = F.log_sigmoid(x) out1 = F.log_sigmoid(x)
...@@ -460,7 +470,7 @@ class TestLogSigmoidAPI(unittest.TestCase): ...@@ -460,7 +470,7 @@ class TestLogSigmoidAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.log_sigmoid, 1) self.assertRaises(TypeError, F.log_sigmoid, 1)
...@@ -533,7 +543,7 @@ class TestTanhAPI(unittest.TestCase): ...@@ -533,7 +543,7 @@ class TestTanhAPI(unittest.TestCase):
self.tanh = F.tanh self.tanh = F.tanh
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [10, 12], self.dtype) x = paddle.static.data('X', [10, 12], self.dtype)
out1 = self.tanh(x) out1 = self.tanh(x)
...@@ -556,7 +566,7 @@ class TestTanhAPI(unittest.TestCase): ...@@ -556,7 +566,7 @@ class TestTanhAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, self.tanh, 1) self.assertRaises(TypeError, self.tanh, 1)
...@@ -599,7 +609,7 @@ class TestAtan(TestActivation, TestParameter): ...@@ -599,7 +609,7 @@ class TestAtan(TestActivation, TestParameter):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def test_out_name(self): def test_out_name(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
np_x = np.array([0.1]).astype('float32').reshape((-1, 1)) np_x = np.array([0.1]).astype('float32').reshape((-1, 1))
data = paddle.static.data( data = paddle.static.data(
...@@ -662,7 +672,7 @@ class TestSinhAPI(unittest.TestCase): ...@@ -662,7 +672,7 @@ class TestSinhAPI(unittest.TestCase):
np.testing.assert_allclose(z, z_expected, rtol=1e-05) np.testing.assert_allclose(z, z_expected, rtol=1e-05)
def test_api(self): def test_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
test_data_shape = [11, 17] test_data_shape = [11, 17]
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
input_x = np.random.uniform(0.1, 1, test_data_shape).astype( input_x = np.random.uniform(0.1, 1, test_data_shape).astype(
...@@ -702,7 +712,7 @@ class TestSinhAPI(unittest.TestCase): ...@@ -702,7 +712,7 @@ class TestSinhAPI(unittest.TestCase):
class TestSinhOpError(unittest.TestCase): class TestSinhOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with program_guard(Program()): with program_guard(Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, paddle.sinh, 1) self.assertRaises(TypeError, paddle.sinh, 1)
...@@ -754,7 +764,7 @@ class TestCoshAPI(unittest.TestCase): ...@@ -754,7 +764,7 @@ class TestCoshAPI(unittest.TestCase):
np.testing.assert_allclose(z, z_expected, rtol=1e-05) np.testing.assert_allclose(z, z_expected, rtol=1e-05)
def test_api(self): def test_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
test_data_shape = [11, 17] test_data_shape = [11, 17]
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
input_x = np.random.uniform(0.1, 1, test_data_shape).astype( input_x = np.random.uniform(0.1, 1, test_data_shape).astype(
...@@ -794,7 +804,7 @@ class TestCoshAPI(unittest.TestCase): ...@@ -794,7 +804,7 @@ class TestCoshAPI(unittest.TestCase):
class TestCoshOpError(unittest.TestCase): class TestCoshOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with program_guard(Program()): with program_guard(Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, paddle.cosh, 1) self.assertRaises(TypeError, paddle.cosh, 1)
...@@ -853,7 +863,7 @@ class TestTanhshrinkAPI(unittest.TestCase): ...@@ -853,7 +863,7 @@ class TestTanhshrinkAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.tanhshrink(x) out1 = F.tanhshrink(x)
...@@ -875,7 +885,7 @@ class TestTanhshrinkAPI(unittest.TestCase): ...@@ -875,7 +885,7 @@ class TestTanhshrinkAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.tanhshrink, 1) self.assertRaises(TypeError, F.tanhshrink, 1)
...@@ -953,7 +963,7 @@ class TestHardShrinkAPI(unittest.TestCase): ...@@ -953,7 +963,7 @@ class TestHardShrinkAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [10, 12], dtype="float32") x = paddle.static.data('X', [10, 12], dtype="float32")
out1 = F.hardshrink(x) out1 = F.hardshrink(x)
...@@ -982,7 +992,7 @@ class TestHardShrinkAPI(unittest.TestCase): ...@@ -982,7 +992,7 @@ class TestHardShrinkAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.hardshrink, 1) self.assertRaises(TypeError, F.hardshrink, 1)
...@@ -1018,7 +1028,7 @@ class TestHardtanhAPI(unittest.TestCase): ...@@ -1018,7 +1028,7 @@ class TestHardtanhAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [10, 12], dtype="float32") x = paddle.static.data('X', [10, 12], dtype="float32")
out1 = F.hardtanh(x) out1 = F.hardtanh(x)
...@@ -1047,7 +1057,7 @@ class TestHardtanhAPI(unittest.TestCase): ...@@ -1047,7 +1057,7 @@ class TestHardtanhAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.hardtanh, 1) self.assertRaises(TypeError, F.hardtanh, 1)
...@@ -1113,7 +1123,7 @@ class TestSoftshrinkAPI(unittest.TestCase): ...@@ -1113,7 +1123,7 @@ class TestSoftshrinkAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.softshrink(x, self.threshold) out1 = F.softshrink(x, self.threshold)
...@@ -1135,7 +1145,7 @@ class TestSoftshrinkAPI(unittest.TestCase): ...@@ -1135,7 +1145,7 @@ class TestSoftshrinkAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.softshrink, 1) self.assertRaises(TypeError, F.softshrink, 1)
...@@ -1165,6 +1175,7 @@ class TestSqrt(TestActivation, TestParameter): ...@@ -1165,6 +1175,7 @@ class TestSqrt(TestActivation, TestParameter):
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1023) np.random.seed(1023)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
...@@ -1173,13 +1184,14 @@ class TestSqrt(TestActivation, TestParameter): ...@@ -1173,13 +1184,14 @@ class TestSqrt(TestActivation, TestParameter):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.convert_input_output() self.convert_input_output()
self.enable_cinn = False
# TODO(wanghao107) add prim test def if_enable_cinn(self):
pass
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -1193,13 +1205,13 @@ class TestSqrtPrimFp32(TestActivation): ...@@ -1193,13 +1205,13 @@ class TestSqrtPrimFp32(TestActivation):
self.public_python_api = paddle.sqrt self.public_python_api = paddle.sqrt
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1023) np.random.seed(1023)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.sqrt(x) out = np.sqrt(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.enable_cinn = True
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
...@@ -1212,26 +1224,17 @@ class TestSqrtPrimFp32(TestActivation): ...@@ -1212,26 +1224,17 @@ class TestSqrtPrimFp32(TestActivation):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def if_enable_cinn(self):
pass
class TestSqrt_ZeroDim(TestSqrt): class TestSqrt_ZeroDim(TestSqrt):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
self.enable_cinn = False
class TestSqrtPrim_ZeroDim(TestSqrt): def if_enable_cinn(self):
def init_shape(self):
self.shape = []
self.enable_cinn = False self.enable_cinn = False
def init_dtype(self):
self.dtype = np.float32
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_prim=True)
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
...@@ -1244,6 +1247,7 @@ class TestSqrtBF16(OpTest): ...@@ -1244,6 +1247,7 @@ class TestSqrtBF16(OpTest):
self.public_python_api = paddle.sqrt self.public_python_api = paddle.sqrt
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1023) np.random.seed(1023)
x = np.random.uniform(0.1, 1, self.shape).astype(np.float32) x = np.random.uniform(0.1, 1, self.shape).astype(np.float32)
...@@ -1253,7 +1257,6 @@ class TestSqrtBF16(OpTest): ...@@ -1253,7 +1257,6 @@ class TestSqrtBF16(OpTest):
'X': OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(x)) 'X': OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(x))
} }
self.outputs = {'Out': convert_float_to_uint16(out)} self.outputs = {'Out': convert_float_to_uint16(out)}
self.enable_cinn = False
def init_dtype(self): def init_dtype(self):
self.dtype = np.uint16 self.dtype = np.uint16
...@@ -1261,6 +1264,9 @@ class TestSqrtBF16(OpTest): ...@@ -1261,6 +1264,9 @@ class TestSqrtBF16(OpTest):
def init_shape(self): def init_shape(self):
self.shape = [11, 17] self.shape = [11, 17]
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place)
...@@ -1278,6 +1284,7 @@ class TestSqrtComp(TestActivation, TestParameter): ...@@ -1278,6 +1284,7 @@ class TestSqrtComp(TestActivation, TestParameter):
self.public_python_api = paddle.sqrt self.public_python_api = paddle.sqrt
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1023) np.random.seed(1023)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
...@@ -1286,7 +1293,9 @@ class TestSqrtComp(TestActivation, TestParameter): ...@@ -1286,7 +1293,9 @@ class TestSqrtComp(TestActivation, TestParameter):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.convert_input_output() self.convert_input_output()
self.enable_cinn = True
def if_enable_cinn(self):
pass
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
...@@ -1305,13 +1314,16 @@ class TestSqrtCompFp32(TestActivation): ...@@ -1305,13 +1314,16 @@ class TestSqrtCompFp32(TestActivation):
self.public_python_api = paddle.sqrt self.public_python_api = paddle.sqrt
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1023) np.random.seed(1023)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.sqrt(x) out = np.sqrt(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.enable_cinn = True
def if_enable_cinn(self):
pass
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
...@@ -1333,19 +1345,22 @@ class TestRsqrt(TestActivation): ...@@ -1333,19 +1345,22 @@ class TestRsqrt(TestActivation):
self.public_python_api = paddle.rsqrt self.public_python_api = paddle.rsqrt
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) * 10 x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = 1.0 / np.sqrt(x) out = 1.0 / np.sqrt(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.convert_input_output() self.convert_input_output()
self.enable_cinn = True
def init_shape(self): def init_shape(self):
self.shape = [10, 12] self.shape = [10, 12]
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output(check_prim=True)
...@@ -1360,12 +1375,12 @@ class TestRsqrt(TestActivation): ...@@ -1360,12 +1375,12 @@ class TestRsqrt(TestActivation):
) )
'''
class TestRsqrt_ZeroDim(TestRsqrt): class TestRsqrt_ZeroDim(TestRsqrt):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
'''
def if_enable_cinn(self):
self.enable_cinn = False
class TestAbs(TestActivation): class TestAbs(TestActivation):
...@@ -1374,9 +1389,9 @@ class TestAbs(TestActivation): ...@@ -1374,9 +1389,9 @@ class TestAbs(TestActivation):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = paddle.abs self.python_api = paddle.abs
self.public_python_api = paddle.abs self.public_python_api = paddle.abs
self.enable_cinn = False
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
...@@ -1394,6 +1409,9 @@ class TestAbs(TestActivation): ...@@ -1394,6 +1409,9 @@ class TestAbs(TestActivation):
def init_shape(self): def init_shape(self):
self.shape = [4, 25] self.shape = [4, 25]
def if_enable_cinn(self):
pass
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
...@@ -1404,6 +1422,9 @@ class TestAbs_ZeroDim(TestAbs): ...@@ -1404,6 +1422,9 @@ class TestAbs_ZeroDim(TestAbs):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestCeil(TestActivation): class TestCeil(TestActivation):
def setUp(self): def setUp(self):
...@@ -1441,6 +1462,7 @@ class TestFloor(TestActivation): ...@@ -1441,6 +1462,7 @@ class TestFloor(TestActivation):
self.public_python_api = paddle.floor self.public_python_api = paddle.floor
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
...@@ -1453,57 +1475,36 @@ class TestFloor(TestActivation): ...@@ -1453,57 +1475,36 @@ class TestFloor(TestActivation):
def init_shape(self): def init_shape(self):
self.shape = [10, 12] self.shape = [10, 12]
def if_enable_cinn(self):
pass
# the gradient on floor, ceil, round is undefined. # the gradient on floor, ceil, round is undefined.
# we return zero as gradient, but the numpy return nan # we return zero as gradient, but the numpy return nan
# The same reason with TestFloor # The same reason with TestFloor
def test_check_grad(self): def test_check_grad(self):
pass pass
def test_check_grad_for_prim(self):
class TestFloor_ZeroDim(TestFloor):
def init_shape(self):
self.shape = []
class TestFloor_Prim(TestActivation):
def setUp(self):
self.op_type = "floor"
self.prim_op_type = "prim"
self.python_api = paddle.floor
self.public_python_api = paddle.floor
self.init_dtype()
self.init_shape()
if len(self.shape) == 0:
# for 0-D tensor, skip cinn testing
self.enable_cinn = False
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = np.floor(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def init_shape(self):
self.shape = [10, 12]
def test_check_grad(self):
# the gradient on floor, ceil, round is undefined. # the gradient on floor, ceil, round is undefined.
# we return zero as gradient, but the numpy return nan. # we return zero as gradient, but the numpy return nan.
# for prim, we compare result with eager python api, # for prim, we compare result with eager python api,
# so, we use only_prim flag to express we only test prim. # so, we use only_prim flag to express we only test prim.
self.check_grad(['X'], 'Out', check_prim=True, only_check_prim=True) if core.is_compiled_with_cuda():
self.check_grad_with_place(
paddle.CUDAPlace(0),
['X'],
'Out',
check_prim=True,
only_check_prim=True,
)
class TestFloor_ZeroDim_Prim(TestFloor_Prim): class TestFloor_ZeroDim(TestFloor):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
class TestFloorFp16_Prim(TestFloor_Prim): self.enable_cinn = False
def init_dtype(self):
self.dtype = np.float16
class TestCos(TestActivation): class TestCos(TestActivation):
...@@ -1592,7 +1593,7 @@ class TestTanAPI(unittest.TestCase): ...@@ -1592,7 +1593,7 @@ class TestTanAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, out_test.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, out_test.numpy(), rtol=1e-05)
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [11, 17], self.dtype) x = paddle.static.data('X', [11, 17], self.dtype)
out = paddle.tan(x) out = paddle.tan(x)
...@@ -1827,7 +1828,7 @@ class TestRelu(TestActivation): ...@@ -1827,7 +1828,7 @@ class TestRelu(TestActivation):
self.public_python_api = paddle.nn.functional.relu self.public_python_api = paddle.nn.functional.relu
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.skip_cinn() self.if_enable_cinn()
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
...@@ -1847,15 +1848,15 @@ class TestRelu(TestActivation): ...@@ -1847,15 +1848,15 @@ class TestRelu(TestActivation):
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output(check_prim=True)
def skip_cinn(self): def if_enable_cinn(self):
self.enable_cinn = False pass
class TestRelu_ZeroDim(TestRelu): class TestRelu_ZeroDim(TestRelu):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def skip_cinn(self): def if_enable_cinn(self):
self.enable_cinn = False self.enable_cinn = False
...@@ -1875,7 +1876,7 @@ class TestReluAPI(unittest.TestCase): ...@@ -1875,7 +1876,7 @@ class TestReluAPI(unittest.TestCase):
self.relu = F.relu self.relu = F.relu
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [10, 12], dtype="float32") x = paddle.static.data('X', [10, 12], dtype="float32")
out1 = self.relu(x) out1 = self.relu(x)
...@@ -1897,8 +1898,8 @@ class TestReluAPI(unittest.TestCase): ...@@ -1897,8 +1898,8 @@ class TestReluAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, self.relu, 1) self.assertRaises(TypeError, self.relu, 1)
...@@ -1937,6 +1938,7 @@ class TestLeakyRelu(TestActivation): ...@@ -1937,6 +1938,7 @@ class TestLeakyRelu(TestActivation):
self.prim_op_type = "comp" self.prim_op_type = "comp"
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
alpha = self.get_alpha() alpha = self.get_alpha()
np.random.seed(1024) np.random.seed(1024)
...@@ -1950,6 +1952,9 @@ class TestLeakyRelu(TestActivation): ...@@ -1950,6 +1952,9 @@ class TestLeakyRelu(TestActivation):
self.attrs = {'alpha': alpha} self.attrs = {'alpha': alpha}
self.convert_input_output() self.convert_input_output()
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output(check_prim=True)
...@@ -1978,25 +1983,8 @@ class TestLeakyRelu_ZeroDim(TestLeakyRelu): ...@@ -1978,25 +1983,8 @@ class TestLeakyRelu_ZeroDim(TestLeakyRelu):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def setUp(self): def if_enable_cinn(self):
self.op_type = "leaky_relu"
self.prim_op_type = "comp"
self.enable_cinn = False self.enable_cinn = False
self.python_api = paddle.nn.functional.leaky_relu
self.public_python_api = paddle.nn.functional.relu
self.init_dtype()
self.init_shape()
alpha = self.get_alpha()
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.05
out = ref_leaky_relu(x, alpha)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'alpha': alpha}
class TestLeakyReluAPI(unittest.TestCase): class TestLeakyReluAPI(unittest.TestCase):
...@@ -2011,7 +1999,7 @@ class TestLeakyReluAPI(unittest.TestCase): ...@@ -2011,7 +1999,7 @@ class TestLeakyReluAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [10, 12], dtype="float32") x = paddle.static.data('X', [10, 12], dtype="float32")
out1 = F.leaky_relu(x) out1 = F.leaky_relu(x)
...@@ -2040,7 +2028,7 @@ class TestLeakyReluAPI(unittest.TestCase): ...@@ -2040,7 +2028,7 @@ class TestLeakyReluAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.leaky_relu, 1) self.assertRaises(TypeError, F.leaky_relu, 1)
...@@ -2169,7 +2157,7 @@ class TestGELUAPI(unittest.TestCase): ...@@ -2169,7 +2157,7 @@ class TestGELUAPI(unittest.TestCase):
self.rev_comp_atol = 1e-8 self.rev_comp_atol = 1e-8
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [11, 17], dtype="float32") x = paddle.static.data('X', [11, 17], dtype="float32")
out1 = F.gelu(x) out1 = F.gelu(x)
...@@ -2198,7 +2186,7 @@ class TestGELUAPI(unittest.TestCase): ...@@ -2198,7 +2186,7 @@ class TestGELUAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.gelu, 1) self.assertRaises(TypeError, F.gelu, 1)
...@@ -2294,7 +2282,7 @@ class TestRelu6API(unittest.TestCase): ...@@ -2294,7 +2282,7 @@ class TestRelu6API(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.relu6(x) out1 = F.relu6(x)
...@@ -2316,7 +2304,7 @@ class TestRelu6API(unittest.TestCase): ...@@ -2316,7 +2304,7 @@ class TestRelu6API(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_fluid_api(self): def test_fluid_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out = paddle.nn.functional.relu6(x) out = paddle.nn.functional.relu6(x)
...@@ -2326,7 +2314,7 @@ class TestRelu6API(unittest.TestCase): ...@@ -2326,7 +2314,7 @@ class TestRelu6API(unittest.TestCase):
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.relu6, 1) self.assertRaises(TypeError, F.relu6, 1)
...@@ -2344,7 +2332,7 @@ class TestRelu6API(unittest.TestCase): ...@@ -2344,7 +2332,7 @@ class TestRelu6API(unittest.TestCase):
class TestRelu6APIWarnings(unittest.TestCase): class TestRelu6APIWarnings(unittest.TestCase):
def test_warnings(self): def test_warnings(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with warnings.catch_warnings(record=True) as context: with warnings.catch_warnings(record=True) as context:
warnings.simplefilter("always") warnings.simplefilter("always")
...@@ -2442,7 +2430,7 @@ class TestHardswishAPI(unittest.TestCase): ...@@ -2442,7 +2430,7 @@ class TestHardswishAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.hardswish(x) out1 = F.hardswish(x)
...@@ -2464,7 +2452,7 @@ class TestHardswishAPI(unittest.TestCase): ...@@ -2464,7 +2452,7 @@ class TestHardswishAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_fluid_api(self): def test_fluid_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out = paddle.nn.functional.hardswish(x) out = paddle.nn.functional.hardswish(x)
...@@ -2478,7 +2466,7 @@ class TestHardswishAPI(unittest.TestCase): ...@@ -2478,7 +2466,7 @@ class TestHardswishAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.hardswish, 1) self.assertRaises(TypeError, F.hardswish, 1)
...@@ -2588,7 +2576,7 @@ class TestELUAPI(unittest.TestCase): ...@@ -2588,7 +2576,7 @@ class TestELUAPI(unittest.TestCase):
self.elu = F.elu self.elu = F.elu
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [10, 12], dtype="float32") x = paddle.static.data('X', [10, 12], dtype="float32")
out1 = self.elu(x) out1 = self.elu(x)
...@@ -2619,7 +2607,7 @@ class TestELUAPI(unittest.TestCase): ...@@ -2619,7 +2607,7 @@ class TestELUAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, self.elu, 1) self.assertRaises(TypeError, self.elu, 1)
...@@ -2697,7 +2685,7 @@ class TestCELUAPI(unittest.TestCase): ...@@ -2697,7 +2685,7 @@ class TestCELUAPI(unittest.TestCase):
self.celu = F.celu self.celu = F.celu
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [10, 12], dtype="float32") x = paddle.static.data('X', [10, 12], dtype="float32")
out1 = self.celu(x, 1.5) out1 = self.celu(x, 1.5)
...@@ -2728,7 +2716,7 @@ class TestCELUAPI(unittest.TestCase): ...@@ -2728,7 +2716,7 @@ class TestCELUAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, self.celu, 1) self.assertRaises(TypeError, self.celu, 1)
...@@ -2786,10 +2774,7 @@ class TestLog(TestActivation): ...@@ -2786,10 +2774,7 @@ class TestLog(TestActivation):
self.public_python_api = paddle.log self.public_python_api = paddle.log
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
if len(self.shape) == 0:
# for 0-D tensor, skip cinn testing
self.enable_cinn = False
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
...@@ -2799,14 +2784,17 @@ class TestLog(TestActivation): ...@@ -2799,14 +2784,17 @@ class TestLog(TestActivation):
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.convert_input_output() self.convert_input_output()
def if_enable_cinn(self):
pass
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True)
def test_error(self): def test_error(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
in1 = paddle.static.data( in1 = paddle.static.data(
name="in1", shape=[11, 17], dtype="int32" name="in1", shape=[11, 17], dtype="int32"
) )
...@@ -2820,7 +2808,7 @@ class TestLog(TestActivation): ...@@ -2820,7 +2808,7 @@ class TestLog(TestActivation):
class Test_Log_Op_Fp16(unittest.TestCase): class Test_Log_Op_Fp16(unittest.TestCase):
def test_api_fp16(self): def test_api_fp16(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with static.program_guard( with static.program_guard(
paddle.static.Program(), paddle.static.Program() paddle.static.Program(), paddle.static.Program()
): ):
...@@ -2837,6 +2825,9 @@ class TestLog_ZeroDim(TestLog): ...@@ -2837,6 +2825,9 @@ class TestLog_ZeroDim(TestLog):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestLog2(TestActivation): class TestLog2(TestActivation):
def setUp(self): def setUp(self):
...@@ -2858,7 +2849,7 @@ class TestLog2(TestActivation): ...@@ -2858,7 +2849,7 @@ class TestLog2(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def test_error(self): def test_error(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32") in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32")
in2 = paddle.static.data(name="in2", shape=[11, 17], dtype="int64") in2 = paddle.static.data(name="in2", shape=[11, 17], dtype="int64")
...@@ -2866,7 +2857,7 @@ class TestLog2(TestActivation): ...@@ -2866,7 +2857,7 @@ class TestLog2(TestActivation):
self.assertRaises(TypeError, paddle.log2, in2) self.assertRaises(TypeError, paddle.log2, in2)
def test_api(self): def test_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard( with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program() paddle.static.Program(), paddle.static.Program()
): ):
...@@ -2928,7 +2919,7 @@ class TestLog10_ZeroDim(TestLog10): ...@@ -2928,7 +2919,7 @@ class TestLog10_ZeroDim(TestLog10):
class TestLog10API(unittest.TestCase): class TestLog10API(unittest.TestCase):
def test_error(self): def test_error(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32") in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32")
in2 = paddle.static.data(name="in2", shape=[11, 17], dtype="int64") in2 = paddle.static.data(name="in2", shape=[11, 17], dtype="int64")
...@@ -2936,7 +2927,7 @@ class TestLog10API(unittest.TestCase): ...@@ -2936,7 +2927,7 @@ class TestLog10API(unittest.TestCase):
self.assertRaises(TypeError, paddle.log10, in2) self.assertRaises(TypeError, paddle.log10, in2)
def test_api(self): def test_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard( with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program() paddle.static.Program(), paddle.static.Program()
): ):
...@@ -2989,7 +2980,7 @@ class TestLog1p(TestActivation): ...@@ -2989,7 +2980,7 @@ class TestLog1p(TestActivation):
class Test_Log1p_Op_Fp16(unittest.TestCase): class Test_Log1p_Op_Fp16(unittest.TestCase):
def test_api_fp16(self): def test_api_fp16(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with static.program_guard( with static.program_guard(
paddle.static.Program(), paddle.static.Program() paddle.static.Program(), paddle.static.Program()
): ):
...@@ -3009,7 +3000,7 @@ class TestLog1p_ZeroDim(TestLog1p): ...@@ -3009,7 +3000,7 @@ class TestLog1p_ZeroDim(TestLog1p):
class TestLog1pAPI(unittest.TestCase): class TestLog1pAPI(unittest.TestCase):
def test_api(self): def test_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
input_x = np.random.uniform(0.1, 1, [11, 17]).astype("float64") input_x = np.random.uniform(0.1, 1, [11, 17]).astype("float64")
data_x = paddle.static.data( data_x = paddle.static.data(
...@@ -3106,6 +3097,7 @@ class TestPow(TestActivation): ...@@ -3106,6 +3097,7 @@ class TestPow(TestActivation):
self.public_python_api = paddle.pow self.public_python_api = paddle.pow
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(1, 2, self.shape).astype(self.dtype) x = np.random.uniform(1, 2, self.shape).astype(self.dtype)
...@@ -3116,6 +3108,9 @@ class TestPow(TestActivation): ...@@ -3116,6 +3108,9 @@ class TestPow(TestActivation):
self.attrs = {'factor': 3.0} self.attrs = {'factor': 3.0}
self.convert_input_output() self.convert_input_output()
def if_enable_cinn(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output(check_prim=True)
...@@ -3129,8 +3124,7 @@ class TestPow_ZeroDim(TestPow): ...@@ -3129,8 +3124,7 @@ class TestPow_ZeroDim(TestPow):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def setUp(self): def if_enable_cinn(self):
super().setUp()
self.enable_cinn = False self.enable_cinn = False
...@@ -3162,7 +3156,7 @@ class TestPow_factor_tensor(TestActivation): ...@@ -3162,7 +3156,7 @@ class TestPow_factor_tensor(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def test_api(self): def test_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
input = np.random.uniform(1, 2, [11, 17]).astype("float32") input = np.random.uniform(1, 2, [11, 17]).astype("float32")
x = paddle.static.data(name="x", shape=[11, 17], dtype="float32") x = paddle.static.data(name="x", shape=[11, 17], dtype="float32")
res = paddle.static.data( res = paddle.static.data(
...@@ -3261,7 +3255,7 @@ class TestSTanhAPI(unittest.TestCase): ...@@ -3261,7 +3255,7 @@ class TestSTanhAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', [10, 12]) x = paddle.static.data('X', [10, 12])
out = paddle.stanh(x, self.scale_a, self.scale_b) out = paddle.stanh(x, self.scale_a, self.scale_b)
...@@ -3279,7 +3273,7 @@ class TestSTanhAPI(unittest.TestCase): ...@@ -3279,7 +3273,7 @@ class TestSTanhAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_fluid_api(self): def test_fluid_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = paddle.static.data('X', [10, 12], dtype="float32") x = paddle.static.data('X', [10, 12], dtype="float32")
out = paddle.stanh(x, self.scale_a, self.scale_b) out = paddle.stanh(x, self.scale_a, self.scale_b)
...@@ -3289,7 +3283,7 @@ class TestSTanhAPI(unittest.TestCase): ...@@ -3289,7 +3283,7 @@ class TestSTanhAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, paddle.stanh, 1) self.assertRaises(TypeError, paddle.stanh, 1)
...@@ -3400,7 +3394,7 @@ class TestSoftplusAPI(unittest.TestCase): ...@@ -3400,7 +3394,7 @@ class TestSoftplusAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.softplus(x, self.beta, self.threshold) out1 = F.softplus(x, self.beta, self.threshold)
...@@ -3422,7 +3416,7 @@ class TestSoftplusAPI(unittest.TestCase): ...@@ -3422,7 +3416,7 @@ class TestSoftplusAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.softplus, 1) self.assertRaises(TypeError, F.softplus, 1)
...@@ -3485,7 +3479,7 @@ class TestSoftsignAPI(unittest.TestCase): ...@@ -3485,7 +3479,7 @@ class TestSoftsignAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.softsign(x) out1 = F.softsign(x)
...@@ -3507,7 +3501,7 @@ class TestSoftsignAPI(unittest.TestCase): ...@@ -3507,7 +3501,7 @@ class TestSoftsignAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.softsign, 1) self.assertRaises(TypeError, F.softsign, 1)
...@@ -3575,7 +3569,7 @@ class TestThresholdedReluAPI(unittest.TestCase): ...@@ -3575,7 +3569,7 @@ class TestThresholdedReluAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.thresholded_relu(x, self.threshold) out1 = F.thresholded_relu(x, self.threshold)
...@@ -3588,6 +3582,7 @@ class TestThresholdedReluAPI(unittest.TestCase): ...@@ -3588,6 +3582,7 @@ class TestThresholdedReluAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05) np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self): def test_dygraph_api(self):
paddle.disable_static()
x = paddle.to_tensor(self.x_np) x = paddle.to_tensor(self.x_np)
out1 = F.thresholded_relu(x, self.threshold) out1 = F.thresholded_relu(x, self.threshold)
thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold) thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold)
...@@ -3597,7 +3592,7 @@ class TestThresholdedReluAPI(unittest.TestCase): ...@@ -3597,7 +3592,7 @@ class TestThresholdedReluAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.thresholded_relu, 1) self.assertRaises(TypeError, F.thresholded_relu, 1)
...@@ -3678,7 +3673,7 @@ class TestHardsigmoidAPI(unittest.TestCase): ...@@ -3678,7 +3673,7 @@ class TestHardsigmoidAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.hardsigmoid(x) out1 = F.hardsigmoid(x)
...@@ -3700,7 +3695,7 @@ class TestHardsigmoidAPI(unittest.TestCase): ...@@ -3700,7 +3695,7 @@ class TestHardsigmoidAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_fluid_api(self): def test_fluid_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out = paddle.nn.functional.hardsigmoid(x, slope=0.2) out = paddle.nn.functional.hardsigmoid(x, slope=0.2)
...@@ -3715,7 +3710,7 @@ class TestHardsigmoidAPI(unittest.TestCase): ...@@ -3715,7 +3710,7 @@ class TestHardsigmoidAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.hardsigmoid, 1) self.assertRaises(TypeError, F.hardsigmoid, 1)
...@@ -3781,7 +3776,7 @@ class TestSwishAPI(unittest.TestCase): ...@@ -3781,7 +3776,7 @@ class TestSwishAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.swish(x) out1 = F.swish(x)
...@@ -3803,7 +3798,7 @@ class TestSwishAPI(unittest.TestCase): ...@@ -3803,7 +3798,7 @@ class TestSwishAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_fluid_api(self): def test_fluid_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out = paddle.nn.functional.swish(x) out = paddle.nn.functional.swish(x)
...@@ -3813,7 +3808,7 @@ class TestSwishAPI(unittest.TestCase): ...@@ -3813,7 +3808,7 @@ class TestSwishAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.swish, 1) self.assertRaises(TypeError, F.swish, 1)
...@@ -3880,7 +3875,7 @@ class TestMishAPI(unittest.TestCase): ...@@ -3880,7 +3875,7 @@ class TestMishAPI(unittest.TestCase):
) )
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.mish(x) out1 = F.mish(x)
...@@ -3902,7 +3897,7 @@ class TestMishAPI(unittest.TestCase): ...@@ -3902,7 +3897,7 @@ class TestMishAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
def test_fluid_api(self): def test_fluid_api(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out = paddle.nn.functional.mish(x) out = paddle.nn.functional.mish(x)
...@@ -3912,7 +3907,7 @@ class TestMishAPI(unittest.TestCase): ...@@ -3912,7 +3907,7 @@ class TestMishAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
def test_errors(self): def test_errors(self):
with paddle_static_guard(): with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, F.mish, 1) self.assertRaises(TypeError, F.mish, 1)
...@@ -3955,7 +3950,7 @@ def create_test_act_fp16_class( ...@@ -3955,7 +3950,7 @@ def create_test_act_fp16_class(
grad_check=True, grad_check=True,
check_dygraph=True, check_dygraph=True,
check_prim=False, check_prim=False,
enable_cinn=True, enable_cinn=False,
grad_atol=1e-2, grad_atol=1e-2,
**kwargs **kwargs
): ):
...@@ -4003,20 +3998,22 @@ def create_test_act_fp16_class( ...@@ -4003,20 +3998,22 @@ def create_test_act_fp16_class(
globals()[cls_name] = TestActFp16 globals()[cls_name] = TestActFp16
create_test_act_fp16_class(TestActivation, check_prim=True) create_test_act_fp16_class(TestActivation)
create_test_act_fp16_class(TestExpm1) create_test_act_fp16_class(TestExpm1)
create_test_act_fp16_class(TestSigmoid, check_prim=True) create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestSilu, check_prim=True) create_test_act_fp16_class(TestSilu, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestLogSigmoid) create_test_act_fp16_class(TestLogSigmoid)
create_test_act_fp16_class(TestTanh) create_test_act_fp16_class(TestTanh, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestTanhshrink) create_test_act_fp16_class(TestTanhshrink)
create_test_act_fp16_class(TestHardShrink) create_test_act_fp16_class(TestHardShrink)
create_test_act_fp16_class(TestSoftshrink) create_test_act_fp16_class(TestSoftshrink)
create_test_act_fp16_class(TestSqrt, check_prim=True) create_test_act_fp16_class(TestSqrt, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestSqrtComp, check_prim=True) create_test_act_fp16_class(TestSqrtComp, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestAbs, check_prim=True) create_test_act_fp16_class(TestAbs, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestCeil, grad_check=False) create_test_act_fp16_class(TestCeil, grad_check=False)
create_test_act_fp16_class(TestFloor, check_prim=True, grad_check=False) create_test_act_fp16_class(
TestFloor, check_prim=True, grad_check=False, enable_cinn=True
)
create_test_act_fp16_class(TestCos) create_test_act_fp16_class(TestCos)
create_test_act_fp16_class(TestTan) create_test_act_fp16_class(TestTan)
create_test_act_fp16_class(TestCosh) create_test_act_fp16_class(TestCosh)
...@@ -4029,7 +4026,7 @@ create_test_act_fp16_class(TestAcosh) ...@@ -4029,7 +4026,7 @@ create_test_act_fp16_class(TestAcosh)
create_test_act_fp16_class(TestAsinh) create_test_act_fp16_class(TestAsinh)
create_test_act_fp16_class(TestAtanh) create_test_act_fp16_class(TestAtanh)
create_test_act_fp16_class(TestRound, grad_check=False) create_test_act_fp16_class(TestRound, grad_check=False)
create_test_act_fp16_class(TestRelu, check_prim=True) create_test_act_fp16_class(TestRelu, check_prim=True, enable_cinn=True)
create_test_act_fp16_class( create_test_act_fp16_class(
TestGelu, TestGelu,
check_prim=True, check_prim=True,
...@@ -4063,14 +4060,18 @@ create_test_act_fp16_class(TestHardSigmoid) ...@@ -4063,14 +4060,18 @@ create_test_act_fp16_class(TestHardSigmoid)
create_test_act_fp16_class(TestSwish) create_test_act_fp16_class(TestSwish)
create_test_act_fp16_class(TestHardSwish, check_prim=True) create_test_act_fp16_class(TestHardSwish, check_prim=True)
create_test_act_fp16_class(TestMish) create_test_act_fp16_class(TestMish)
create_test_act_fp16_class(TestLeakyRelu, check_prim=True) create_test_act_fp16_class(TestLeakyRelu, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestLeakyReluAlpha1, check_prim=True)
create_test_act_fp16_class(TestLeakyReluAlpha2, check_prim=True)
create_test_act_fp16_class(TestLeakyReluAlpha3, check_prim=True)
create_test_act_fp16_class( create_test_act_fp16_class(
TestLeakyRelu_ZeroDim, check_prim=True, enable_cinn=False TestLeakyReluAlpha1, check_prim=True, enable_cinn=True
) )
create_test_act_fp16_class(TestRsqrt) create_test_act_fp16_class(
TestLeakyReluAlpha2, check_prim=True, enable_cinn=True
)
create_test_act_fp16_class(
TestLeakyReluAlpha3, check_prim=True, enable_cinn=True
)
create_test_act_fp16_class(TestLeakyRelu_ZeroDim, check_prim=True)
create_test_act_fp16_class(TestRsqrt, check_prim=True, enable_cinn=True)
def create_test_act_bf16_class( def create_test_act_bf16_class(
...@@ -4079,7 +4080,7 @@ def create_test_act_bf16_class( ...@@ -4079,7 +4080,7 @@ def create_test_act_bf16_class(
grad_check=True, grad_check=True,
check_dygraph=True, check_dygraph=True,
check_prim=False, check_prim=False,
enable_cinn=True, enable_cinn=False,
grad_atol=1e-2, grad_atol=1e-2,
**kwargs **kwargs
): ):
...@@ -4097,6 +4098,9 @@ def create_test_act_bf16_class( ...@@ -4097,6 +4098,9 @@ def create_test_act_bf16_class(
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def if_enable_cinn(self):
self.enable_cinn = enable_cinn
def convert_input_output(self): def convert_input_output(self):
self.inputs = {'X': convert_float_to_uint16(self.inputs['X'])} self.inputs = {'X': convert_float_to_uint16(self.inputs['X'])}
self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])} self.outputs = {'Out': convert_float_to_uint16(self.outputs['Out'])}
...@@ -4104,13 +4108,19 @@ def create_test_act_bf16_class( ...@@ -4104,13 +4108,19 @@ def create_test_act_bf16_class(
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=atol) self.check_output_with_place(
place, atol=atol, check_prim=check_prim
)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if grad_check: if grad_check:
self.check_grad_with_place( self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=grad_atol place,
['X'],
'Out',
max_relative_error=grad_atol,
check_prim=check_prim,
) )
cls_name = "{}_{}".format(parent.__name__, "BF16OP") cls_name = "{}_{}".format(parent.__name__, "BF16OP")
...@@ -4118,12 +4128,12 @@ def create_test_act_bf16_class( ...@@ -4118,12 +4128,12 @@ def create_test_act_bf16_class(
globals()[cls_name] = TestActBF16 globals()[cls_name] = TestActBF16
create_test_act_bf16_class(TestActivation, check_prim=True) create_test_act_bf16_class(TestActivation)
create_test_act_bf16_class(TestExpm1) create_test_act_bf16_class(TestExpm1)
create_test_act_bf16_class(TestSigmoid, check_prim=True) create_test_act_bf16_class(TestSigmoid, check_prim=True)
create_test_act_bf16_class(TestSilu, check_prim=True) create_test_act_bf16_class(TestSilu, check_prim=True)
create_test_act_bf16_class(TestLogSigmoid) create_test_act_bf16_class(TestLogSigmoid)
create_test_act_bf16_class(TestTanh) create_test_act_bf16_class(TestTanh, check_prim=True)
create_test_act_bf16_class(TestTanhshrink) create_test_act_bf16_class(TestTanhshrink)
create_test_act_bf16_class(TestHardShrink) create_test_act_bf16_class(TestHardShrink)
create_test_act_bf16_class(TestSoftshrink) create_test_act_bf16_class(TestSoftshrink)
...@@ -4148,7 +4158,6 @@ create_test_act_bf16_class(TestRelu, check_prim=True) ...@@ -4148,7 +4158,6 @@ create_test_act_bf16_class(TestRelu, check_prim=True)
create_test_act_bf16_class( create_test_act_bf16_class(
TestGelu, TestGelu,
check_prim=True, check_prim=True,
enable_cinn=True,
rev_comp_rtol=1e-2, rev_comp_rtol=1e-2,
rev_comp_atol=1e-2, rev_comp_atol=1e-2,
cinn_rtol=1e-2, cinn_rtol=1e-2,
...@@ -4178,20 +4187,12 @@ create_test_act_bf16_class(TestHardSigmoid) ...@@ -4178,20 +4187,12 @@ create_test_act_bf16_class(TestHardSigmoid)
create_test_act_bf16_class(TestSwish) create_test_act_bf16_class(TestSwish)
create_test_act_bf16_class(TestHardSwish, check_prim=True) create_test_act_bf16_class(TestHardSwish, check_prim=True)
create_test_act_bf16_class(TestMish) create_test_act_bf16_class(TestMish)
create_test_act_bf16_class(TestLeakyRelu, check_prim=True, enable_cinn=False) create_test_act_bf16_class(TestLeakyRelu, check_prim=True)
create_test_act_bf16_class( create_test_act_bf16_class(TestLeakyReluAlpha1, check_prim=True)
TestLeakyReluAlpha1, check_prim=True, enable_cinn=False create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True)
) create_test_act_bf16_class(TestLeakyReluAlpha3, check_prim=True)
create_test_act_bf16_class( create_test_act_bf16_class(TestLeakyRelu_ZeroDim, check_prim=True)
TestLeakyReluAlpha2, check_prim=True, enable_cinn=False create_test_act_bf16_class(TestRsqrt, check_prim=True)
)
create_test_act_bf16_class(
TestLeakyReluAlpha3, check_prim=True, enable_cinn=False
)
create_test_act_bf16_class(
TestLeakyRelu_ZeroDim, check_prim=True, enable_cinn=False
)
create_test_act_bf16_class(TestRsqrt)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -631,12 +631,13 @@ def rsqrt_composite(x): ...@@ -631,12 +631,13 @@ def rsqrt_composite(x):
is_amp = False is_amp = False
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16": dtype = convert_dtype(x.dtype)
if dtype == "float16" or dtype == "uint16":
is_amp = True is_amp = True
x = cast(x, "float32") x = cast(x, "float32")
y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype) y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype)
res = pow(x, y) res = pow(x, y)
return res if not is_amp else cast(res, "float16") return res if not is_amp else cast(res, dtype)
@REGISTER_COMPOSITE('group_norm') @REGISTER_COMPOSITE('group_norm')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册