未验证 提交 6838a187 编写于 作者: T taixiurong 提交者: GitHub

add fp16 unittests for kl2 (#36583)

上级 8c1c72af
...@@ -22,12 +22,14 @@ namespace paddle { ...@@ -22,12 +22,14 @@ namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ScaleXPUKernel : public framework::OpKernel<T> { class ScaleXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in_var = ctx.InputVar("X"); auto* in_var = ctx.InputVar("X");
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var); auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
auto scale = static_cast<T>(ctx.Attr<float>("scale")); auto scale = static_cast<float>(ctx.Attr<float>("scale"));
auto bias = static_cast<T>(ctx.Attr<float>("bias")); auto bias = static_cast<float>(ctx.Attr<float>("bias"));
auto bias_after_scale = ctx.Attr<bool>("bias_after_scale"); auto bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* out_var = ctx.OutputVar("Out"); auto* out_var = ctx.OutputVar("Out");
if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) { if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) {
...@@ -46,9 +48,10 @@ class ScaleXPUKernel : public framework::OpKernel<T> { ...@@ -46,9 +48,10 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
in->dims().to_str().c_str(), in->dims().to_str().c_str(),
out->dims().to_str().c_str())); out->dims().to_str().c_str()));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = int r = xpu::scale(dev_ctx.x_context(),
xpu::scale(dev_ctx.x_context(), in->data<float>(), out->data<float>(), reinterpret_cast<const XPUType*>(in->data<T>()),
in->numel(), bias_after_scale, scale, bias); reinterpret_cast<XPUType*>(out->data<T>()), in->numel(),
bias_after_scale, scale, bias);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU scale kernel return wrong value[%d %s]", platform::errors::External("XPU scale kernel return wrong value[%d %s]",
...@@ -60,7 +63,11 @@ class ScaleXPUKernel : public framework::OpKernel<T> { ...@@ -60,7 +63,11 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
scale, ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext, float>); scale, ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>,
ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext, int64_t>);
#endif #endif
...@@ -184,6 +184,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -184,6 +184,9 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})}
// AddMore // AddMore
}; };
......
...@@ -1709,6 +1709,14 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1709,6 +1709,14 @@ All parameter, weight, gradient are variables in Paddle.
m.def("get_xpu_device_count", platform::GetXPUDeviceCount); m.def("get_xpu_device_count", platform::GetXPUDeviceCount);
m.def("get_xpu_device_version", m.def("get_xpu_device_version",
[](int device_id) { return platform::get_xpu_version(device_id); }); [](int device_id) { return platform::get_xpu_version(device_id); });
m.def("is_float16_supported", [](const platform::XPUPlace &place) -> bool {
// XPUs with Compute Capability > xpu2 support float16 and bfloat16
return platform::get_xpu_version(place.device) > platform::XPUVersion::XPU1;
});
m.def("is_bfloat16_supported", [](const platform::XPUPlace &place) -> bool {
// XPUs with Compute Capability > xpu2 support float16 and bfloat16
return platform::get_xpu_version(place.device) > platform::XPUVersion::XPU1;
});
#endif #endif
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace", R"DOC( py::class_<paddle::platform::CPUPlace>(m, "CPUPlace", R"DOC(
......
...@@ -44,86 +44,33 @@ class XPUOpTest(OpTest): ...@@ -44,86 +44,33 @@ class XPUOpTest(OpTest):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
'''Fix random seeds to remove randomness from tests''' '''Fix random seeds to remove randomness from tests'''
cls._np_rand_state = np.random.get_state() cls.use_xpu = True
cls._py_rand_state = random.getstate() cls.use_mkldnn = False
cls.call_once = False super().setUpClass()
cls.dtype = np.float32
cls.outputs = {}
cls.input_shape_is_large = True
np.random.seed(123)
random.seed(124)
cls._use_system_allocator = _set_use_system_allocator(True)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
"""Restore random seeds""" """Restore random seeds"""
np.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)
_set_use_system_allocator(cls._use_system_allocator)
def is_empty_grad_op(op_type): def is_empty_grad_op(op_type):
all_op_kernels = core._get_all_register_op_kernels() all_op_kernels = core._get_all_register_op_kernels()
grad_op = op_type + '_grad' grad_op = op_type + '_grad'
if grad_op in all_op_kernels.keys(): if grad_op in all_op_kernels.keys():
if is_mkldnn_op_test(): grad_op_kernels = all_op_kernels[grad_op]
grad_op_kernels = all_op_kernels[grad_op] for grad_op_kernel in grad_op_kernels:
for grad_op_kernel in grad_op_kernels: if 'XPU' in grad_op_kernel:
if 'MKLDNN' in grad_op_kernel: return False
return False
else:
return False
return True return True
def is_xpu_op_test(): if cls.dtype == np.float16:
return True place = paddle.XPUPlace(0)
if core.is_float16_supported(place) == False:
def is_mkldnn_op_test(): return
return False super().tearDownClass()
if not hasattr(cls, "op_type"):
raise AssertionError(
"This test do not have op_type in class attrs, "
"please set self.__class__.op_type=the_real_op_type manually.")
# case in NO_FP64_CHECK_GRAD_CASES and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed def _get_places(self):
if not hasattr(cls, "no_need_check_grad") \ places = [fluid.XPUPlace(0)]
and not is_empty_grad_op(cls.op_type): return places
if cls.dtype is None or \
(cls.dtype == np.float16 \
and cls.op_type not in op_accuracy_white_list.NO_FP16_CHECK_GRAD_OP_LIST \
and not hasattr(cls, "exist_check_grad")):
raise AssertionError("This test of %s op needs check_grad." %
cls.op_type)
# check for op test with fp64 precision, but not check mkldnn op test for now
if cls.dtype in [np.float32, np.float64] \
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \
and not hasattr(cls, 'exist_fp64_check_grad') \
and not is_xpu_op_test() \
and not is_mkldnn_op_test() \
and not is_rocm_op_test() \
and not is_npu_op_test():
raise AssertionError(
"This test of %s op needs check_grad with fp64 precision." %
cls.op_type)
if not cls.input_shape_is_large \
and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST:
raise AssertionError(
"Input's shape should be large than or equal to 100 for " +
cls.op_type + " Op.")
def try_call_once(self, data_type):
if not self.call_once:
self.call_once = True
if data_type is not None and \
data_type != np.float32:
raise AssertionError("Unsupport data type %s in xpu" %
data_type)
self.dtype = data_type
def check_output_with_place(self, def check_output_with_place(self,
place, place,
...@@ -133,166 +80,17 @@ class XPUOpTest(OpTest): ...@@ -133,166 +80,17 @@ class XPUOpTest(OpTest):
check_dygraph=True, check_dygraph=True,
inplace_atol=None): inplace_atol=None):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if self.dtype == np.float64 and \ #xpu not support float64
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST: if self.dtype == np.float64:
atol = 0 return
if place == None:
if self.is_bfloat16_op(): place = paddle.XPUPlace(0)
check_dygraph = False
if hasattr(self, 'force_fp32_output') and getattr( if self.dtype == np.float16:
self, 'force_fp32_output'): if core.is_float16_supported(place) == False:
atol = 1e-2 return
else: return super().check_output_with_place(
atol = 2 place, atol, no_check_set, equal_nan, check_dygraph, inplace_atol)
if no_check_set is not None:
if self.op_type not in no_check_set_white_list.no_check_set_white_list:
raise AssertionError(
"no_check_set of op %s must be set to None." % self.op_type)
if check_dygraph:
dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs:
continue
if no_check_set is not None and out_name in no_check_set:
continue
def find_imperative_actual(target_name, dygraph_outs, place):
with fluid.dygraph.base.guard(place=place):
for name in dygraph_outs:
if name == target_name:
return dygraph_outs[name][0]
var_list = dygraph_outs[name]
for i, var in enumerate(var_list):
if var.name == target_name:
return dygraph_outs[name][i]
self.assertTrue(False, "Found failed {} {}".format(
dygraph_outs.keys(), target_name))
def find_actual(target_name, fetch_list):
found = [
i for i, var_name in enumerate(fetch_list)
if var_name == target_name
]
self.assertTrue(
len(found) == 1, "Found {} {}".format(
len(found), target_name))
return found[0]
if out_dup:
sub_out = self.outputs[out_name]
if not isinstance(sub_out, list):
raise AssertionError("sub_out type %s is not list",
type(sub_out))
for item in sub_out:
sub_out_name, expect = item[0], item[1]
if check_dygraph:
imperative_actual = find_imperative_actual(
sub_out_name, dygraph_outs, place)
imperative_actual_t = np.array(imperative_actual.value()
.get_tensor())
idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
expect_t = expect[0] \
if isinstance(expect, tuple) else expect
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place))
if check_dygraph:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place) + " in dygraph mode")
if isinstance(expect, tuple):
self.assertListEqual(
actual.recursive_sequence_lengths(), expect[1],
"Output (" + sub_out_name +
") has different lod at " + str(place))
if check_dygraph:
self.assertListEqual(
imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name +
") has different lod at " + str(place) +
" in dygraph mode")
else:
if check_dygraph:
imperative_actual = find_imperative_actual(
out_name, dygraph_outs, place)
imperative_actual_t = np.array(imperative_actual.value()
.get_tensor())
idx = find_actual(out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
expect = self.outputs[out_name]
expect_t = expect[0] if isinstance(expect, tuple) else expect
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(actual_t) + " in class " + self.__class__.__name__ + " "
+ str(atol) + " " + str(expect_t - actual_t))
if check_dygraph:
if six.moves.reduce(
lambda x, y: x * y, imperative_actual_t.shape,
1) == 0 and six.moves.reduce(
lambda x, y: x * y, expect_t.shape, 1) == 0:
pass
else:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " +
str(place) + "\nExpect " + str(expect_t) + "\n" +
"But Got" + str(imperative_actual_t) + " in class "
+ self.__class__.__name__)
if isinstance(expect, tuple):
self.assertListEqual(actual.recursive_sequence_lengths(),
expect[1], "Output (" + out_name +
") has different lod at " + str(place))
if check_dygraph:
self.assertListEqual(
imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " +
str(place) + " in dygraph mode")
# Note(zhiqiu): inplace_atol should be only set when op doesn't ensure
# computational consistency.
# For example, group_norm uses AtomicAdd on CUDAPlace, which do not ensure
# computation order when multiple threads write the same address. So the
# result of group_norm is non-deterministic when datatype is float.
# When inplace_atol is not None, the inplace check uses numpy.allclose
# to check inplace result instead of numpy.array_equal.
if inplace_atol is not None:
warnings.warn(
"inplace_atol should only be set when op doesn't ensure computational consistency, please check it!"
)
# Check inplace for given op, its grad op, its grad_grad op, etc.
# No effect on original OpTest
# Currently not support ParallelExecutor on XPUPlace.
if not paddle.is_compiled_with_xpu():
self.check_inplace_output_with_place(
place, no_check_set=no_check_set, inplace_atol=inplace_atol)
if check_dygraph:
return outs
else:
return outs
def check_grad_with_place(self, def check_grad_with_place(self,
place, place,
...@@ -303,8 +101,25 @@ class XPUOpTest(OpTest): ...@@ -303,8 +101,25 @@ class XPUOpTest(OpTest):
in_place=False, in_place=False,
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None, user_defined_grads=None,
check_dygraph=True): user_defined_grad_outputs=None,
place = paddle.XPUPlace(0) check_dygraph=True,
numeric_place=None):
if place == None:
place = paddle.XPUPlace(0)
if self.dtype == np.float64:
return
if self.dtype == np.float16:
if core.is_float16_supported(place) == False:
return
if self.dtype == np.float16:
return super().check_grad_with_place(
place, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, user_defined_grads, check_dygraph)
a1 = self.get_grad_with_place( a1 = self.get_grad_with_place(
place, inputs_to_check, output_names, no_grad_set=no_grad_set) place, inputs_to_check, output_names, no_grad_set=no_grad_set)
a2 = self.get_grad_with_place( a2 = self.get_grad_with_place(
......
...@@ -28,17 +28,12 @@ paddle.enable_static() ...@@ -28,17 +28,12 @@ paddle.enable_static()
@unittest.skipIf(not paddle.is_compiled_with_xpu(), @unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU") "core is not compiled with XPU")
class TestElementwiseAddOp(XPUOpTest): class TestElementwiseAddOp(XPUOpTest):
def init_kernel_type(self):
self.use_mkldnn = False
def setUp(self): def setUp(self):
self.op_type = "elementwise_add" self.op_type = "elementwise_add"
self.init_dtype() self.init_dtype()
self.init_input_output() self.init_input_output()
self.init_kernel_type()
self.init_axis() self.init_axis()
self.use_xpu = True self.init_max_relative_error()
self.inputs = { self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x), 'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y) 'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
...@@ -55,7 +50,9 @@ class TestElementwiseAddOp(XPUOpTest): ...@@ -55,7 +50,9 @@ class TestElementwiseAddOp(XPUOpTest):
if paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0) place = paddle.XPUPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
place, ['X', 'Y'], 'Out', max_relative_error=0.006) place, ['X', 'Y'],
'Out',
max_relative_error=self.max_relative_error)
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
if paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
...@@ -64,7 +61,7 @@ class TestElementwiseAddOp(XPUOpTest): ...@@ -64,7 +61,7 @@ class TestElementwiseAddOp(XPUOpTest):
place, ['Y'], place, ['Y'],
'Out', 'Out',
no_grad_set=set("X"), no_grad_set=set("X"),
max_relative_error=0.006) max_relative_error=self.max_relative_error)
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
if paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
...@@ -73,7 +70,7 @@ class TestElementwiseAddOp(XPUOpTest): ...@@ -73,7 +70,7 @@ class TestElementwiseAddOp(XPUOpTest):
place, ['X'], place, ['X'],
'Out', 'Out',
no_grad_set=set("Y"), no_grad_set=set("Y"),
max_relative_error=0.006) max_relative_error=self.max_relative_error)
def init_input_output(self): def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
...@@ -86,6 +83,9 @@ class TestElementwiseAddOp(XPUOpTest): ...@@ -86,6 +83,9 @@ class TestElementwiseAddOp(XPUOpTest):
def init_axis(self): def init_axis(self):
self.axis = -1 self.axis = -1
def init_max_relative_error(self):
self.max_relative_error = 0.006
@unittest.skipIf(not paddle.is_compiled_with_xpu(), @unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU") "core is not compiled with XPU")
...@@ -337,5 +337,170 @@ class TestAddOp(unittest.TestCase): ...@@ -337,5 +337,170 @@ class TestAddOp(unittest.TestCase):
self.assertEqual((np_z == z_expected).all(), True) self.assertEqual((np_z == z_expected).all(), True)
######## fp16 test
class TestElementwiseAddFP16Op(TestElementwiseAddOp):
def init_dtype(self):
self.dtype = np.float16
def init_max_relative_error(self):
self.max_relative_error = 0.01
class TestElementwiseAddOp_scalarFP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_scalar2FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1, 1).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_VectorFP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.random((100, )).astype(self.dtype)
self.y = np.random.random((100, )).astype(self.dtype)
self.out = np.add(self.x, self.y)
class TestElementwiseAddOp_broadcast_0FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(100, 2, 3).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(100, 1, 1)
def init_axis(self):
self.axis = 0
class TestElementwiseAddOp_broadcast_1FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(2, 100, 3).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 100, 1)
def init_axis(self):
self.axis = 1
class TestElementwiseAddOp_broadcast_2FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(2, 3, 100).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1, 100)
class TestElementwiseAddOp_broadcast_3FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(2, 10, 12, 3).astype(self.dtype)
self.y = np.random.rand(10, 12).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 10, 12, 1)
def init_axis(self):
self.axis = 1
class TestElementwiseAddOp_broadcast_4FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(100, 2, 3, 4).astype(self.dtype)
self.y = np.random.rand(100, 1).astype(self.dtype)
self.out = self.x + self.y.reshape(100, 1, 1, 1)
def init_axis(self):
self.axis = 0
class TestElementwiseAddOp_broadcast_5FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(10, 3, 12).astype(self.dtype)
self.y = np.random.rand(10, 1, 12).astype(self.dtype)
self.out = self.x + self.y
def init_dtype(self):
self.dtype = np.float16
class TestElementwiseAddOp_broadcast_6FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype)
self.y = np.random.rand(2, 12, 1, 5).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_broadcast_7FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(1, 1, 20, 5).astype(self.dtype)
self.y = np.random.rand(20, 5, 1, 1).astype(self.dtype)
self.out = self.x + self.y
def init_dtype(self):
self.dtype = np.float16
class TestElementwiseAddOp_rowwise_add_0FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(2, 10, 12).astype(self.dtype)
self.y = np.random.rand(10, 12).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 10, 12)
def init_axis(self):
self.axis = 1
class TestElementwiseAddOp_rowwise_add_1FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(100, 1).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1)
def init_axis(self):
self.axis = 1
class TestElementwiseAddOp_channelwise_addFP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(100, 2, 3).astype(self.dtype)
self.y = np.random.rand(100, 1, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseAddOp_commonuse_add1FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(2, 3, 100).astype(self.dtype)
self.y = np.random.rand(1, 1, 100).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseAddOp_commonuse_add2FP16(TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(10, 3, 1, 4).astype(self.dtype)
self.y = np.random.rand(10, 1, 12, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseAddOp_xsize_lessthan_ysize_addFP16(
TestElementwiseAddFP16Op):
def init_input_output(self):
self.x = np.random.rand(10, 12).astype(self.dtype)
self.y = np.random.rand(2, 3, 10, 12).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = 2
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -127,45 +127,23 @@ class Generator(object): ...@@ -127,45 +127,23 @@ class Generator(object):
self.outputs = {'Out': Out} self.outputs = {'Out': Out}
def test_check_output(self): def test_check_output(self):
place = paddle.XPUPlace(0)
if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( self.check_output_with_place(place, atol=1e-3)
self.inputs['Y'].shape) and self.inputs['X'].shape[
0] == self.inputs['Y'].shape[0]:
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=1e-3)
def test_check_grad_normal(self): def test_check_grad_normal(self):
place = paddle.XPUPlace(0)
if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( self.check_grad_with_place(
self.inputs['Y'].shape) and self.inputs['X'].shape[ place, ['X', 'Y'], 'Out', max_relative_error=5e-2)
0] == self.inputs['Y'].shape[0]:
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', max_relative_error=5e-2)
def test_check_grad_ignore_x(self): def test_check_grad_ignore_x(self):
place = paddle.XPUPlace(0)
if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( self.check_grad_with_place(
self.inputs['Y'].shape) and self.inputs['X'].shape[ place, ['Y'], 'Out', max_relative_error=5e-2, no_grad_set=set("X"))
0] == self.inputs['Y'].shape[0]:
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Y'],
'Out',
max_relative_error=5e-2,
no_grad_set=set("X"))
def test_check_grad_ignore_y(self): def test_check_grad_ignore_y(self):
place = paddle.XPUPlace(0)
if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( self.check_grad_with_place(
self.inputs['Y'].shape) and self.inputs['X'].shape[ place, ['X'], 'Out', max_relative_error=5e-2, no_grad_set=set('Y'))
0] == self.inputs['Y'].shape[0]:
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X'],
'Out',
max_relative_error=5e-2,
no_grad_set=set('Y'))
class TestMatmulOpError(unittest.TestCase): class TestMatmulOpError(unittest.TestCase):
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import numpy as np import numpy as np
import sys import sys
sys.path.append("..") sys.path.append("..")
from op_test_xpu import XPUOpTest
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -27,22 +28,27 @@ from paddle.fluid import Program, program_guard ...@@ -27,22 +28,27 @@ from paddle.fluid import Program, program_guard
np.random.seed(10) np.random.seed(10)
class TestMeanOp(OpTest): class TestMeanOp(XPUOpTest):
def setUp(self): def setUp(self):
self.op_type = "mean" self.op_type = "mean"
self.dtype = np.float64
self.init_dtype_type() self.init_dtype_type()
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)} self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)}
self.outputs = {'Out': np.mean(self.inputs["X"])} self.outputs = {'Out': np.mean(self.inputs["X"]).astype(np.float16)}
def init_dtype_type(self): def init_dtype_type(self):
pass self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=2e-3)
def test_checkout_grad(self): def test_checkout_grad(self):
self.check_grad(['X'], 'Out') if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
class TestMeanOpError(unittest.TestCase): class TestMeanOpError(unittest.TestCase):
...@@ -77,5 +83,23 @@ class TestXPUMeanOp(TestMeanOp): ...@@ -77,5 +83,23 @@ class TestXPUMeanOp(TestMeanOp):
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out')
class TestXPUMeanOpFp16(TestMeanOp):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_checkout_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=1.e1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -18,27 +18,27 @@ import unittest ...@@ -18,27 +18,27 @@ import unittest
import numpy as np import numpy as np
import sys import sys
sys.path.append("..") sys.path.append("..")
from op_test import OpTest from op_test_xpu import XPUOpTest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
import paddle import paddle
from paddle.static import Program, program_guard
paddle.enable_static()
class TestXPUScaleOp(XPUOpTest):
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUScaleOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scale" self.op_type = "scale"
self.dtype = np.float32 self.init_type()
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)} self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)}
self.attrs = {'scale': -2.3, 'use_xpu': True} self.attrs = {'scale': -2.3, 'use_xpu': True}
self.outputs = { self.outputs = {
'Out': self.inputs['X'] * self.dtype(self.attrs['scale']) 'Out': self.inputs['X'] * self.dtype(self.attrs['scale'])
} }
def init_type(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
if paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0) place = paddle.XPUPlace(0)
...@@ -50,5 +50,63 @@ class TestXPUScaleOp(OpTest): ...@@ -50,5 +50,63 @@ class TestXPUScaleOp(OpTest):
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out')
# class TestXPUScaleOpInt64(TestXPUScaleOp):
# def init_type(self):
# self.dtype = np.int64
class TestScaleFp16Op(TestXPUScaleOp):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self):
place = core.XPUPlace(0)
self.check_output_with_place(place, atol=0.002)
def test_check_grad(self):
place = core.XPUPlace(0)
self.check_grad_with_place(place, ["X"], "Out", max_relative_error=0.05)
class TestScaleApiStatic(unittest.TestCase):
def _executed_api(self, x, scale=1.0, bias=0.0):
return paddle.scale(x, scale, bias)
def test_api(self):
paddle.enable_static()
input = np.random.random([2, 25]).astype("float32")
main_prog = Program()
with program_guard(main_prog, Program()):
x = paddle.static.data(name="x", shape=[2, 25], dtype="float32")
out = self._executed_api(x, scale=2.0, bias=3.0)
exe = paddle.static.Executor(place=paddle.CPUPlace())
out = exe.run(main_prog, feed={"x": input}, fetch_list=[out])
self.assertEqual(np.array_equal(out[0], input * 2.0 + 3.0), True)
class TestScaleInplaceApiStatic(TestScaleApiStatic):
def _executed_api(self, x, scale=1.0, bias=0.0):
return x.scale_(scale, bias)
class TestScaleApiDygraph(unittest.TestCase):
def _executed_api(self, x, scale=1.0, bias=0.0):
return paddle.scale(x, scale, bias)
def test_api(self):
paddle.disable_static()
input = np.random.random([2, 25]).astype("float32")
x = paddle.to_tensor(input)
out = self._executed_api(x, scale=2.0, bias=3.0)
self.assertEqual(np.array_equal(out.numpy(), input * 2.0 + 3.0), True)
paddle.enable_static()
class TestScaleInplaceApiDygraph(TestScaleApiDygraph):
def _executed_api(self, x, scale=1.0, bias=0.0):
return x.scale_(scale, bias)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -13,27 +13,26 @@ ...@@ -13,27 +13,26 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import unittest
import numpy as np
import sys import sys
sys.path.append("..") sys.path.append("..")
from op_test import OpTest import unittest
import numpy as np
from op_test_xpu import XPUOpTest
import paddle import paddle
from paddle import enable_static
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
import paddle from paddle.fluid.tests.unittests.op_test import (
OpTest, convert_float_to_uint16, convert_uint16_to_float)
paddle.enable_static() paddle.enable_static()
@unittest.skipIf(not paddle.is_compiled_with_xpu(), class TestSumOp(XPUOpTest):
"core is not compiled with XPU")
class TestXPUSumOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "sum" self.op_type = "sum"
self.use_mkldnn = False self.init_kernel_type()
self.init_kernel_type() self.init_kernel_type()
x0 = np.random.random((3, 40)).astype(self.dtype) x0 = np.random.random((3, 40)).astype(self.dtype)
x1 = np.random.random((3, 40)).astype(self.dtype) x1 = np.random.random((3, 40)).astype(self.dtype)
...@@ -41,21 +40,147 @@ class TestXPUSumOp(OpTest): ...@@ -41,21 +40,147 @@ class TestXPUSumOp(OpTest):
self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]} self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
y = x0 + x1 + x2 y = x0 + x1 + x2
self.outputs = {'Out': y} self.outputs = {'Out': y}
self.attrs = {'use_mkldnn': self.use_mkldnn, 'use_xpu': True}
def init_kernel_type(self): def init_kernel_type(self):
self.dtype = np.float32 self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
if paddle.is_compiled_with_xpu(): self.check_output()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self): def test_check_grad(self):
if paddle.is_compiled_with_xpu(): self.check_grad(['x0'], 'Out')
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['x0'], 'Out')
#----------- test fp16 -----------
class TestFP16SumOp(TestSumOp):
def init_kernel_type(self):
self.dtype = np.float16
def test_check_output(self):
place = core.XPUPlace(0)
# if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
# FIXME: Because of the precision fp16, max_relative_error
# should be 0.15 here.
def test_check_grad(self):
place = core.XPUPlace(0)
# if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['x0'], 'Out', max_relative_error=0.15)
def create_test_sum_fp16_class(parent):
class TestSumFp16Case(parent):
def init_kernel_type(self):
self.dtype = np.float16
def test_w_is_selected_rows(self):
place = core.XPUPlace(0)
# if core.is_float16_supported(place):
for inplace in [True, False]:
self.check_with_place(place, inplace)
cls_name = "{0}_{1}".format(parent.__name__, "SumFp16Test")
TestSumFp16Case.__name__ = cls_name
globals()[cls_name] = TestSumFp16Case
class API_Test_Add_n(unittest.TestCase):
def test_api(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input0 = fluid.layers.fill_constant(
shape=[2, 3], dtype='int64', value=5)
input1 = fluid.layers.fill_constant(
shape=[2, 3], dtype='int64', value=3)
expected_result = np.empty((2, 3))
expected_result.fill(8)
sum_value = paddle.add_n([input0, input1])
exe = fluid.Executor(fluid.XPUPlace(0))
result = exe.run(fetch_list=[sum_value])
self.assertEqual((result == expected_result).all(), True)
with fluid.dygraph.guard():
input0 = paddle.ones(shape=[2, 3], dtype='float32')
expected_result = np.empty((2, 3))
expected_result.fill(2)
sum_value = paddle.add_n([input0, input0])
self.assertEqual((sum_value.numpy() == expected_result).all(), True)
class TestRaiseSumError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.sum([11, 22])
self.assertRaises(TypeError, test_type)
def test_dtype():
data1 = fluid.data(name="input1", shape=[10], dtype="int8")
data2 = fluid.data(name="input2", shape=[10], dtype="int8")
fluid.layers.sum([data1, data2])
self.assertRaises(TypeError, test_dtype)
def test_dtype1():
data1 = fluid.data(name="input1", shape=[10], dtype="int8")
fluid.layers.sum(data1)
self.assertRaises(TypeError, test_dtype1)
class TestRaiseSumsError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.sums([11, 22])
self.assertRaises(TypeError, test_type)
def test_dtype():
data1 = fluid.data(name="input1", shape=[10], dtype="int8")
data2 = fluid.data(name="input2", shape=[10], dtype="int8")
fluid.layers.sums([data1, data2])
self.assertRaises(TypeError, test_dtype)
def test_dtype1():
data1 = fluid.data(name="input1", shape=[10], dtype="int8")
fluid.layers.sums(data1)
self.assertRaises(TypeError, test_dtype1)
def test_out_type():
data1 = fluid.data(name="input1", shape=[10], dtype="flaot32")
data2 = fluid.data(name="input2", shape=[10], dtype="float32")
fluid.layers.sums([data1, data2], out=[10])
self.assertRaises(TypeError, test_out_type)
def test_out_dtype():
data1 = fluid.data(name="input1", shape=[10], dtype="flaot32")
data2 = fluid.data(name="input2", shape=[10], dtype="float32")
out = fluid.data(name="out", shape=[10], dtype="int8")
fluid.layers.sums([data1, data2], out=out)
self.assertRaises(TypeError, test_out_dtype)
class TestSumOpError(unittest.TestCase):
def test_errors(self):
def test_empty_list_input():
with fluid.dygraph.guard():
fluid.core.ops.sum([])
def test_list_of_none_input():
with fluid.dygraph.guard():
fluid.core.ops.sum([None])
self.assertRaises(Exception, test_empty_list_input)
self.assertRaises(Exception, test_list_of_none_input)
if __name__ == "__main__": if __name__ == "__main__":
enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册