From 9d6c8bdfbe38a2dd2a95f153fd7dfd73442d0278 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Wed, 16 Jun 2021 19:58:05 +0200 Subject: [PATCH] Add lookup_table_v2 BF16 op (#33172) * Add lookup_table_v2 BF16 * Reuse lookup table UT * Change op_type to op_version * Remove check_dygraph * Remove skip_check_grad_ci --- paddle/fluid/operators/lookup_table_v2_op.cc | 10 +- paddle/fluid/operators/lookup_table_v2_op.h | 13 +- python/paddle/fluid/input.py | 2 +- .../unittests/test_lookup_table_bf16_op.py | 50 ++++--- .../unittests/test_lookup_table_v2_bf16_op.py | 126 ++++++++++++++++++ tools/static_mode_white_list.py | 1 + 6 files changed, 177 insertions(+), 25 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_lookup_table_v2_bf16_op.py diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index feaa33e28df..f1bb9a985f4 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -197,10 +197,12 @@ REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad, ops::LookupTableV2OpGradVarTypeInference); REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel, - ops::LookupTableV2Kernel); -REGISTER_OP_CPU_KERNEL(lookup_table_v2_grad, - ops::LookupTableV2GradKernel, - ops::LookupTableV2GradKernel); + ops::LookupTableV2Kernel, + ops::LookupTableV2Kernel); +REGISTER_OP_CPU_KERNEL( + lookup_table_v2_grad, ops::LookupTableV2GradKernel, + ops::LookupTableV2GradKernel, + ops::LookupTableV2GradKernel); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(lookup_table_v2) diff --git a/paddle/fluid/operators/lookup_table_v2_op.h b/paddle/fluid/operators/lookup_table_v2_op.h index 877baebdb6a..4e8d96afa03 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.h +++ b/paddle/fluid/operators/lookup_table_v2_op.h @@ -91,8 +91,8 @@ class LookupTableV2Kernel : public framework::OpKernel { int64_t row_width = table_t.value().dims()[1]; const auto *table = table_t.value().data(); auto *output = output_t->mutable_data(context.GetPlace()); + auto input_data_type = table_t.value().type(); - auto blas = math::GetBlas(context); for (int64_t i = 0; i < ids_numel; ++i) { if (padding_idx != kNoPadding && ids[i] == padding_idx) { memset(output + i * row_width, 0, row_width * sizeof(T)); @@ -109,8 +109,15 @@ class LookupTableV2Kernel : public framework::OpKernel { platform::errors::InvalidArgument( "the input key should be exists. But received %d.", id_index)); - blas.VCOPY(row_width, table + id_index * row_width, - output + i * row_width); + + if (input_data_type == framework::proto::VarType::BF16) { + memcpy(output + i * row_width, table + id_index * row_width, + row_width * sizeof(T)); + } else { + auto blas = math::GetBlas(context); + blas.VCOPY(row_width, table + id_index * row_width, + output + i * row_width); + } } } } diff --git a/python/paddle/fluid/input.py b/python/paddle/fluid/input.py index b13419ae36c..d7a8e3bcb82 100644 --- a/python/paddle/fluid/input.py +++ b/python/paddle/fluid/input.py @@ -309,7 +309,7 @@ def embedding(input, helper = LayerHelper('embedding', **locals()) check_variable_and_dtype(input, 'input', ['int64'], 'fluid.embedding') - check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], + check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64', 'uint16'], 'fluid.embedding') remote_prefetch = is_sparse and (not is_distributed) if remote_prefetch: diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_bf16_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_bf16_op.py index b423123160f..0a247b4dbe0 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_bf16_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_bf16_op.py @@ -25,18 +25,21 @@ from paddle.fluid.op import Operator from paddle import enable_static -def _lookup(weights, ids, flat_ids): +def _lookup(weights, ids, flat_ids, op_version="lookup_table"): w_shape = weights.shape - out_shape = list(ids.shape[:-1]) + out_shape = list(ids.shape[:-1]) if op_version is "lookup_table" else list( + ids.shape) out_shape.append(w_shape[-1]) out = weights[flat_ids].reshape(out_shape) return out -def _get_grad(weights, ids, flat_ids): +def _get_grad(weights, ids, flat_ids, op_version="lookup_table"): w_shape = weights.shape w_grad = np.zeros((w_shape), dtype=weights.dtype) - out_grad_shape = (np.prod(ids.shape[:-1]), w_shape[-1]) + out_shape = list(ids.shape[:-1]) if op_version is "lookup_table" else list( + ids.shape) + out_grad_shape = (np.prod(out_shape), w_shape[-1]) out_grad = weights[flat_ids].reshape(out_grad_shape) for i, idx in enumerate(flat_ids): w_grad[idx, :] += out_grad[i] @@ -46,18 +49,24 @@ def _get_grad(weights, ids, flat_ids): @unittest.skipIf(not core.supports_bfloat16(), "place does not support BF16 evaluation") class TestLookupTableBF16Op(OpTest): - def setUp(self): + def init_test(self): self.op_type = "lookup_table" + self.ids_shape = (4, 1) + + def setUp(self): + self.init_test() self.dtype = np.uint16 table = np.random.random((17, 31)).astype("float32") - self.ids = np.random.randint(0, 17, (4, 1)).astype("int64") + self.ids = np.random.randint(0, 17, self.ids_shape).astype("int64") self.flat_ids = self.ids.flatten() self.w_bf16 = convert_float_to_uint16(table) - self.out_bf16 = _lookup(self.w_bf16, self.ids, self.flat_ids) - self.out_fp32 = _lookup(table, self.ids, self.flat_ids) - self.w_grad_fp32 = _get_grad(table, self.ids, self.flat_ids) + self.out_bf16 = _lookup(self.w_bf16, self.ids, self.flat_ids, + self.op_type) + self.out_fp32 = _lookup(table, self.ids, self.flat_ids, self.op_type) + self.w_grad_fp32 = _get_grad(table, self.ids, self.flat_ids, + self.op_type) self.inputs = {'W': self.w_bf16, 'Ids': self.ids} self.outputs = {'Out': self.out_fp32} @@ -79,17 +88,22 @@ class TestLookupTableBF16Op(OpTest): @unittest.skipIf(not core.supports_bfloat16(), "place does not support BF16 evaluation") class TestLookupTableBF16OpIds4D(TestLookupTableBF16Op): - def setUp(self): - super(TestLookupTableBF16OpIds4D, self).setUp() - self.ids = np.random.randint(0, 17, (2, 4, 5, 1)).astype("int64") + def init_test(self): + self.op_type = "lookup_table" + self.ids_shape = (2, 4, 5, 1) @unittest.skipIf(not core.supports_bfloat16(), "place does not support BF16 evaluation") class TestLookupTableBF16OpWIsSelectedRows(unittest.TestCase): + def init_test(self): + self.op_type = "lookup_table" + self.ids_shape = (10, 1) + def setUp(self): + self.init_test() self.ids = np.random.randint( - low=0, high=15, size=(10, 1)).astype("int64") + low=0, high=15, size=self.ids_shape).astype("int64") self.flat_ids = self.ids.flatten() self.w_fp32 = np.random.random((15, 32)).astype("float32") self.w_bf16 = convert_float_to_uint16(self.w_fp32) @@ -120,12 +134,12 @@ class TestLookupTableBF16OpWIsSelectedRows(unittest.TestCase): out_tensor = self.scope.var('Out').get_tensor() # create and run lookup_table operator - lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out') + lookup_table = Operator(self.op_type, W='W', Ids='Ids', Out='Out') lookup_table.run(self.scope, self.place) # get result from Out result_array = np.array(out_tensor) - ref = _lookup(self.w_fp32, self.ids, self.flat_ids) + ref = _lookup(self.w_fp32, self.ids, self.flat_ids, self.op_type) self._check_output(ref, result_array) @@ -133,10 +147,12 @@ class TestLookupTableBF16OpWIsSelectedRows(unittest.TestCase): "place does not support BF16 evaluation") class TestLookupTableBF16OpWIsSelectedRows4DIds( TestLookupTableBF16OpWIsSelectedRows): + def init_test(self): + self.op_type = "lookup_table" + self.ids_shape = (3, 4, 5, 1) + def setUp(self): super(TestLookupTableBF16OpWIsSelectedRows4DIds, self).setUp() - self.ids = np.random.randint( - low=0, high=15, size=(3, 4, 5, 1)).astype("int64") self.flat_ids = self.ids.flatten() diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_bf16_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_bf16_op.py new file mode 100644 index 00000000000..0776ae852d1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_bf16_op.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +from paddle.fluid.tests.unittests.op_test import (skip_check_grad_ci, + convert_uint16_to_float) +from paddle.fluid.tests.unittests.test_lookup_table_bf16_op import ( + _lookup, TestLookupTableBF16Op, TestLookupTableBF16OpIds4D, + TestLookupTableBF16OpWIsSelectedRows, + TestLookupTableBF16OpWIsSelectedRows4DIds) +import paddle.fluid as fluid +import paddle.fluid.core as core + + +class TestLookupTableV2BF16Op(TestLookupTableBF16Op): + def init_test(self): + self.op_type = "lookup_table_v2" + self.ids_shape = (4) + self.mkldnn_data_type = "bfloat16" + + +class TestLookupTableV2BF16OpIds4D(TestLookupTableBF16OpIds4D): + def init_test(self): + self.op_type = "lookup_table_v2" + self.ids_shape = (2, 4, 5) + self.mkldnn_data_type = "bfloat16" + + +class TestLookupTableV2BF16OpWIsSelectedRows( + TestLookupTableBF16OpWIsSelectedRows): + def init_test(self): + self.op_type = "lookup_table_v2" + self.ids_shape = (10) + + +class TestLookupTableV2BF16OpWIsSelectedRows4DIds( + TestLookupTableBF16OpWIsSelectedRows4DIds): + def init_test(self): + self.op_type = "lookup_table_v2" + self.ids_shape = (3, 4, 5) + + +class TestLookupTableBF16OpWithPadding(TestLookupTableV2BF16Op): + def test_check_output(self): + ids = np.squeeze(self.inputs['Ids']) + padding_idx = np.random.choice(ids, 1)[0] + self.outputs['Out'][ids == padding_idx] = np.zeros(31) + self.attrs = {'padding_idx': int(padding_idx)} + self.check_output_with_place(core.CPUPlace()) + + +class TestLookupTableBF16OpIds4DPadding(TestLookupTableV2BF16OpIds4D): + def test_check_output(self): + ids = self.inputs['Ids'] + flatten_idx = ids.flatten() + padding_idx = np.random.choice(flatten_idx, 1)[0] + self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) + self.attrs = {'padding_idx': int(padding_idx)} + self.check_output_with_place(core.CPUPlace()) + + +class TestEmbeddingLayerBF16ConstantInitializer(unittest.TestCase): + """ + Test embedding layer from input api and results for bfloat16 + """ + + def set_initializer(self): + self.initializer = fluid.initializer.Constant(value=self.value) + + def setUp(self): + self.op_type = "lookup_table_v2" + self.ids_shape = [4] + self.w_shape = [10, 64] + self.ids = np.random.randint( + low=0, high=9, size=self.ids_shape).astype("int64") + self.flat_ids = self.ids.flatten() + self.value = 3.0 + self.w_fp32 = np.full(self.w_shape, self.value) + self.place = fluid.CPUPlace() + self.prog = fluid.Program() + self.startup_prog = fluid.Program() + self.set_initializer() + + with fluid.program_guard(self.prog, self.startup_prog): + x = fluid.layers.data(name='x', shape=self.ids_shape, dtype='int64') + self.emb = fluid.input.embedding( + input=x, + size=self.w_shape, + param_attr=fluid.ParamAttr( + name="emb_weight", initializer=self.initializer), + is_sparse=False, + dtype="uint16") # bfloat16 + exe = fluid.Executor(self.place) + exe.run(self.startup_prog) + self.result = exe.run(self.prog, + feed={'x': self.ids}, + fetch_list=['emb_weight', self.emb]) + + def test_embedding_weights(self): + result = convert_uint16_to_float(self.result[0]) + self.assertTrue(np.array_equal(self.w_fp32, result)) + + def test_lookup_results(self): + lookup_result = convert_uint16_to_float(self.result[1]) + lookup_ref = _lookup(self.w_fp32, self.ids, self.flat_ids, self.op_type) + self.assertTrue(np.array_equal(lookup_result, lookup_ref)) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index d1e4680e63f..bc6c2ce0ea2 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -22,6 +22,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_lod_reset_op', 'test_lookup_table_op', 'test_lookup_table_bf16_op', + 'test_lookup_table_v2_bf16_op', 'test_pad2d_op', 'test_scatter_op', 'test_sequence_concat', -- GitLab