diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 947b0c1d30561c91910b33fe044f326efab97101..58614e38e990d37fa17d810564c59274545fc3d5 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -378,7 +378,6 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT16, phi::DataType::FLOAT32, phi::DataType::BOOL})}, - {"index_select", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, @@ -419,7 +418,8 @@ XPUOpMap& get_kl2_ops() { {"logical_or", XPUKernelSet({phi::DataType::BOOL})}, {"logical_xor", XPUKernelSet({phi::DataType::BOOL})}, {"lookup_table_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"lookup_table_v2", XPUKernelSet({phi::DataType::FLOAT32})}, + {"lookup_table_v2", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"masked_select", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, @@ -577,7 +577,10 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32})}, {"sigmoid_cross_entropy_with_logits", XPUKernelSet({phi::DataType::FLOAT32})}, - {"shape", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64})}, + {"shape", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT64, + phi::DataType::FLOAT16})}, {"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, {"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"sign", XPUKernelSet({phi::DataType::FLOAT32})}, diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index 9069db5adbe7e2e36b1fde86d7e21d0141425eda..d102b4e88284811be2e7c3dc50fed45f8b404f4d 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -179,6 +179,7 @@ PD_REGISTER_KERNEL(assign_value, bool, int, float, + double, int64_t, phi::dtype::float16) {} #endif diff --git a/paddle/phi/kernels/shape_kernel.cc b/paddle/phi/kernels/shape_kernel.cc index 2c2b41e3c66fc7d192d28f06633ac72bf02c35b2..b8aa8718a5598fec74ecfbe4f73d631653177386 100644 --- a/paddle/phi/kernels/shape_kernel.cc +++ b/paddle/phi/kernels/shape_kernel.cc @@ -77,7 +77,8 @@ PD_REGISTER_KERNEL(shape, int, int64_t, float, - double) { + double, + phi::dtype::float16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } #endif diff --git a/paddle/phi/kernels/xpu/embedding_kernel.cc b/paddle/phi/kernels/xpu/embedding_kernel.cc index ace2116cdc963209efa67a4c905cedfc27621f31..99faf8b58196618fbebad35aac3d08246801645c 100644 --- a/paddle/phi/kernels/xpu/embedding_kernel.cc +++ b/paddle/phi/kernels/xpu/embedding_kernel.cc @@ -25,6 +25,8 @@ void EmbeddingKernel(const Context &ctx, const DenseTensor &weight, int64_t padding_idx, DenseTensor *out) { + using XPUType = typename XPUTypeTrait::Type; + auto *ids_t = &inputx; // int auto *output_t = out; // float PADDLE_ENFORCE_EQ( @@ -66,18 +68,23 @@ void EmbeddingKernel(const Context &ctx, size_t xm = table_t->dims()[0]; size_t n = table_t->dims()[1]; - int r = xpu::embedding(dev_ctx.x_context(), - table, - ids, - output, - xm, - n, - ym, - static_cast(padding_idx)); + int r = xpu::embedding(dev_ctx.x_context(), + reinterpret_cast(table), + ids, + reinterpret_cast(output), + xm, + n, + ym, + padding_idx); PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); } } // namespace phi -PD_REGISTER_KERNEL(embedding, XPU, ALL_LAYOUT, phi::EmbeddingKernel, float) {} +PD_REGISTER_KERNEL(embedding, + XPU, + ALL_LAYOUT, + phi::EmbeddingKernel, + float, + phi::dtype::float16) {} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_lookup_table_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_lookup_table_v2_op_xpu.py index fcbf724ccbc4254f73eaafb92c879aa259296407..8cb36afb2e49045e6180e48ffdb66cb3b0fe07e1 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_lookup_table_v2_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_lookup_table_v2_op_xpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,204 +18,94 @@ import unittest import numpy as np sys.path.append("..") -from op_test import OpTest, skip_check_grad_ci + +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) import paddle -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid import Program, program_guard -from paddle.fluid.op import Operator paddle.enable_static() -class TestLookupTableOp(OpTest): - def setUp(self): - self.op_type = "lookup_table_v2" - table = np.random.random((17, 31)).astype("float64") - ids = np.random.randint(0, 17, 4).astype("int64") - self.inputs = {'W': table, 'Ids': ids} - self.outputs = {'Out': table[ids]} - - def test_check_output_with_place(self): - self.check_output_with_place(place=paddle.XPUPlace(0)) - - def test_check_grad(self): - - self.check_grad_with_place( - inputs_to_check=['W'], - output_names='Out', - no_grad_set=set('Ids'), - place=paddle.XPUPlace(0), - in_place=True, - ) - - -class TestLookupTableOpWithTensorIds(OpTest): - def setUp(self): - self.op_type = "lookup_table_v2" - table = np.random.random((17, 31)).astype("float64") - ids = np.random.randint(low=0, high=17, size=(2, 4, 5)).astype("int32") - self.inputs = {'W': table, 'Ids': ids} - self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))} - - def test_check_output(self): - self.check_output_with_place(place=paddle.XPUPlace(0)) - - def test_check_grad(self): - self.check_grad_with_place( - inputs_to_check=['W'], - output_names='Out', - no_grad_set=set('Ids'), - place=paddle.XPUPlace(0), - in_place=True, - ) - - -@skip_check_grad_ci( - reason="Since paddings are not trainable and fixed in forward," - "the gradient of paddings makes no sense and we don't " - "test the gradient here." -) -class TestLookupTableOpWithPadding(TestLookupTableOp): - def test_check_output(self): - ids = np.squeeze(self.inputs['Ids']) - padding_idx = np.random.choice(ids, 1)[0] - self.outputs['Out'][ids == padding_idx] = np.zeros(31) - self.attrs = {'padding_idx': int(padding_idx)} - self.check_output_with_place(place=paddle.XPUPlace(0)) - - -@skip_check_grad_ci( - reason="Since paddings are not trainable and fixed in forward," - "the gradient of paddings makes no sense and we don't " - "test the gradient here." -) -class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): - def test_check_output(self): - ids = self.inputs['Ids'] - flatten_idx = ids.flatten() - padding_idx = np.random.choice(flatten_idx, 1)[0] - self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) - self.attrs = {'padding_idx': padding_idx} - self.check_output_with_place(place=paddle.XPUPlace(0)) - - -class TestLookupTableWIsSelectedRows(unittest.TestCase): - def prepare_ids(self, scope, place): - ids_tensor = scope.var('Ids').get_tensor() - ids_array = np.array([0, 4, 3, 5]).astype("int64") - ids_tensor.set(ids_array, place) - return ids_array - - def prepare_w(self, scope, place): - rows = [0, 1, 2, 3, 4, 5, 6] - row_numel = 12 - w_selected_rows = scope.var('W') - w_array = np.ones((len(rows), row_numel)).astype("float32") - for i in range(len(rows)): - w_array[i] *= i - w_tensor = w_selected_rows.get_tensor() - w_tensor.set(w_array, place) - - def create_out_tensor(self, scope, place): - return scope.var('Out').get_tensor() - - def check_result(self, ids_array, result_array): - # all(): return True if all elements of the iterable are true (or if the iterable is empty) - for idx, row in enumerate(ids_array): - assert (row == result_array[idx]).all() - - def check_with_place(self, place): - scope = core.Scope() - ids_array = self.prepare_ids(scope, place) - - self.prepare_w(scope, place) - - out_tensor = self.create_out_tensor(scope, place) - - # create and run lookup_table_v2 operator - lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out') - lookup_table.run(scope, place) - - # get result from Out - result_array = np.array(out_tensor) - - self.check_result(ids_array, result_array) - - def test_w_is_selected_rows(self): - places = [paddle.XPUPlace(0)] - for place in places: - self.check_with_place(place) - - -class TestLookupTableWithTensorIdsWIsSelectedRows( - TestLookupTableWIsSelectedRows -): - def prepare_ids(self, scope, place): - ids_tensor = scope.var('Ids').get_tensor() - ids_array = np.random.randint(low=0, high=6, size=(2, 4, 3)).astype( - "int64" - ) - ids_tensor.set(ids_array, place) - return ids_array - - def check_result(self, ids_array, result_array): - for idx, row in np.ndenumerate(ids_array): - assert (row == result_array[idx]).all() - - -class TestLookupTableApi(unittest.TestCase): - def test_api(self): - x = paddle.static.data(name='x', shape=[-1, 20], dtype='int64') - emb = paddle.static.nn.embedding(input=x, size=[128, 64]) - - place = paddle.XPUPlace(0) - x_data = np.random.randint(0, 127, [2, 20]).astype("int64") - - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - ret = exe.run( - feed={ - 'x': x_data, - }, - fetch_list=[emb], - return_numpy=False, - ) - - -class TestEmbedOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - input_data = np.random.randint(0, 10, (4, 6)).astype("int64") - - def test_Variable(): - # the input type must be Variable - paddle.static.nn.embedding(input=input_data, size=(10, 64)) - - self.assertRaises(TypeError, test_Variable) - - def test_input_dtype(): - # the input dtype must be int64 - input = fluid.data(name='x1', shape=[4, 6], dtype='float32') - paddle.static.nn.embedding(input=input, size=(10, 64)) - - self.assertRaises(TypeError, test_input_dtype) - - def test_param_dtype(): - # dtype must be float32 or float64 - input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64') - paddle.static.nn.embedding( - input=input2, size=(10, 64), dtype='int64' - ) - - self.assertRaises(TypeError, test_param_dtype) - input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64') - paddle.static.nn.embedding( - input=input3, size=(10, 64), dtype='float16' - ) +class XPUTestLookupTableOP(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'lookup_table_v2' + self.use_dynamic_create_class = False + + class TestLookupTableOPBase(XPUOpTest): + def setUp(self): + self.place = paddle.XPUPlace(0) + self.init_dtype() + self.op_type = 'lookup_table_v2' + self.init_config() + self.set_case() + def set_case(self): + table = np.random.random(self.input_shape).astype(self.dtype) + ids = np.random.randint(0, self.id_range, self.id_count).astype( + self.id_dtype + ) + self.inputs = {'W': table, 'Ids': ids} + self.outputs = {'Out': table[ids]} + + def init_dtype(self): + self.dtype = self.in_type + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def init_config(self): + self.input_shape = (17, 31) + self.id_range = 17 + self.id_count = 4 + self.id_dtype = "int32" + + class XPUTestLookupTable1(TestLookupTableOPBase): + def init_config(self): + self.input_shape = (25, 52) + self.id_range = 25 + self.id_count = 14 + self.id_dtype = "int64" + + class TestLookupTableOpWithTensorIds(TestLookupTableOPBase): + def set_case(self): + table = np.random.random((17, 31)).astype(self.dtype) + ids = np.random.randint(low=0, high=17, size=(2, 4, 5)).astype( + self.id_dtype + ) + self.inputs = {'W': table, 'Ids': ids} + self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))} + + class TestLookupTableOpWithPadding(TestLookupTableOPBase): + def test_check_output(self): + ids = np.squeeze(self.inputs['Ids']) + padding_idx = np.random.choice(ids, 1)[0] + self.outputs['Out'][ids == padding_idx] = np.zeros(31) + self.attrs = {'padding_idx': int(padding_idx)} + self.check_output_with_place(self.place) + + class TestLookupTableOpWithTensorIdsAndPadding( + TestLookupTableOpWithTensorIds + ): + def test_check_output(self): + ids = self.inputs['Ids'] + flatten_idx = ids.flatten() + padding_idx = np.random.choice(flatten_idx, 1)[0] + self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) + self.attrs = {'padding_idx': padding_idx} + self.check_output_with_place(self.place) + + +support_types = get_xpu_op_support_types('lookup_table_v2') +for stype in support_types: + create_test_class(globals(), XPUTestLookupTableOP, stype) if __name__ == "__main__": - paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_shape_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_shape_op_xpu.py index 8da888e1a41273b9804b773e5086a78748aca39c..1c25661c138ade58b94b0d11db4b7baa60271f04 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_shape_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_shape_op_xpu.py @@ -79,12 +79,7 @@ class XPUTestShapeOp(XPUOpTestWrapper): self.dtype = self.in_type def get_places(self): - places = [core.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(core.CUDAPlace(0)) - if core.is_compiled_with_xpu(): - places.append(core.XPUPlace(0)) - return places + return [core.CPUPlace(), core.XPUPlace(0)] def check_with_place(self, place): scope = core.Scope() @@ -110,7 +105,14 @@ class XPUTestShapeOp(XPUOpTestWrapper): def test_check_output(self): for place in self.get_places(): - self.check_with_place(place) + if ( + type(place) is paddle.fluid.libpaddle.CPUPlace + and self.dtype == np.float16 + ): + # fp16 not available on cpu + pass + else: + self.check_with_place(place) support_types = get_xpu_op_support_types("shape") diff --git a/tools/check_file_diff_approvals.sh b/tools/check_file_diff_approvals.sh index a5867d1c0cdbc6b0a8a2433eb0ea61af46a49def..37f07940e9b0b36b7ea3ee75915c7de6c75356de 100644 --- a/tools/check_file_diff_approvals.sh +++ b/tools/check_file_diff_approvals.sh @@ -463,7 +463,7 @@ if [ "${UNITTEST_FILE_CHANGED}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then if [ "${ERROR_LINES}" != "" ]; then ERROR_LINES=${ERROR_LINES//+/'\n+\t'} echo_line="It is an Op accuracy problem, please take care of it. You must have one RD (zhangting2020 (Recommend), luotao1 or phlrain, qili93, QingshuChen) approval for the usage (either add or delete) of @skip_check_grad_ci. For more information, please refer to: https://github.com/PaddlePaddle/Paddle/wiki/Gradient-Check-Is-Required-for-Op-Test. The corresponding lines are as follows:\n${ERROR_LINES}\n" - check_approval 1 26615455 6836917 43953930 16605440 + check_approval 1 26615455 6836917 43953930 16605440 2002279 fi fi