未验证 提交 acec26a1 编写于 作者: T taixiurong 提交者: GitHub

xpu add dropout&cast unitest (#41120)

上级 3b686b18
......@@ -42,7 +42,13 @@ class DropoutXPUKernel : public framework::OpKernel<T> {
if (!context.Attr<bool>("is_test")) {
int seed_data = 0;
if (seed) {
seed_data = *(seed->data<int>());
if (platform::is_xpu_place(seed->place())) {
memory::Copy(platform::CPUPlace(), &seed_data, seed->place(),
seed->data<int>(), sizeof(int));
} else {
seed_data = *(seed->data<int>());
}
} else {
seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : 0;
......
......@@ -54,13 +54,11 @@ class XPUOpTest(OpTest):
"""Restore random seeds"""
def is_empty_grad_op(op_type):
all_op_kernels = core._get_all_register_op_kernels()
grad_op = op_type + '_grad'
if grad_op in all_op_kernels.keys():
grad_op_kernels = all_op_kernels[grad_op]
for grad_op_kernel in grad_op_kernels:
if 'XPU' in grad_op_kernel:
return False
xpu_version = core.get_xpu_device_version(0)
xpu_op_list = core.get_xpu_device_op_list(xpu_version)
if grad_op in xpu_op_list.keys():
return False
return True
if cls.dtype == np.float16:
......@@ -70,9 +68,20 @@ class XPUOpTest(OpTest):
super().tearDownClass()
def _get_places(self):
places = [fluid.XPUPlace(0)]
places = [paddle.XPUPlace(0)]
return places
def check_output(self,
atol=0.001,
no_check_set=None,
equal_nan=False,
check_dygraph=True,
inplace_atol=None,
check_eager=False):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol, no_check_set, equal_nan,
check_dygraph, inplace_atol, check_eager)
def check_output_with_place(self,
place,
atol=0.001,
......@@ -82,20 +91,37 @@ class XPUOpTest(OpTest):
inplace_atol=None,
check_eager=False):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
#xpu not support float64
if self.dtype == np.float64:
return
if place == None:
place = paddle.XPUPlace(0)
if self.dtype == np.float16:
if core.is_float16_supported(place) == False:
return
if self.dtype == np.float16:
atol = 0.1
return super().check_output_with_place(
place, atol, no_check_set, equal_nan, check_dygraph, inplace_atol)
def check_grad(self,
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
user_defined_grad_outputs=None,
check_dygraph=True,
numeric_place=None,
check_eager=False):
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, inputs_to_check, output_names,
no_grad_set, numeric_grad_delta, in_place,
max_relative_error, user_defined_grads,
user_defined_grad_outputs, check_dygraph,
numeric_place, check_eager)
def check_grad_with_place(self,
place,
inputs_to_check,
......@@ -116,9 +142,6 @@ class XPUOpTest(OpTest):
self._check_grad_helper()
return
if place == None:
place = paddle.XPUPlace(0)
if self.dtype == np.float64:
return
......
......@@ -23,6 +23,9 @@ import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
typeid_dict = {
'int32': int(core.VarDesc.VarType.INT32),
......@@ -33,10 +36,27 @@ typeid_dict = {
}
def create_test_class(in_typename, out_typename):
class Cls(op_test.OpTest):
class XPUTestCastOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'cast'
self.use_dynamic_create_class = True
def dynamic_create_class(self):
base_class = self.TestCastOp
classes = []
for out_type in {'float16', 'float32', 'int32', 'int64'}:
class_name = 'XPUTestCastOp_outtype_' + out_type
attr_dict = {'out_typename': out_type}
classes.append([class_name, attr_dict])
return base_class, classes
class TestCastOp(XPUOpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
in_typename = self.in_type_str
out_typename = 'float32' if not hasattr(
self, 'out_typename') else self.out_typename
self.inputs = {'X': ipt.astype(in_typename)}
self.outputs = {'Out': ipt.astype(in_typename).astype(out_typename)}
self.attrs = {
......@@ -47,18 +67,12 @@ def create_test_class(in_typename, out_typename):
self.__class__.no_need_check_grad = True
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
cls_name = "cast_{0}_{1}".format(in_typename, out_typename)
Cls.__name__ = cls_name
globals()[cls_name] = Cls
self.check_output()
for in_type in {'float16', 'float32', 'int32', 'int64', 'bool'}:
for out_type in {'float16', 'float32', 'int32', 'int64'}:
create_test_class(in_type, out_type)
support_types = get_xpu_op_support_types('cast')
for stype in support_types:
create_test_class(globals(), XPUTestCastOp, stype)
class TestCastOpError(unittest.TestCase):
......
......@@ -25,90 +25,196 @@ from paddle.fluid import Program, program_guard
from op_test_xpu import XPUOpTest
paddle.enable_static()
class TestDropoutOp(XPUOpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64)).astype('uint8')
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad_normal(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
class TestDropoutOpInput1d(XPUOpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((2000, )).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((2000)).astype('uint8')
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad_normal(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
class TestDropoutOp2(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 1.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('uint8')
}
class TestDropoutOp3(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('uint8')
}
class TestDropoutOp6(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {
'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('uint8')
}
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
class XPUTestDropoutOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'dropout'
self.use_dynamic_create_class = False
class TestDropoutOp(XPUOpTest):
def setUp(self):
self.init_inputs_shape()
self.init_attrs()
self.dtype = self.in_type
self.op_type = 'dropout'
self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)}
self.attrs = {
'dropout_prob': self.dropout_prob,
'fix_seed': self.fix_seed,
'is_test': self.is_test,
'dropout_implementation': self.dropout_implementation
}
out = self.inputs['X'] * (1.0 - self.dropout_prob)
if self.is_test == False:
mask = None
if self.dropout_prob == 0.0:
mask = np.ones(self.shape).astype(self.dtype)
elif self.dropout_prob == 1.0:
mask = np.zeros(self.shape).astype(self.dtype)
self.outputs = {'Out': out, 'Mask': mask}
else:
self.outputs = {'Out': out}
def init_inputs_shape(self):
self.shape = [32, 64]
def init_attrs(self):
self.__class__.no_need_check_grad = False
self.dropout_prob = 0.0
self.fix_seed = True
self.is_test = False
self.dropout_implementation = "upscale_in_train"
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
if hasattr(self.__class__, "no_need_check_grad"
) and self.__class__.no_need_check_grad == True:
return
self.check_grad(['X'], 'Out')
class TestDropoutOpInput1d(TestDropoutOp):
def init_inputs_shape(self):
self.shape = [2000]
class TestDropoutOp2(TestDropoutOp):
def init_inputs_shape(self):
self.shape = [32, 64]
def init_attrs(self):
self.dropout_prob = 1.0
self.fix_seed = True
self.is_test = False
self.dropout_implementation = "upscale_in_train"
class TestDropoutOp3(TestDropoutOp):
def init_inputs_shape(self):
self.shape = [32, 64, 2]
class TestDropoutOp4(TestDropoutOp):
def init_attrs(self):
self.__class__.no_need_check_grad = True
self.dropout_prob = 0.35
self.fix_seed = True
self.is_test = True
self.dropout_implementation = "downgrade_in_infer"
class TestDropoutOp5(TestDropoutOp):
def init_inputs_shape(self):
self.shape = [32, 64, 3]
def init_attrs(self):
self.__class__.no_need_check_grad = True
self.dropout_prob = 0.75
self.fix_seed = True
self.is_test = True
self.dropout_implementation = "downgrade_in_infer"
class TestDropoutOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
# the input of dropout must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]],
fluid.CPUPlace())
fluid.layers.dropout(x1, dropout_prob=0.5)
self.assertRaises(TypeError, test_Variable)
def test_dtype():
# the input dtype of dropout must be float16 or float32 or float64
# float16 only can be set on GPU place
x2 = fluid.layers.data(
name='x2', shape=[3, 4, 5, 6], dtype="int32")
fluid.layers.dropout(x2, dropout_prob=0.5)
self.assertRaises(TypeError, test_dtype)
class TestDropoutCAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = [fluid.CPUPlace()]
self.places.append(fluid.XPUPlace(0))
def test_dygraph(self):
for place in self.places:
with fluid.dygraph.guard(place):
input_np = np.random.random([40, 40]).astype(self.in_type)
result_np = input_np
input = fluid.dygraph.to_variable(input_np)
m = paddle.nn.Dropout(p=0.)
m.eval()
result = m(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
class TestDropoutBackward(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = [fluid.CPUPlace()]
self.places.append(fluid.XPUPlace(0))
def cal_grad_upscale_train(self, mask, prob):
return mask.astype(self.in_type) / (1 - prob)
def cal_grad_downscale_in_infer(self, mask):
return mask.astype(self.in_type)
def test_backward_downscale_in_infer(self):
for place in self.places:
with fluid.dygraph.guard(place):
input = paddle.uniform([40, 40], dtype=self.in_type)
input.stop_gradient = False
out, mask = core.ops.dropout(input, 'dropout_prob', 0.5)
out.backward()
self.assertTrue(
np.array_equal(input.gradient(
), self.cal_grad_downscale_in_infer(mask.numpy())))
def test_backward_upscale_train(self):
for place in self.places:
with fluid.dygraph.guard(place):
prob = 0.5
input = paddle.uniform([40, 40], dtype=self.in_type)
input.stop_gradient = False
out, mask = core.ops.dropout(input, 'dropout_prob', prob,
"dropout_implementation",
"upscale_in_train")
out.backward()
self.assertTrue(
np.allclose(input.gradient(
), self.cal_grad_upscale_train(mask.numpy(), prob)))
def test_backward_upscale_train_2(self):
for place in self.places:
with fluid.dygraph.guard(place):
prob = 0.3
input = paddle.uniform([40, 40], dtype=self.in_type)
input.stop_gradient = False
out, mask = core.ops.dropout(input, 'dropout_prob', prob,
"dropout_implementation",
"upscale_in_train")
out.backward()
self.assertTrue(
np.allclose(input.gradient(
), self.cal_grad_upscale_train(mask.numpy(), prob)))
support_types = get_xpu_op_support_types('dropout')
for stype in support_types:
create_test_class(globals(), XPUTestDropoutOp, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册