未验证 提交 a4a2b77d 编写于 作者: A Adam Osewski 提交者: GitHub

[oneDNN] lookup_table op with support for BF16 data type. (#31558)

上级 c86e771e
...@@ -53,7 +53,7 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType( ...@@ -53,7 +53,7 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType(
gpd(graph, handler); gpd(graph, handler);
} }
void CPUBfloat16PlacementPass::RemoveOrhanedOperators( void CPUBfloat16PlacementPass::RemoveOrphanedOperators(
ir::Graph* graph, int* bfloat16_operators) const { ir::Graph* graph, int* bfloat16_operators) const {
// find orphaned bfloat16 operator that is between two float32 operators // find orphaned bfloat16 operator that is between two float32 operators
// revert mkldnn_data_type attr to float32 // revert mkldnn_data_type attr to float32
...@@ -74,7 +74,7 @@ void CPUBfloat16PlacementPass::RemoveOrhanedOperators( ...@@ -74,7 +74,7 @@ void CPUBfloat16PlacementPass::RemoveOrhanedOperators(
void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const { void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
int bfloat16_operators = 0; int bfloat16_operators = 0;
SetMkldnnDataType(graph, &bfloat16_operators); SetMkldnnDataType(graph, &bfloat16_operators);
RemoveOrhanedOperators(graph, &bfloat16_operators); RemoveOrphanedOperators(graph, &bfloat16_operators);
PrettyLogDetail("--- marked %d operators to bfloat16 ", PrettyLogDetail("--- marked %d operators to bfloat16 ",
bfloat16_operators); bfloat16_operators);
} }
......
...@@ -28,7 +28,7 @@ class CPUBfloat16PlacementPass : public Pass { ...@@ -28,7 +28,7 @@ class CPUBfloat16PlacementPass : public Pass {
protected: protected:
void SetMkldnnDataType(ir::Graph* graph, int* bfloat16_operators) const; void SetMkldnnDataType(ir::Graph* graph, int* bfloat16_operators) const;
void RemoveOrhanedOperators(ir::Graph* graph, int* bfloat16_operators) const; void RemoveOrphanedOperators(ir::Graph* graph, int* bfloat16_operators) const;
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h" #include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/bfloat16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -222,9 +223,11 @@ REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad, ...@@ -222,9 +223,11 @@ REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>, REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
ops::LookupTableKernel<double>, ops::LookupTableKernel<double>,
ops::LookupTableKernel<int8_t>); ops::LookupTableKernel<int8_t>,
ops::LookupTableKernel<paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>, REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
ops::LookupTableGradKernel<double>); ops::LookupTableGradKernel<double>,
ops::LookupTableGradKernel<paddle::platform::bfloat16>);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
......
...@@ -102,7 +102,8 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -102,7 +102,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto id_index = table_t.GetIndexFromId(ids[i]); auto id_index = table_t.GetIndexFromId(ids[i]);
if (id_index != -1) { if (id_index != -1) {
if (input_data_type == framework::proto::VarType::INT8) { if (input_data_type == framework::proto::VarType::INT8 ||
input_data_type == framework::proto::VarType::BF16) {
memcpy(output + i * row_width, table + id_index * row_width, memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} else { } else {
...@@ -128,7 +129,8 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -128,7 +129,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
"the input key should be exists. But received %d.", "the input key should be exists. But received %d.",
id_index)); id_index));
if (input_data_type == framework::proto::VarType::INT8) { if (input_data_type == framework::proto::VarType::INT8 ||
input_data_type == framework::proto::VarType::BF16) {
memcpy(output + i * row_width, table + id_index * row_width, memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} else { } else {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/complex64.h"
...@@ -40,6 +41,16 @@ struct CBlas<int8_t> { ...@@ -40,6 +41,16 @@ struct CBlas<int8_t> {
} }
}; };
template <>
struct CBlas<platform::bfloat16> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"Blas VCOPY do not supported on CPU with bfloat16,"
" please check your code"));
}
};
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <> template <>
struct CBlas<float> { struct CBlas<float> {
......
...@@ -33,10 +33,19 @@ from paddle.fluid.backward import append_backward ...@@ -33,10 +33,19 @@ from paddle.fluid.backward import append_backward
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, OpProtoHolder, Variable from paddle.fluid.framework import Program, OpProtoHolder, Variable
from testsuite import create_op, set_input, append_input_output, append_loss_ops from paddle.fluid.tests.unittests.testsuite import (
create_op,
set_input,
append_input_output,
append_loss_ops, )
from paddle.fluid import unique_name from paddle.fluid import unique_name
from white_list import op_accuracy_white_list, check_shape_white_list, compile_vs_runtime_white_list, no_check_set_white_list from paddle.fluid.tests.unittests.white_list import (
from white_list import op_threshold_white_list, no_grad_set_white_list op_accuracy_white_list,
check_shape_white_list,
compile_vs_runtime_white_list,
no_check_set_white_list,
op_threshold_white_list,
no_grad_set_white_list, )
def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs): def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs):
...@@ -1452,6 +1461,7 @@ class OpTest(unittest.TestCase): ...@@ -1452,6 +1461,7 @@ class OpTest(unittest.TestCase):
analytic_grads = self._get_gradient(inputs_to_check, place, analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set, output_names, no_grad_set,
user_defined_grad_outputs) user_defined_grad_outputs)
# comparison of bf16 results will happen as fp32 # comparison of bf16 results will happen as fp32
# loop over list of grads and convert bf16 to fp32 # loop over list of grads and convert bf16 to fp32
fp32_grads = [] fp32_grads = []
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import (
OpTest, convert_float_to_uint16, convert_uint16_to_float,
skip_check_grad_ci)
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle import enable_static
def _lookup(weights, ids, flat_ids):
w_shape = weights.shape
out_shape = list(ids.shape[:-1])
out_shape.append(w_shape[-1])
out = weights[flat_ids].reshape(out_shape)
return out
def _get_grad(weights, ids, flat_ids):
w_shape = weights.shape
w_grad = np.zeros((w_shape), dtype=weights.dtype)
out_grad_shape = (np.prod(ids.shape[:-1]), w_shape[-1])
out_grad = weights[flat_ids].reshape(out_grad_shape)
for i, idx in enumerate(flat_ids):
w_grad[idx, :] += out_grad[i]
return w_grad
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16Op(OpTest):
def setUp(self):
self.op_type = "lookup_table"
self.dtype = np.uint16
table = np.random.random((17, 31)).astype("float32")
self.ids = np.random.randint(0, 17, (4, 1)).astype("int64")
self.flat_ids = self.ids.flatten()
self.w_bf16 = convert_float_to_uint16(table)
self.out_bf16 = _lookup(self.w_bf16, self.ids, self.flat_ids)
self.out_fp32 = _lookup(table, self.ids, self.flat_ids)
self.w_grad_fp32 = _get_grad(table, self.ids, self.flat_ids)
self.inputs = {'W': self.w_bf16, 'Ids': self.ids}
self.outputs = {'Out': self.out_fp32}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace(), check_dygraph=False)
def test_check_grad(self):
self.check_grad_with_place(
core.CPUPlace(), ['W'],
'Out',
no_grad_set=set('Ids'),
check_dygraph=False,
max_relative_error=1.5e-2,
user_defined_grads=[self.w_grad_fp32],
user_defined_grad_outputs=[self.out_bf16])
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpIds4D(TestLookupTableBF16Op):
def setUp(self):
super(TestLookupTableBF16OpIds4D, self).setUp()
self.ids = np.random.randint(0, 17, (2, 4, 5, 1)).astype("int64")
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpWIsSelectedRows(unittest.TestCase):
def setUp(self):
self.ids = np.random.randint(
low=0, high=15, size=(10, 1)).astype("int64")
self.flat_ids = self.ids.flatten()
self.w_fp32 = np.random.random((15, 32)).astype("float32")
self.w_bf16 = convert_float_to_uint16(self.w_fp32)
self.scope = core.Scope()
self.place = core.CPUPlace()
def prepare_w(self):
rows = [a for a in range(self.w_bf16.shape[0])]
row_numel = self.w_bf16.shape[1]
w_selected_rows = self.scope.var('W').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(self.w_bf16, self.place)
def prepare_ids(self):
ids_tensor = self.scope.var('Ids').get_tensor()
ids_tensor.set(self.ids, self.place)
def _check_output(self, reference, result_array):
result_array_fp32 = convert_uint16_to_float(result_array)
np.testing.assert_allclose(result_array_fp32, reference, rtol=1.5e-2)
def test_check_output(self):
self.prepare_ids()
self.prepare_w()
out_tensor = self.scope.var('Out').get_tensor()
# create and run lookup_table operator
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
lookup_table.run(self.scope, self.place)
# get result from Out
result_array = np.array(out_tensor)
ref = _lookup(self.w_fp32, self.ids, self.flat_ids)
self._check_output(ref, result_array)
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpWIsSelectedRows4DIds(
TestLookupTableBF16OpWIsSelectedRows):
def setUp(self):
super(TestLookupTableBF16OpWIsSelectedRows4DIds, self).setUp()
self.ids = np.random.randint(
low=0, high=15, size=(3, 4, 5, 1)).astype("int64")
self.flat_ids = self.ids.flatten()
@skip_check_grad_ci(
reason="Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here.")
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpWithPadding(TestLookupTableBF16Op):
def test_check_output(self):
ids = np.squeeze(self.inputs['Ids'])
padding_idx = np.random.choice(ids, 1)[0]
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output_with_place(core.CPUPlace(), check_dygraph=False)
@skip_check_grad_ci(
reason="Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here.")
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpIds4DPadding(TestLookupTableBF16OpIds4D):
def test_check_output(self):
ids = self.inputs['Ids']
flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output_with_place(core.CPUPlace(), check_dygraph=False)
if __name__ == "__main__":
enable_static()
unittest.main()
...@@ -21,6 +21,7 @@ STATIC_MODE_TESTING_LIST = [ ...@@ -21,6 +21,7 @@ STATIC_MODE_TESTING_LIST = [
'test_linear_chain_crf_op', 'test_linear_chain_crf_op',
'test_lod_reset_op', 'test_lod_reset_op',
'test_lookup_table_op', 'test_lookup_table_op',
'test_lookup_table_bf16_op',
'test_pad2d_op', 'test_pad2d_op',
'test_scatter_op', 'test_scatter_op',
'test_sequence_concat', 'test_sequence_concat',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册