未验证 提交 85e531a9 编写于 作者: C cc 提交者: GitHub

Add int16 kernel for lookup_talbe and dequantize_abs_max op (#34275)

* add int16 kernel for lookup_talbe and dequantize_abs_max op
上级 5179853a
......@@ -50,6 +50,7 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
};
template struct DequantizeFunctor<platform::CPUDeviceContext, int8_t>;
template struct DequantizeFunctor<platform::CPUDeviceContext, int16_t>;
class DequantizeMaxAbsOp : public framework::OperatorWithKernel {
public:
......@@ -79,7 +80,7 @@ class DequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(int8 Tensor) The input with int8 type is the "
"(Int Tensor) The input with int8/16 type is the "
"low precision tensor.");
AddInput("Scale", "(float) The scale in quantization stage.");
AddOutput("Out",
......@@ -108,4 +109,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(dequantize_abs_max,
ops::DequantizeMaxAbsKernel<CPU, int8_t>);
ops::DequantizeMaxAbsKernel<CPU, int8_t>,
ops::DequantizeMaxAbsKernel<CPU, int16_t>);
......@@ -45,6 +45,7 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
};
template struct DequantizeFunctor<platform::CUDADeviceContext, int8_t>;
template struct DequantizeFunctor<platform::CUDADeviceContext, int16_t>;
} // namespace operators
} // namespace paddle
......@@ -52,4 +53,5 @@ template struct DequantizeFunctor<platform::CUDADeviceContext, int8_t>;
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(dequantize_abs_max,
ops::DequantizeMaxAbsKernel<CUDA, int8_t>);
ops::DequantizeMaxAbsKernel<CUDA, int8_t>,
ops::DequantizeMaxAbsKernel<CUDA, int16_t>);
......@@ -229,6 +229,7 @@ REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
ops::LookupTableKernel<double>,
ops::LookupTableKernel<int8_t>,
ops::LookupTableKernel<int16_t>,
ops::LookupTableKernel<paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
ops::LookupTableGradKernel<double>,
......
......@@ -227,7 +227,8 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
ops::LookupTableCUDAKernel<double>,
ops::LookupTableCUDAKernel<plat::float16>,
ops::LookupTableCUDAKernel<int8_t>);
ops::LookupTableCUDAKernel<int8_t>,
ops::LookupTableCUDAKernel<int16_t>);
REGISTER_OP_CUDA_KERNEL(lookup_table_grad,
ops::LookupTableGradCUDAKernel<float>,
ops::LookupTableGradCUDAKernel<double>,
......
......@@ -103,6 +103,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
if (id_index != -1) {
if (input_data_type == framework::proto::VarType::INT8 ||
input_data_type == framework::proto::VarType::INT16 ||
input_data_type == framework::proto::VarType::BF16) {
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
......@@ -130,6 +131,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
id_index));
if (input_data_type == framework::proto::VarType::INT8 ||
input_data_type == framework::proto::VarType::INT16 ||
input_data_type == framework::proto::VarType::BF16) {
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
......
......@@ -54,6 +54,15 @@ struct CBlas<int8_t> {
}
};
template <>
struct CBlas<int16_t> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"Blas VCOPY do not supported on CPU, please check your code"));
}
};
template <>
struct CBlas<platform::bfloat16> {
template <typename... ARGS>
......
......@@ -62,5 +62,12 @@ class TestDequantizeMaxAbsOp5Bits(TestDequantizeMaxAbsOp):
self.data_type = "int8"
class TestDequantizeMaxAbsOpInt16(TestDequantizeMaxAbsOp):
def set_args(self):
self.num_bits = 16
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "int16"
if __name__ == "__main__":
unittest.main()
......@@ -316,6 +316,124 @@ class TestLookupTableWithTensorIdsWIsSelectedRowsInt8(
assert (row == result_array[idx]).all()
@skip_check_grad_ci(reason="Int16 type only be used in test and inference.")
class TestLookupTableOpInt16(OpTest):
def setUp(self):
self.op_type = "lookup_table"
table = np.random.randint(
low=-128, high=127, size=(17, 31)).astype("int16")
ids = np.random.randint(0, 17, 4).astype("int64")
ids_expand = np.expand_dims(ids, axis=1)
self.inputs = {'W': table, 'Ids': ids_expand}
self.outputs = {'Out': table[ids]}
def test_check_output(self):
self.check_output()
@skip_check_grad_ci(reason="Int16 type only be used in test and inference.")
class TestLookupTableOpWithTensorIdsInt16(OpTest):
def setUp(self):
self.op_type = "lookup_table"
table = np.random.randint(
low=-128, high=127, size=(17, 31)).astype("int16")
ids = np.random.randint(
low=0, high=17, size=(2, 4, 5, 1)).astype("int64")
self.inputs = {'W': table, 'Ids': ids}
self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))}
def test_check_output(self):
self.check_output()
@skip_check_grad_ci(reason="Int16 type only be used in test and inference.")
class TestLookupTableOpWithPaddingInt16(TestLookupTableOpInt16):
def test_check_output(self):
ids = np.squeeze(self.inputs['Ids'])
padding_idx = np.random.choice(ids, 1)[0]
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output()
@skip_check_grad_ci(reason="Int16 type only be used in test and inference.")
class TestLookupTableOpWithTensorIdsAndPaddingInt16(
TestLookupTableOpWithTensorIdsInt16):
def test_check_output(self):
ids = self.inputs['Ids']
flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
self.check_output()
class TestLookupTableWIsSelectedRowsInt16(unittest.TestCase):
def prepare_ids(self, scope, place):
ids_tensor = scope.var('Ids').get_tensor()
ids_array = np.array([[0], [4], [3], [5]]).astype("int64")
ids_tensor.set(ids_array, place)
return ids_array
def prepare_w(self, scope, place):
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 12
w_selected_rows = scope.var('W').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("int16")
for i in range(len(rows)):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
def create_out_tensor(self, scope, place):
return scope.var('Out').get_tensor()
def check_result(self, ids_array, result_array):
for idx, row in enumerate(ids_array):
assert (row[0] == result_array[idx]).all()
def check_with_place(self, place):
scope = core.Scope()
ids_array = self.prepare_ids(scope, place)
self.prepare_w(scope, place)
out_tensor = self.create_out_tensor(scope, place)
# create and run lookup_table operator
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
lookup_table.run(scope, place)
# get result from Out
result_array = np.array(out_tensor)
self.check_result(ids_array, result_array)
def test_w_is_selected_rows(self):
places = [core.CPUPlace()]
# currently only support CPU
for place in places:
self.check_with_place(place)
class TestLookupTableWithTensorIdsWIsSelectedRowsInt16(
TestLookupTableWIsSelectedRowsInt16):
def prepare_ids(self, scope, place):
ids_tensor = scope.var('Ids').get_tensor()
ids_array = np.random.randint(
low=0, high=6, size=(2, 4, 3, 1)).astype("int64")
ids_tensor.set(ids_array, place)
return ids_array
def check_result(self, ids_array, result_array):
for idx, row in np.ndenumerate(ids_array):
assert (row == result_array[idx]).all()
class TestOutDtype(unittest.TestCase):
def test_dtype(self):
api_fn = F.embedding
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册