From a4a2b77defe3ef1697794ca60911be45078798da Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 19 Mar 2021 03:54:24 +0100 Subject: [PATCH] [oneDNN] lookup_table op with support for BF16 data type. (#31558) --- .../ir/mkldnn/cpu_bfloat16_placement_pass.cc | 4 +- .../ir/mkldnn/cpu_bfloat16_placement_pass.h | 2 +- paddle/fluid/operators/lookup_table_op.cc | 7 +- paddle/fluid/operators/lookup_table_op.h | 6 +- paddle/fluid/operators/math/blas_impl.h | 11 ++ .../paddle/fluid/tests/unittests/op_test.py | 16 +- .../unittests/test_lookup_table_bf16_op.py | 176 ++++++++++++++++++ tools/static_mode_white_list.py | 1 + 8 files changed, 213 insertions(+), 10 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_lookup_table_bf16_op.py diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc index 3d7a9c1107b..531a04e1a0d 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc @@ -53,7 +53,7 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType( gpd(graph, handler); } -void CPUBfloat16PlacementPass::RemoveOrhanedOperators( +void CPUBfloat16PlacementPass::RemoveOrphanedOperators( ir::Graph* graph, int* bfloat16_operators) const { // find orphaned bfloat16 operator that is between two float32 operators // revert mkldnn_data_type attr to float32 @@ -74,7 +74,7 @@ void CPUBfloat16PlacementPass::RemoveOrhanedOperators( void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const { int bfloat16_operators = 0; SetMkldnnDataType(graph, &bfloat16_operators); - RemoveOrhanedOperators(graph, &bfloat16_operators); + RemoveOrphanedOperators(graph, &bfloat16_operators); PrettyLogDetail("--- marked %d operators to bfloat16 ", bfloat16_operators); } diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h index 1911b1a3cb3..53b97f0e972 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h @@ -28,7 +28,7 @@ class CPUBfloat16PlacementPass : public Pass { protected: void SetMkldnnDataType(ir::Graph* graph, int* bfloat16_operators) const; - void RemoveOrhanedOperators(ir::Graph* graph, int* bfloat16_operators) const; + void RemoveOrphanedOperators(ir::Graph* graph, int* bfloat16_operators) const; void ApplyImpl(ir::Graph* graph) const override; }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 1b482235da5..2e8b551ea4e 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/no_need_buffer_vars_inference.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/var_type_inference.h" +#include "paddle/fluid/platform/bfloat16.h" namespace paddle { namespace operators { @@ -222,9 +223,11 @@ REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad, REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel, ops::LookupTableKernel, - ops::LookupTableKernel); + ops::LookupTableKernel, + ops::LookupTableKernel); REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel, - ops::LookupTableGradKernel); + ops::LookupTableGradKernel, + ops::LookupTableGradKernel); /* ========================== register checkpoint ===========================*/ diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 8baa3bccceb..e385d72d1f4 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -102,7 +102,8 @@ class LookupTableKernel : public framework::OpKernel { auto id_index = table_t.GetIndexFromId(ids[i]); if (id_index != -1) { - if (input_data_type == framework::proto::VarType::INT8) { + if (input_data_type == framework::proto::VarType::INT8 || + input_data_type == framework::proto::VarType::BF16) { memcpy(output + i * row_width, table + id_index * row_width, row_width * sizeof(T)); } else { @@ -128,7 +129,8 @@ class LookupTableKernel : public framework::OpKernel { "the input key should be exists. But received %d.", id_index)); - if (input_data_type == framework::proto::VarType::INT8) { + if (input_data_type == framework::proto::VarType::INT8 || + input_data_type == framework::proto::VarType::BF16) { memcpy(output + i * row_width, table + id_index * row_width, row_width * sizeof(T)); } else { diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 4847c1f05b0..64b533de098 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -21,6 +21,7 @@ #include #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex64.h" @@ -40,6 +41,16 @@ struct CBlas { } }; +template <> +struct CBlas { + template + static void VCOPY(ARGS... args) { + PADDLE_THROW(platform::errors::Unimplemented( + "Blas VCOPY do not supported on CPU with bfloat16," + " please check your code")); + } +}; + #ifdef PADDLE_WITH_MKLML template <> struct CBlas { diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 8ca83d08d64..939e2ac0f59 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -33,10 +33,19 @@ from paddle.fluid.backward import append_backward from paddle.fluid.op import Operator from paddle.fluid.executor import Executor from paddle.fluid.framework import Program, OpProtoHolder, Variable -from testsuite import create_op, set_input, append_input_output, append_loss_ops +from paddle.fluid.tests.unittests.testsuite import ( + create_op, + set_input, + append_input_output, + append_loss_ops, ) from paddle.fluid import unique_name -from white_list import op_accuracy_white_list, check_shape_white_list, compile_vs_runtime_white_list, no_check_set_white_list -from white_list import op_threshold_white_list, no_grad_set_white_list +from paddle.fluid.tests.unittests.white_list import ( + op_accuracy_white_list, + check_shape_white_list, + compile_vs_runtime_white_list, + no_check_set_white_list, + op_threshold_white_list, + no_grad_set_white_list, ) def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs): @@ -1452,6 +1461,7 @@ class OpTest(unittest.TestCase): analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set, user_defined_grad_outputs) + # comparison of bf16 results will happen as fp32 # loop over list of grads and convert bf16 to fp32 fp32_grads = [] diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_bf16_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_bf16_op.py new file mode 100644 index 00000000000..13c4aa6d767 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_bf16_op.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import ( + OpTest, convert_float_to_uint16, convert_uint16_to_float, + skip_check_grad_ci) +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from paddle import enable_static + + +def _lookup(weights, ids, flat_ids): + w_shape = weights.shape + out_shape = list(ids.shape[:-1]) + out_shape.append(w_shape[-1]) + out = weights[flat_ids].reshape(out_shape) + return out + + +def _get_grad(weights, ids, flat_ids): + w_shape = weights.shape + w_grad = np.zeros((w_shape), dtype=weights.dtype) + out_grad_shape = (np.prod(ids.shape[:-1]), w_shape[-1]) + out_grad = weights[flat_ids].reshape(out_grad_shape) + for i, idx in enumerate(flat_ids): + w_grad[idx, :] += out_grad[i] + return w_grad + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestLookupTableBF16Op(OpTest): + def setUp(self): + self.op_type = "lookup_table" + self.dtype = np.uint16 + + table = np.random.random((17, 31)).astype("float32") + self.ids = np.random.randint(0, 17, (4, 1)).astype("int64") + self.flat_ids = self.ids.flatten() + + self.w_bf16 = convert_float_to_uint16(table) + self.out_bf16 = _lookup(self.w_bf16, self.ids, self.flat_ids) + self.out_fp32 = _lookup(table, self.ids, self.flat_ids) + self.w_grad_fp32 = _get_grad(table, self.ids, self.flat_ids) + + self.inputs = {'W': self.w_bf16, 'Ids': self.ids} + self.outputs = {'Out': self.out_fp32} + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace(), check_dygraph=False) + + def test_check_grad(self): + self.check_grad_with_place( + core.CPUPlace(), ['W'], + 'Out', + no_grad_set=set('Ids'), + check_dygraph=False, + max_relative_error=1.5e-2, + user_defined_grads=[self.w_grad_fp32], + user_defined_grad_outputs=[self.out_bf16]) + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestLookupTableBF16OpIds4D(TestLookupTableBF16Op): + def setUp(self): + super(TestLookupTableBF16OpIds4D, self).setUp() + self.ids = np.random.randint(0, 17, (2, 4, 5, 1)).astype("int64") + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestLookupTableBF16OpWIsSelectedRows(unittest.TestCase): + def setUp(self): + self.ids = np.random.randint( + low=0, high=15, size=(10, 1)).astype("int64") + self.flat_ids = self.ids.flatten() + self.w_fp32 = np.random.random((15, 32)).astype("float32") + self.w_bf16 = convert_float_to_uint16(self.w_fp32) + self.scope = core.Scope() + self.place = core.CPUPlace() + + def prepare_w(self): + rows = [a for a in range(self.w_bf16.shape[0])] + row_numel = self.w_bf16.shape[1] + + w_selected_rows = self.scope.var('W').get_selected_rows() + w_selected_rows.set_height(len(rows)) + w_selected_rows.set_rows(rows) + w_tensor = w_selected_rows.get_tensor() + w_tensor.set(self.w_bf16, self.place) + + def prepare_ids(self): + ids_tensor = self.scope.var('Ids').get_tensor() + ids_tensor.set(self.ids, self.place) + + def _check_output(self, reference, result_array): + result_array_fp32 = convert_uint16_to_float(result_array) + np.testing.assert_allclose(result_array_fp32, reference, rtol=1.5e-2) + + def test_check_output(self): + self.prepare_ids() + self.prepare_w() + out_tensor = self.scope.var('Out').get_tensor() + + # create and run lookup_table operator + lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out') + lookup_table.run(self.scope, self.place) + + # get result from Out + result_array = np.array(out_tensor) + ref = _lookup(self.w_fp32, self.ids, self.flat_ids) + self._check_output(ref, result_array) + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestLookupTableBF16OpWIsSelectedRows4DIds( + TestLookupTableBF16OpWIsSelectedRows): + def setUp(self): + super(TestLookupTableBF16OpWIsSelectedRows4DIds, self).setUp() + self.ids = np.random.randint( + low=0, high=15, size=(3, 4, 5, 1)).astype("int64") + self.flat_ids = self.ids.flatten() + + +@skip_check_grad_ci( + reason="Since paddings are not trainable and fixed in forward," + "the gradient of paddings makes no sense and we don't " + "test the gradient here.") +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestLookupTableBF16OpWithPadding(TestLookupTableBF16Op): + def test_check_output(self): + ids = np.squeeze(self.inputs['Ids']) + padding_idx = np.random.choice(ids, 1)[0] + self.outputs['Out'][ids == padding_idx] = np.zeros(31) + self.attrs = {'padding_idx': int(padding_idx)} + self.check_output_with_place(core.CPUPlace(), check_dygraph=False) + + +@skip_check_grad_ci( + reason="Since paddings are not trainable and fixed in forward," + "the gradient of paddings makes no sense and we don't " + "test the gradient here.") +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestLookupTableBF16OpIds4DPadding(TestLookupTableBF16OpIds4D): + def test_check_output(self): + ids = self.inputs['Ids'] + flatten_idx = ids.flatten() + padding_idx = np.random.choice(flatten_idx, 1)[0] + self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) + self.attrs = {'padding_idx': int(padding_idx)} + self.check_output_with_place(core.CPUPlace(), check_dygraph=False) + + +if __name__ == "__main__": + enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index dc537cb2684..2ea3f7654af 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -21,6 +21,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_linear_chain_crf_op', 'test_lod_reset_op', 'test_lookup_table_op', + 'test_lookup_table_bf16_op', 'test_pad2d_op', 'test_scatter_op', 'test_sequence_concat', -- GitLab