未验证 提交 4c6ad5c0 编写于 作者: F FlyingQianMM 提交者: GitHub

[AMP_OP&Test] improve FP16 and BF16 OpTest for maximum, minimum and multiply op (#52256)

* [AMP_OP&Test] add FP16 and BF16 OpTest for minimum; add FP16 OpTest for multiply

* [AMP_OP&Test] reset atol and max_relative_error for multiply fp16 and bf16 optest

* [AMP_OP&Test] delete manually atol and max_relative_error setting
上级 6efeb227
...@@ -104,6 +104,17 @@ class TestElementwiseFP16Op(TestElementwiseOp): ...@@ -104,6 +104,17 @@ class TestElementwiseFP16Op(TestElementwiseOp):
np.float16 np.float16
) )
def setUp(self):
self.init_data()
self.op_type = "elementwise_max"
self.prim_op_type = "prim"
self.enable_cinn = False
self.python_api = paddle.maximum
self.dtype = np.float16
self.public_python_api = paddle.maximum
self.inputs = {'X': self.x, 'Y': self.y}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMaxOp_ZeroDim1(TestElementwiseOp): class TestElementwiseMaxOp_ZeroDim1(TestElementwiseOp):
def init_data(self): def init_data(self):
...@@ -111,10 +122,10 @@ class TestElementwiseMaxOp_ZeroDim1(TestElementwiseOp): ...@@ -111,10 +122,10 @@ class TestElementwiseMaxOp_ZeroDim1(TestElementwiseOp):
self.y = np.random.uniform(0.1, 1, []).astype("float64") self.y = np.random.uniform(0.1, 1, []).astype("float64")
class TestElementwiseMaxFP16Op_ZeroDim1(TestElementwiseOp): class TestElementwiseMaxFP16Op_ZeroDim1(TestElementwiseFP16Op):
def init_data(self): def init_data(self):
self.x = np.random.uniform(0.1, 1, []).astype("float16") self.x = np.random.uniform(0.1, 1, []).astype(np.float16)
self.y = np.random.uniform(0.1, 1, []).astype("float16") self.y = np.random.uniform(0.1, 1, []).astype(np.float16)
class TestElementwiseMaxOp_ZeroDim2(TestElementwiseOp): class TestElementwiseMaxOp_ZeroDim2(TestElementwiseOp):
...@@ -123,10 +134,10 @@ class TestElementwiseMaxOp_ZeroDim2(TestElementwiseOp): ...@@ -123,10 +134,10 @@ class TestElementwiseMaxOp_ZeroDim2(TestElementwiseOp):
self.y = np.random.uniform(0.1, 1, []).astype("float64") self.y = np.random.uniform(0.1, 1, []).astype("float64")
class TestElementwiseMaxFP16Op_ZeroDim2(TestElementwiseOp): class TestElementwiseMaxFP16Op_ZeroDim2(TestElementwiseFP16Op):
def init_data(self): def init_data(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype("float16") self.x = np.random.uniform(0.1, 1, [13, 17]).astype(np.float16)
self.y = np.random.uniform(0.1, 1, []).astype("float16") self.y = np.random.uniform(0.1, 1, []).astype(np.float16)
class TestElementwiseMaxOp_ZeroDim3(TestElementwiseOp): class TestElementwiseMaxOp_ZeroDim3(TestElementwiseOp):
...@@ -135,10 +146,10 @@ class TestElementwiseMaxOp_ZeroDim3(TestElementwiseOp): ...@@ -135,10 +146,10 @@ class TestElementwiseMaxOp_ZeroDim3(TestElementwiseOp):
self.y = np.random.uniform(0.1, 1, [13, 17]).astype("float64") self.y = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
class TestElementwiseMaxFP16Op_ZeroDim3(TestElementwiseOp): class TestElementwiseMaxFP16Op_ZeroDim3(TestElementwiseFP16Op):
def init_data(self): def init_data(self):
self.x = np.random.uniform(0.1, 1, []).astype("float16") self.x = np.random.uniform(0.1, 1, []).astype(np.float16)
self.y = np.random.uniform(0.1, 1, [13, 17]).astype("float16") self.y = np.random.uniform(0.1, 1, [13, 17]).astype(np.float16)
@unittest.skipIf( @unittest.skipIf(
...@@ -180,29 +191,11 @@ class TestElementwiseBF16Op(OpTest): ...@@ -180,29 +191,11 @@ class TestElementwiseBF16Op(OpTest):
if hasattr(self, 'attrs'): if hasattr(self, 'attrs'):
self.check_output(check_dygraph=False) self.check_output(check_dygraph=False)
else: else:
self.check_output() self.check_output(check_dygraph=True)
def test_check_grad_normal(self): def test_check_grad_normal(self):
if hasattr(self, 'attrs'): if hasattr(self, 'attrs'):
# check_prim=False, bfloat16 is not supported in `less_equal` # check_prim=False, bfloat16 is not supported in `less_equal`
self.check_grad(['X', 'Y'], 'Out', check_dygraph=False)
else:
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
class TestElementwiseMaxBF16Op_ZeroDim1(TestElementwiseBF16Op):
def init_data(self):
self.x = np.random.uniform(0.1, 1, []).astype("float32")
self.y = np.random.uniform(0.1, 1, []).astype("float32")
def test_check_grad_normal(self):
if hasattr(self, 'attrs'):
self.check_grad( self.check_grad(
['X', 'Y'], 'Out', numeric_grad_delta=0.05, check_dygraph=False ['X', 'Y'], 'Out', numeric_grad_delta=0.05, check_dygraph=False
) )
...@@ -220,6 +213,12 @@ class TestElementwiseMaxBF16Op_ZeroDim1(TestElementwiseBF16Op): ...@@ -220,6 +213,12 @@ class TestElementwiseMaxBF16Op_ZeroDim1(TestElementwiseBF16Op):
) )
class TestElementwiseMaxBF16Op_ZeroDim1(TestElementwiseBF16Op):
def init_data(self):
self.x = np.random.uniform(0.1, 1, []).astype("float32")
self.y = np.random.uniform(0.1, 1, []).astype("float32")
class TestElementwiseMaxBF16Op_scalar(TestElementwiseBF16Op): class TestElementwiseMaxBF16Op_scalar(TestElementwiseBF16Op):
def init_data(self): def init_data(self):
self.x = np.random.random_integers(-5, 5, [2, 3, 20]).astype("float32") self.x = np.random.random_integers(-5, 5, [2, 3, 20]).astype("float32")
...@@ -236,10 +235,13 @@ class TestElementwiseMaxOp_scalar(TestElementwiseOp): ...@@ -236,10 +235,13 @@ class TestElementwiseMaxOp_scalar(TestElementwiseOp):
self.y = np.array([0.5]).astype("float64") self.y = np.array([0.5]).astype("float64")
class TestElementwiseMaxFP16Op_scalar(TestElementwiseMaxOp_scalar): @skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast."
)
class TestElementwiseMaxFP16Op_scalar(TestElementwiseFP16Op):
def init_data(self): def init_data(self):
self.x = np.random.random_integers(-5, 5, [2, 3, 20]).astype("float16") self.x = np.random.random_integers(-5, 5, [2, 3, 20]).astype(np.float16)
self.y = np.array([0.5]).astype("float16") self.y = np.array([0.5]).astype(np.float16)
class TestElementwiseMaxOp_Vector(TestElementwiseOp): class TestElementwiseMaxOp_Vector(TestElementwiseOp):
...@@ -251,12 +253,12 @@ class TestElementwiseMaxOp_Vector(TestElementwiseOp): ...@@ -251,12 +253,12 @@ class TestElementwiseMaxOp_Vector(TestElementwiseOp):
) )
class TestElementwiseMaxFP16Op_Vector(TestElementwiseOp): class TestElementwiseMaxFP16Op_Vector(TestElementwiseFP16Op):
def init_data(self): def init_data(self):
self.x = np.random.random((100,)).astype("float16") self.x = np.random.random((100,)).astype(np.float16)
sgn = np.random.choice([-1, 1], (100,)).astype("float16") sgn = np.random.choice([-1, 1], (100,)).astype(np.float16)
self.y = self.x + sgn * np.random.uniform(0.1, 1, (100,)).astype( self.y = self.x + sgn * np.random.uniform(0.1, 1, (100,)).astype(
"float16" np.float16
) )
...@@ -289,12 +291,13 @@ class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp): ...@@ -289,12 +291,13 @@ class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp):
} }
class TestElementwiseMaxFP16Op_broadcast_2(TestElementwiseOp): class TestElementwiseMaxFP16Op_broadcast_2(TestElementwiseFP16Op):
def setUp(self): def setUp(self):
self.op_type = "elementwise_max" self.op_type = "elementwise_max"
self.python_api = paddle.maximum self.python_api = paddle.maximum
self.public_python_api = paddle.maximum self.public_python_api = paddle.maximum
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.dtype = np.float16
x = np.random.uniform(0.5, 1, (1, 3, 100)).astype(np.float16) x = np.random.uniform(0.5, 1, (1, 3, 100)).astype(np.float16)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float16) sgn = np.random.choice([-1, 1], (100,)).astype(np.float16)
y = x[0, 0, :] + sgn * np.random.uniform(1, 2, (100,)).astype( y = x[0, 0, :] + sgn * np.random.uniform(1, 2, (100,)).astype(
...@@ -323,12 +326,13 @@ class TestElementwiseMaxOp_broadcast_4(TestElementwiseOp): ...@@ -323,12 +326,13 @@ class TestElementwiseMaxOp_broadcast_4(TestElementwiseOp):
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseFP16Op_broadcast_4(TestElementwiseOp): class TestElementwiseFP16Op_broadcast_4(TestElementwiseFP16Op):
def setUp(self): def setUp(self):
self.op_type = "elementwise_max" self.op_type = "elementwise_max"
self.python_api = paddle.maximum self.python_api = paddle.maximum
self.public_python_api = paddle.maximum self.public_python_api = paddle.maximum
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.dtype = np.float16
x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float16) x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float16)
sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(np.float16) sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(np.float16)
y = x + sgn * np.random.uniform(1, 2, (2, 3, 1, 5)).astype(np.float16) y = x + sgn * np.random.uniform(1, 2, (2, 3, 1, 5)).astype(np.float16)
......
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest, skip_check_grad_ci from eager_op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
import paddle import paddle
from paddle import _legacy_C_ops, fluid
from paddle.fluid import core from paddle.fluid import core
paddle.enable_static() paddle.enable_static()
...@@ -61,6 +60,33 @@ class TestElementwiseOp(OpTest): ...@@ -61,6 +60,33 @@ class TestElementwiseOp(OpTest):
) )
class TestElementwiseFP16Op(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_min"
self.python_api = paddle.minimum
self.dtype = np.float16
# If x and y have the same value, the min() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
x = np.random.uniform(0.1, 1, [13, 17]).astype(np.float16)
sgn = np.random.choice([-1, 1], [13, 17]).astype(np.float16)
y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype(np.float16)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
class TestElementwiseMinOp_ZeroDim1(TestElementwiseOp): class TestElementwiseMinOp_ZeroDim1(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
...@@ -71,6 +97,12 @@ class TestElementwiseMinOp_ZeroDim1(TestElementwiseOp): ...@@ -71,6 +97,12 @@ class TestElementwiseMinOp_ZeroDim1(TestElementwiseOp):
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMinFP16Op_ZeroDim1(TestElementwiseFP16Op):
def init_data(self):
self.x = np.random.uniform(0.1, 1, []).astype(np.float16)
self.y = np.random.uniform(0.1, 1, []).astype(np.float16)
class TestElementwiseMinOp_ZeroDim2(TestElementwiseOp): class TestElementwiseMinOp_ZeroDim2(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
...@@ -81,6 +113,12 @@ class TestElementwiseMinOp_ZeroDim2(TestElementwiseOp): ...@@ -81,6 +113,12 @@ class TestElementwiseMinOp_ZeroDim2(TestElementwiseOp):
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMinFP16Op_ZeroDim2(TestElementwiseFP16Op):
def init_data(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype("float16")
self.y = np.random.uniform(0.1, 1, []).astype("float16")
class TestElementwiseMinOp_ZeroDim3(TestElementwiseOp): class TestElementwiseMinOp_ZeroDim3(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
...@@ -91,6 +129,12 @@ class TestElementwiseMinOp_ZeroDim3(TestElementwiseOp): ...@@ -91,6 +129,12 @@ class TestElementwiseMinOp_ZeroDim3(TestElementwiseOp):
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMinFP16Op_ZeroDim3(TestElementwiseFP16Op):
def init_data(self):
self.x = np.random.uniform(0.1, 1, []).astype("float16")
self.y = np.random.uniform(0.1, 1, [13, 17]).astype("float16")
@skip_check_grad_ci( @skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast." reason="[skip shape check] Use y_shape(1) to test broadcast."
) )
...@@ -104,6 +148,19 @@ class TestElementwiseMinOp_scalar(TestElementwiseOp): ...@@ -104,6 +148,19 @@ class TestElementwiseMinOp_scalar(TestElementwiseOp):
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast."
)
class TestElementwiseMinFP16Op_scalar(TestElementwiseFP16Op):
def setUp(self):
self.op_type = "elementwise_min"
self.python_api = paddle.minimum
x = np.random.random_integers(-5, 5, [10, 3, 4]).astype(np.float16)
y = np.array([0.5]).astype(np.float16)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMinOp_Vector(TestElementwiseOp): class TestElementwiseMinOp_Vector(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
...@@ -115,6 +172,17 @@ class TestElementwiseMinOp_Vector(TestElementwiseOp): ...@@ -115,6 +172,17 @@ class TestElementwiseMinOp_Vector(TestElementwiseOp):
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMinFP16Op_Vector(TestElementwiseFP16Op):
def setUp(self):
self.op_type = "elementwise_min"
self.python_api = paddle.minimum
x = np.random.random((100,)).astype(np.float16)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float16)
y = x + sgn * np.random.uniform(0.1, 1, (100,)).astype(np.float16)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMinOp_broadcast_2(TestElementwiseOp): class TestElementwiseMinOp_broadcast_2(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
...@@ -133,6 +201,24 @@ class TestElementwiseMinOp_broadcast_2(TestElementwiseOp): ...@@ -133,6 +201,24 @@ class TestElementwiseMinOp_broadcast_2(TestElementwiseOp):
} }
class TestElementwiseMinFP16Op_broadcast_2(TestElementwiseFP16Op):
def setUp(self):
self.op_type = "elementwise_min"
self.python_api = broadcast_wrapper(shape=[1, 1, 100])
x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(np.float16)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float16)
y = x[0, 0, :] + sgn * np.random.uniform(1, 2, (100,)).astype(
np.float16
)
self.inputs = {'X': x, 'Y': y}
self.outputs = {
'Out': np.minimum(
self.inputs['X'], self.inputs['Y'].reshape(1, 1, 100)
)
}
class TestElementwiseMinOp_broadcast_4(TestElementwiseOp): class TestElementwiseMinOp_broadcast_4(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
...@@ -145,52 +231,90 @@ class TestElementwiseMinOp_broadcast_4(TestElementwiseOp): ...@@ -145,52 +231,90 @@ class TestElementwiseMinOp_broadcast_4(TestElementwiseOp):
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMinOpFP16(unittest.TestCase): class TestElementwiseMinFP16Op_broadcast_4(TestElementwiseFP16Op):
def get_out_and_grad(self, x_np, y_np, axis, place, use_fp32=False): def setUp(self):
assert x_np.dtype == np.float16 self.op_type = "elementwise_min"
assert y_np.dtype == np.float16 self.python_api = paddle.minimum
if use_fp32: x = np.random.uniform(0.5, 1, (2, 10, 2, 5)).astype(np.float16)
x_np = x_np.astype(np.float32) sgn = np.random.choice([-1, 1], (2, 10, 1, 5)).astype(np.float16)
y_np = y_np.astype(np.float32) y = x + sgn * np.random.uniform(1, 2, (2, 10, 1, 5)).astype(np.float16)
dtype = np.float16 self.inputs = {'X': x, 'Y': y}
with fluid.dygraph.guard(place): self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
x.stop_gradient = False @unittest.skipIf(
y.stop_gradient = False core.is_compiled_with_cuda()
z = _legacy_C_ops.elementwise_min(x, y, 'axis', axis) and (
x_g, y_g = paddle.grad([z], [x, y]) core.cudnn_version() < 8100
return ( or paddle.device.cuda.get_device_capability()[0] < 8
z.numpy().astype(dtype), ),
x_g.numpy().astype(dtype), "run test when gpu is availble and the minimum cudnn version is 8.1.0 and gpu's compute capability is at least 8.0.",
y_g.numpy().astype(dtype), )
class TestElementwiseBF16Op(OpTest):
def init_data(self):
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(np.float32)
sgn = np.random.choice([-1, 1], [13, 17]).astype(np.float32)
self.y = self.x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype(
np.float32
) )
def check_main(self, x_shape, y_shape, axis=-1): def setUp(self):
if not paddle.is_compiled_with_cuda(): self.init_data()
return self.op_type = "elementwise_min"
place = paddle.CUDAPlace(0) self.python_api = paddle.minimum
if not core.is_float16_supported(place): self.public_python_api = paddle.minimum
return self.prim_op_type = "prim"
self.enable_cinn = False
self.dtype = np.uint16
self.inputs = {
'X': convert_float_to_uint16(self.x),
'Y': convert_float_to_uint16(self.y),
}
self.outputs = {
'Out': convert_float_to_uint16(np.minimum(self.x, self.y))
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', numeric_grad_delta=0.05)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', numeric_grad_delta=0.05, no_grad_set=set("X")
)
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', numeric_grad_delta=0.05, no_grad_set=set('Y')
)
class TestElementwiseMinBF16Op_ZeroDim1(TestElementwiseBF16Op):
def init_data(self):
self.x = np.random.uniform(0.1, 1, []).astype("float32")
self.y = np.random.uniform(0.1, 1, []).astype("float32")
class TestElementwiseMinBF16Op_scalar(TestElementwiseBF16Op):
def init_data(self):
self.x = np.random.random_integers(-5, 5, [2, 3, 20]).astype("float32")
self.y = np.array([0.5]).astype("float32")
self.__class__.no_need_check_grad = True
x_np = np.random.random(size=x_shape).astype(np.float16)
y_np = np.random.random(size=y_shape).astype(np.float16)
z_1, x_g_1, y_g_1 = self.get_out_and_grad( class TestElementwiseMinBF16Op_Vector(TestElementwiseBF16Op):
x_np, y_np, axis, place, False def init_data(self):
self.x = np.random.random((100,)).astype("float32")
sgn = np.random.choice([-1, 1], (100,)).astype("float32")
self.y = self.x + sgn * np.random.uniform(0.1, 1, (100,)).astype(
"float32"
) )
z_2, x_g_2, y_g_2 = self.get_out_and_grad(x_np, y_np, axis, place, True)
np.testing.assert_array_equal(z_1, z_2)
np.testing.assert_array_equal(x_g_1, x_g_2)
np.testing.assert_array_equal(y_g_1, y_g_2)
def test_main(self):
self.check_main((13, 17), (13, 17))
self.check_main((10, 3, 4), (1,))
self.check_main((100,), (100,))
self.check_main((2, 3, 100), (100,))
self.check_main((2, 10, 2, 5), (2, 10, 1, 5))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -156,10 +156,20 @@ class TestBF16ElementwiseMulOp(OpTest): ...@@ -156,10 +156,20 @@ class TestBF16ElementwiseMulOp(OpTest):
self.check_grad(['X', 'Y'], 'Out', check_prim=True) self.check_grad(['X', 'Y'], 'Out', check_prim=True)
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_prim=True) self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
check_prim=True,
)
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_prim=True) self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
check_prim=True,
)
def if_enable_cinn(self): def if_enable_cinn(self):
self.enable_cinn = False self.enable_cinn = False
...@@ -357,6 +367,39 @@ class TestElementwiseMulOpFp16(ElementwiseMulOp): ...@@ -357,6 +367,39 @@ class TestElementwiseMulOpFp16(ElementwiseMulOp):
def if_enable_cinn(self): def if_enable_cinn(self):
pass pass
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(not self.use_mkldnn))
def test_check_grad_normal(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X', 'Y'],
'Out',
check_dygraph=(not self.use_mkldnn),
check_prim=True,
)
def test_check_grad_ingore_x(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
check_dygraph=(not self.use_mkldnn),
check_prim=True,
)
def test_check_grad_ingore_y(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
check_dygraph=(not self.use_mkldnn),
check_prim=True,
)
class TestElementwiseMulOp_commonuse_1(ElementwiseMulOp): class TestElementwiseMulOp_commonuse_1(ElementwiseMulOp):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册