未验证 提交 9d6c8bdf 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add lookup_table_v2 BF16 op (#33172)

* Add lookup_table_v2 BF16

* Reuse lookup table UT

* Change op_type to op_version

* Remove check_dygraph

* Remove skip_check_grad_ci
上级 f9ce1b1a
......@@ -197,10 +197,12 @@ REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad,
ops::LookupTableV2OpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel<float>,
ops::LookupTableV2Kernel<double>);
REGISTER_OP_CPU_KERNEL(lookup_table_v2_grad,
ops::LookupTableV2GradKernel<float>,
ops::LookupTableV2GradKernel<double>);
ops::LookupTableV2Kernel<double>,
ops::LookupTableV2Kernel<paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
lookup_table_v2_grad, ops::LookupTableV2GradKernel<float>,
ops::LookupTableV2GradKernel<double>,
ops::LookupTableV2GradKernel<paddle::platform::bfloat16>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(lookup_table_v2)
......
......@@ -91,8 +91,8 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto input_data_type = table_t.value().type();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
......@@ -109,8 +109,15 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"the input key should be exists. But received %d.",
id_index));
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
if (input_data_type == framework::proto::VarType::BF16) {
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
} else {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
}
}
}
......
......@@ -309,7 +309,7 @@ def embedding(input,
helper = LayerHelper('embedding', **locals())
check_variable_and_dtype(input, 'input', ['int64'], 'fluid.embedding')
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64', 'uint16'],
'fluid.embedding')
remote_prefetch = is_sparse and (not is_distributed)
if remote_prefetch:
......
......@@ -25,18 +25,21 @@ from paddle.fluid.op import Operator
from paddle import enable_static
def _lookup(weights, ids, flat_ids):
def _lookup(weights, ids, flat_ids, op_version="lookup_table"):
w_shape = weights.shape
out_shape = list(ids.shape[:-1])
out_shape = list(ids.shape[:-1]) if op_version is "lookup_table" else list(
ids.shape)
out_shape.append(w_shape[-1])
out = weights[flat_ids].reshape(out_shape)
return out
def _get_grad(weights, ids, flat_ids):
def _get_grad(weights, ids, flat_ids, op_version="lookup_table"):
w_shape = weights.shape
w_grad = np.zeros((w_shape), dtype=weights.dtype)
out_grad_shape = (np.prod(ids.shape[:-1]), w_shape[-1])
out_shape = list(ids.shape[:-1]) if op_version is "lookup_table" else list(
ids.shape)
out_grad_shape = (np.prod(out_shape), w_shape[-1])
out_grad = weights[flat_ids].reshape(out_grad_shape)
for i, idx in enumerate(flat_ids):
w_grad[idx, :] += out_grad[i]
......@@ -46,18 +49,24 @@ def _get_grad(weights, ids, flat_ids):
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16Op(OpTest):
def setUp(self):
def init_test(self):
self.op_type = "lookup_table"
self.ids_shape = (4, 1)
def setUp(self):
self.init_test()
self.dtype = np.uint16
table = np.random.random((17, 31)).astype("float32")
self.ids = np.random.randint(0, 17, (4, 1)).astype("int64")
self.ids = np.random.randint(0, 17, self.ids_shape).astype("int64")
self.flat_ids = self.ids.flatten()
self.w_bf16 = convert_float_to_uint16(table)
self.out_bf16 = _lookup(self.w_bf16, self.ids, self.flat_ids)
self.out_fp32 = _lookup(table, self.ids, self.flat_ids)
self.w_grad_fp32 = _get_grad(table, self.ids, self.flat_ids)
self.out_bf16 = _lookup(self.w_bf16, self.ids, self.flat_ids,
self.op_type)
self.out_fp32 = _lookup(table, self.ids, self.flat_ids, self.op_type)
self.w_grad_fp32 = _get_grad(table, self.ids, self.flat_ids,
self.op_type)
self.inputs = {'W': self.w_bf16, 'Ids': self.ids}
self.outputs = {'Out': self.out_fp32}
......@@ -79,17 +88,22 @@ class TestLookupTableBF16Op(OpTest):
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpIds4D(TestLookupTableBF16Op):
def setUp(self):
super(TestLookupTableBF16OpIds4D, self).setUp()
self.ids = np.random.randint(0, 17, (2, 4, 5, 1)).astype("int64")
def init_test(self):
self.op_type = "lookup_table"
self.ids_shape = (2, 4, 5, 1)
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpWIsSelectedRows(unittest.TestCase):
def init_test(self):
self.op_type = "lookup_table"
self.ids_shape = (10, 1)
def setUp(self):
self.init_test()
self.ids = np.random.randint(
low=0, high=15, size=(10, 1)).astype("int64")
low=0, high=15, size=self.ids_shape).astype("int64")
self.flat_ids = self.ids.flatten()
self.w_fp32 = np.random.random((15, 32)).astype("float32")
self.w_bf16 = convert_float_to_uint16(self.w_fp32)
......@@ -120,12 +134,12 @@ class TestLookupTableBF16OpWIsSelectedRows(unittest.TestCase):
out_tensor = self.scope.var('Out').get_tensor()
# create and run lookup_table operator
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
lookup_table = Operator(self.op_type, W='W', Ids='Ids', Out='Out')
lookup_table.run(self.scope, self.place)
# get result from Out
result_array = np.array(out_tensor)
ref = _lookup(self.w_fp32, self.ids, self.flat_ids)
ref = _lookup(self.w_fp32, self.ids, self.flat_ids, self.op_type)
self._check_output(ref, result_array)
......@@ -133,10 +147,12 @@ class TestLookupTableBF16OpWIsSelectedRows(unittest.TestCase):
"place does not support BF16 evaluation")
class TestLookupTableBF16OpWIsSelectedRows4DIds(
TestLookupTableBF16OpWIsSelectedRows):
def init_test(self):
self.op_type = "lookup_table"
self.ids_shape = (3, 4, 5, 1)
def setUp(self):
super(TestLookupTableBF16OpWIsSelectedRows4DIds, self).setUp()
self.ids = np.random.randint(
low=0, high=15, size=(3, 4, 5, 1)).astype("int64")
self.flat_ids = self.ids.flatten()
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.fluid.tests.unittests.op_test import (skip_check_grad_ci,
convert_uint16_to_float)
from paddle.fluid.tests.unittests.test_lookup_table_bf16_op import (
_lookup, TestLookupTableBF16Op, TestLookupTableBF16OpIds4D,
TestLookupTableBF16OpWIsSelectedRows,
TestLookupTableBF16OpWIsSelectedRows4DIds)
import paddle.fluid as fluid
import paddle.fluid.core as core
class TestLookupTableV2BF16Op(TestLookupTableBF16Op):
def init_test(self):
self.op_type = "lookup_table_v2"
self.ids_shape = (4)
self.mkldnn_data_type = "bfloat16"
class TestLookupTableV2BF16OpIds4D(TestLookupTableBF16OpIds4D):
def init_test(self):
self.op_type = "lookup_table_v2"
self.ids_shape = (2, 4, 5)
self.mkldnn_data_type = "bfloat16"
class TestLookupTableV2BF16OpWIsSelectedRows(
TestLookupTableBF16OpWIsSelectedRows):
def init_test(self):
self.op_type = "lookup_table_v2"
self.ids_shape = (10)
class TestLookupTableV2BF16OpWIsSelectedRows4DIds(
TestLookupTableBF16OpWIsSelectedRows4DIds):
def init_test(self):
self.op_type = "lookup_table_v2"
self.ids_shape = (3, 4, 5)
class TestLookupTableBF16OpWithPadding(TestLookupTableV2BF16Op):
def test_check_output(self):
ids = np.squeeze(self.inputs['Ids'])
padding_idx = np.random.choice(ids, 1)[0]
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output_with_place(core.CPUPlace())
class TestLookupTableBF16OpIds4DPadding(TestLookupTableV2BF16OpIds4D):
def test_check_output(self):
ids = self.inputs['Ids']
flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output_with_place(core.CPUPlace())
class TestEmbeddingLayerBF16ConstantInitializer(unittest.TestCase):
"""
Test embedding layer from input api and results for bfloat16
"""
def set_initializer(self):
self.initializer = fluid.initializer.Constant(value=self.value)
def setUp(self):
self.op_type = "lookup_table_v2"
self.ids_shape = [4]
self.w_shape = [10, 64]
self.ids = np.random.randint(
low=0, high=9, size=self.ids_shape).astype("int64")
self.flat_ids = self.ids.flatten()
self.value = 3.0
self.w_fp32 = np.full(self.w_shape, self.value)
self.place = fluid.CPUPlace()
self.prog = fluid.Program()
self.startup_prog = fluid.Program()
self.set_initializer()
with fluid.program_guard(self.prog, self.startup_prog):
x = fluid.layers.data(name='x', shape=self.ids_shape, dtype='int64')
self.emb = fluid.input.embedding(
input=x,
size=self.w_shape,
param_attr=fluid.ParamAttr(
name="emb_weight", initializer=self.initializer),
is_sparse=False,
dtype="uint16") # bfloat16
exe = fluid.Executor(self.place)
exe.run(self.startup_prog)
self.result = exe.run(self.prog,
feed={'x': self.ids},
fetch_list=['emb_weight', self.emb])
def test_embedding_weights(self):
result = convert_uint16_to_float(self.result[0])
self.assertTrue(np.array_equal(self.w_fp32, result))
def test_lookup_results(self):
lookup_result = convert_uint16_to_float(self.result[1])
lookup_ref = _lookup(self.w_fp32, self.ids, self.flat_ids, self.op_type)
self.assertTrue(np.array_equal(lookup_result, lookup_ref))
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -22,6 +22,7 @@ STATIC_MODE_TESTING_LIST = [
'test_lod_reset_op',
'test_lookup_table_op',
'test_lookup_table_bf16_op',
'test_lookup_table_v2_bf16_op',
'test_pad2d_op',
'test_scatter_op',
'test_sequence_concat',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册