diff --git a/paddle/fluid/operators/dequantize_abs_max_op.cc b/paddle/fluid/operators/dequantize_abs_max_op.cc index c8bca25b6b0f0e34fbfad5c2192faac24bd22ffa..aee468e05e18263dec8951bfe4e2357176f387ad 100644 --- a/paddle/fluid/operators/dequantize_abs_max_op.cc +++ b/paddle/fluid/operators/dequantize_abs_max_op.cc @@ -50,6 +50,7 @@ struct DequantizeFunctor { }; template struct DequantizeFunctor; +template struct DequantizeFunctor; class DequantizeMaxAbsOp : public framework::OperatorWithKernel { public: @@ -79,7 +80,7 @@ class DequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(int8 Tensor) The input with int8 type is the " + "(Int Tensor) The input with int8/16 type is the " "low precision tensor."); AddInput("Scale", "(float) The scale in quantization stage."); AddOutput("Out", @@ -108,4 +109,5 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(dequantize_abs_max, - ops::DequantizeMaxAbsKernel); + ops::DequantizeMaxAbsKernel, + ops::DequantizeMaxAbsKernel); diff --git a/paddle/fluid/operators/dequantize_abs_max_op.cu b/paddle/fluid/operators/dequantize_abs_max_op.cu index 6554d4545ad312b55deca18e62068348e420e4c9..e96835a1ea51cd22b11609f8238601c1df4e00bf 100644 --- a/paddle/fluid/operators/dequantize_abs_max_op.cu +++ b/paddle/fluid/operators/dequantize_abs_max_op.cu @@ -45,6 +45,7 @@ struct DequantizeFunctor { }; template struct DequantizeFunctor; +template struct DequantizeFunctor; } // namespace operators } // namespace paddle @@ -52,4 +53,5 @@ template struct DequantizeFunctor; namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(dequantize_abs_max, - ops::DequantizeMaxAbsKernel); + ops::DequantizeMaxAbsKernel, + ops::DequantizeMaxAbsKernel); diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 9a0ce3900acf1c104233aeffb2746c8b4e6f8595..2f3217e628dd0e7314daccc092b64a9eb4a402c1 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -229,6 +229,7 @@ REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad, REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel, ops::LookupTableKernel, ops::LookupTableKernel, + ops::LookupTableKernel, ops::LookupTableKernel); REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel, ops::LookupTableGradKernel, diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 6985b9167571733a3116e2485cf81b3a217f536c..3edea025b2a0440f4fe4f64be7959df3332f3b99 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -227,7 +227,8 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel, ops::LookupTableCUDAKernel, ops::LookupTableCUDAKernel, - ops::LookupTableCUDAKernel); + ops::LookupTableCUDAKernel, + ops::LookupTableCUDAKernel); REGISTER_OP_CUDA_KERNEL(lookup_table_grad, ops::LookupTableGradCUDAKernel, ops::LookupTableGradCUDAKernel, diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index e385d72d1f43fd024158582afe08e704744f744a..74e26626bd5285c3e53b191b714965844f9f30cf 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -103,6 +103,7 @@ class LookupTableKernel : public framework::OpKernel { if (id_index != -1) { if (input_data_type == framework::proto::VarType::INT8 || + input_data_type == framework::proto::VarType::INT16 || input_data_type == framework::proto::VarType::BF16) { memcpy(output + i * row_width, table + id_index * row_width, row_width * sizeof(T)); @@ -130,6 +131,7 @@ class LookupTableKernel : public framework::OpKernel { id_index)); if (input_data_type == framework::proto::VarType::INT8 || + input_data_type == framework::proto::VarType::INT16 || input_data_type == framework::proto::VarType::BF16) { memcpy(output + i * row_width, table + id_index * row_width, row_width * sizeof(T)); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index eab513e24bc8090d30a42cd1149c6bf65d690839..55151c5483a38bdd44ac6260b2534684e0e3a2fd 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -54,6 +54,15 @@ struct CBlas { } }; +template <> +struct CBlas { + template + static void VCOPY(ARGS... args) { + PADDLE_THROW(platform::errors::Unimplemented( + "Blas VCOPY do not supported on CPU, please check your code")); + } +}; + template <> struct CBlas { template diff --git a/python/paddle/fluid/tests/unittests/test_dequantize_abs_max_op.py b/python/paddle/fluid/tests/unittests/test_dequantize_abs_max_op.py index 8a66bdb8d152d014bd441acc544866d4114cddbf..696a60787b754e821985c6bf81242022c914f0c9 100644 --- a/python/paddle/fluid/tests/unittests/test_dequantize_abs_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_dequantize_abs_max_op.py @@ -62,5 +62,12 @@ class TestDequantizeMaxAbsOp5Bits(TestDequantizeMaxAbsOp): self.data_type = "int8" +class TestDequantizeMaxAbsOpInt16(TestDequantizeMaxAbsOp): + def set_args(self): + self.num_bits = 16 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "int16" + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index be1a44120cd1ac37f39e0628fa8f93659be28176..f3546a7c50d97a7d6e83a59acd9202ae0f1aef93 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -316,6 +316,124 @@ class TestLookupTableWithTensorIdsWIsSelectedRowsInt8( assert (row == result_array[idx]).all() +@skip_check_grad_ci(reason="Int16 type only be used in test and inference.") +class TestLookupTableOpInt16(OpTest): + def setUp(self): + self.op_type = "lookup_table" + table = np.random.randint( + low=-128, high=127, size=(17, 31)).astype("int16") + ids = np.random.randint(0, 17, 4).astype("int64") + ids_expand = np.expand_dims(ids, axis=1) + self.inputs = {'W': table, 'Ids': ids_expand} + self.outputs = {'Out': table[ids]} + + def test_check_output(self): + self.check_output() + + +@skip_check_grad_ci(reason="Int16 type only be used in test and inference.") +class TestLookupTableOpWithTensorIdsInt16(OpTest): + def setUp(self): + self.op_type = "lookup_table" + table = np.random.randint( + low=-128, high=127, size=(17, 31)).astype("int16") + ids = np.random.randint( + low=0, high=17, size=(2, 4, 5, 1)).astype("int64") + self.inputs = {'W': table, 'Ids': ids} + self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))} + + def test_check_output(self): + self.check_output() + + +@skip_check_grad_ci(reason="Int16 type only be used in test and inference.") +class TestLookupTableOpWithPaddingInt16(TestLookupTableOpInt16): + def test_check_output(self): + ids = np.squeeze(self.inputs['Ids']) + padding_idx = np.random.choice(ids, 1)[0] + self.outputs['Out'][ids == padding_idx] = np.zeros(31) + self.attrs = {'padding_idx': int(padding_idx)} + self.check_output() + + +@skip_check_grad_ci(reason="Int16 type only be used in test and inference.") +class TestLookupTableOpWithTensorIdsAndPaddingInt16( + TestLookupTableOpWithTensorIdsInt16): + def test_check_output(self): + ids = self.inputs['Ids'] + flatten_idx = ids.flatten() + padding_idx = np.random.choice(flatten_idx, 1)[0] + self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) + self.attrs = {'padding_idx': cpt.long_type(padding_idx)} + self.check_output() + + +class TestLookupTableWIsSelectedRowsInt16(unittest.TestCase): + def prepare_ids(self, scope, place): + ids_tensor = scope.var('Ids').get_tensor() + ids_array = np.array([[0], [4], [3], [5]]).astype("int64") + ids_tensor.set(ids_array, place) + return ids_array + + def prepare_w(self, scope, place): + rows = [0, 1, 2, 3, 4, 5, 6] + row_numel = 12 + + w_selected_rows = scope.var('W').get_selected_rows() + w_selected_rows.set_height(len(rows)) + w_selected_rows.set_rows(rows) + w_array = np.ones((len(rows), row_numel)).astype("int16") + for i in range(len(rows)): + w_array[i] *= i + w_tensor = w_selected_rows.get_tensor() + w_tensor.set(w_array, place) + + def create_out_tensor(self, scope, place): + return scope.var('Out').get_tensor() + + def check_result(self, ids_array, result_array): + for idx, row in enumerate(ids_array): + assert (row[0] == result_array[idx]).all() + + def check_with_place(self, place): + scope = core.Scope() + + ids_array = self.prepare_ids(scope, place) + + self.prepare_w(scope, place) + + out_tensor = self.create_out_tensor(scope, place) + + # create and run lookup_table operator + lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out') + lookup_table.run(scope, place) + + # get result from Out + result_array = np.array(out_tensor) + + self.check_result(ids_array, result_array) + + def test_w_is_selected_rows(self): + places = [core.CPUPlace()] + # currently only support CPU + for place in places: + self.check_with_place(place) + + +class TestLookupTableWithTensorIdsWIsSelectedRowsInt16( + TestLookupTableWIsSelectedRowsInt16): + def prepare_ids(self, scope, place): + ids_tensor = scope.var('Ids').get_tensor() + ids_array = np.random.randint( + low=0, high=6, size=(2, 4, 3, 1)).astype("int64") + ids_tensor.set(ids_array, place) + return ids_array + + def check_result(self, ids_array, result_array): + for idx, row in np.ndenumerate(ids_array): + assert (row == result_array[idx]).all() + + class TestOutDtype(unittest.TestCase): def test_dtype(self): api_fn = F.embedding