From f0b15184388e9f8344f8047e600dad9e8430426b Mon Sep 17 00:00:00 2001 From: Liufang Sang Date: Fri, 22 Nov 2019 11:11:35 +0800 Subject: [PATCH] add dequantize_abs_max op and modify lookup_table op (#20899) * add int8 kernel to lookup_table op and add dequantize op test=develop * change paddle_enforce to paddle_enforce_eq test=develop * change copyright and change some not suitable code test=develop * remove debug log test=develop * replace GetInputType with IndicateVarDataType test=develop * fix EmptyGradMaker test=develop * fix diff between cpu and gpu test=develop * use memcopy when int8_t test=develop --- .../fluid/operators/dequantize_abs_max_op.cc | 98 +++++++++++++ .../fluid/operators/dequantize_abs_max_op.cu | 55 +++++++ .../fluid/operators/dequantize_abs_max_op.h | 50 +++++++ paddle/fluid/operators/lookup_table_op.cc | 3 +- paddle/fluid/operators/lookup_table_op.cu | 3 +- paddle/fluid/operators/lookup_table_op.h | 13 +- paddle/fluid/operators/math/blas_impl.h | 8 ++ .../unittests/test_dequantize_abs_max_op.py | 66 +++++++++ .../tests/unittests/test_lookup_table_op.py | 135 ++++++++++++++++++ 9 files changed, 425 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/operators/dequantize_abs_max_op.cc create mode 100644 paddle/fluid/operators/dequantize_abs_max_op.cu create mode 100644 paddle/fluid/operators/dequantize_abs_max_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_dequantize_abs_max_op.py diff --git a/paddle/fluid/operators/dequantize_abs_max_op.cc b/paddle/fluid/operators/dequantize_abs_max_op.cc new file mode 100644 index 00000000000..48743f2e48c --- /dev/null +++ b/paddle/fluid/operators/dequantize_abs_max_op.cc @@ -0,0 +1,98 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/dequantize_abs_max_op.h" +#include +#include + +namespace paddle { +namespace operators { + +template +struct DequantizeFunctor { + void operator()(const platform::CPUDeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* scale, + float max_range, framework::Tensor* out) { + const float* scale_factor = scale->data(); + const T* input_data = in->data(); + float* output_data = out->mutable_data(dev_ctx.GetPlace()); + int ind = in->numel(); + for (size_t i = 0; i < (unsigned)ind; i++) { + output_data[i] = scale_factor[0] * input_data[i] / max_range; + } + } +}; + +template struct DequantizeFunctor; + +class DequantizeMaxAbsOp : public framework::OperatorWithKernel { + public: + DequantizeMaxAbsOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of DequantizeMaxAbsOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of DequantizeMaxAbsOp should not be null."); + + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + auto type = framework::OpKernelType(data_type, ctx.device_context()); + return type; + } +}; + +class DequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(int8 Tensor) The input with int8 type is the " + "low precision tensor."); + AddInput("Scale", "(float) The scale in quantization stage."); + AddOutput("Out", + "(float32 Tensor) The output is the dequantized high " + "precision tensor."); + AddAttr("max_range", "(float) The max range in quantization stage."); + AddComment(R"DOC( +DequantizeMaxAbsOp operator. + +This calculation is an opposite operation of QuantizeMaxAbsOp: + +$$Out = \frac{scale*X}{ max\_range }$$ + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR( + dequantize_abs_max, ops::DequantizeMaxAbsOp, ops::DequantizeMaxAbsOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(dequantize_abs_max, + ops::DequantizeMaxAbsKernel); diff --git a/paddle/fluid/operators/dequantize_abs_max_op.cu b/paddle/fluid/operators/dequantize_abs_max_op.cu new file mode 100644 index 00000000000..6554d4545ad --- /dev/null +++ b/paddle/fluid/operators/dequantize_abs_max_op.cu @@ -0,0 +1,55 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/dequantize_abs_max_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void KeDequantize(const T* in, const float* scale, float max_range, + int num, float* out) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < num) { + out[idx] = in[idx] * scale[0] / max_range; + } +} + +template +struct DequantizeFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* scale, + float max_range, framework::Tensor* out) { + const T* in_data = in->data(); + const float* scale_factor = scale->data(); + float* out_data = out->mutable_data(dev_ctx.GetPlace()); + + int num = in->numel(); + int block = 512; + int grid = (num + block - 1) / block; + + KeDequantize<<>>( + in_data, scale_factor, max_range, num, out_data); + } +}; + +template struct DequantizeFunctor; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(dequantize_abs_max, + ops::DequantizeMaxAbsKernel); diff --git a/paddle/fluid/operators/dequantize_abs_max_op.h b/paddle/fluid/operators/dequantize_abs_max_op.h new file mode 100644 index 00000000000..796ca93b000 --- /dev/null +++ b/paddle/fluid/operators/dequantize_abs_max_op.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +struct DequantizeFunctor { + void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, + const framework::Tensor* scale, float max_range, + framework::Tensor* out); +}; + +template +class DequantizeMaxAbsKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto* in = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); + auto* out = ctx.Output("Out"); + + float max_range = ctx.Attr("max_range"); + + auto& dev_ctx = ctx.template device_context(); + out->mutable_data(dev_ctx.GetPlace()); + + DequantizeFunctor()(dev_ctx, in, scale, max_range, out); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index feed5a65d6b..9fb208662d9 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -205,6 +205,7 @@ REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad, ops::LookupTableOpGradVarTypeInference); REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel, - ops::LookupTableKernel); + ops::LookupTableKernel, + ops::LookupTableKernel); REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel, ops::LookupTableGradKernel); diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 1885de53c57..43e3457fd5d 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -199,7 +199,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel, ops::LookupTableCUDAKernel, - ops::LookupTableCUDAKernel); + ops::LookupTableCUDAKernel, + ops::LookupTableCUDAKernel); REGISTER_OP_CUDA_KERNEL(lookup_table_grad, ops::LookupTableGradCUDAKernel, ops::LookupTableGradCUDAKernel, diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 75e2c2a9c1f..348fa52f38c 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -107,8 +107,7 @@ class LookupTableKernel : public framework::OpKernel { int64_t row_width = table_t.value().dims()[1]; const auto *table = table_t.value().data(); auto *output = output_t->mutable_data(context.GetPlace()); - - auto blas = math::GetBlas(context); + auto input_data_type = table_t.value().type(); for (int64_t i = 0; i < ids_numel; ++i) { if (padding_idx != kNoPadding && ids[i] == padding_idx) { memset(output + i * row_width, 0, row_width * sizeof(T)); @@ -122,8 +121,14 @@ class LookupTableKernel : public framework::OpKernel { PADDLE_ENFORCE_GE( id_index, 0, "the input key should be exists. But received %d.", id_index); - blas.VCOPY(row_width, table + id_index * row_width, - output + i * row_width); + if (input_data_type == framework::proto::VarType::INT8) { + memcpy(output + i * row_width, table + id_index * row_width, + row_width * sizeof(T)); + } else { + auto blas = math::GetBlas(context); + blas.VCOPY(row_width, table + id_index * row_width, + output + i * row_width); + } } } } diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 817429be442..356445b497d 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -24,6 +24,14 @@ namespace math { template struct CBlas; +template <> +struct CBlas { + template + static void VCOPY(ARGS... args) { + PADDLE_THROW("Blas VCOPY don't support int8_t"); + } +}; + #ifdef PADDLE_WITH_MKLML template <> struct CBlas { diff --git a/python/paddle/fluid/tests/unittests/test_dequantize_abs_max_op.py b/python/paddle/fluid/tests/unittests/test_dequantize_abs_max_op.py new file mode 100644 index 00000000000..8a66bdb8d15 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dequantize_abs_max_op.py @@ -0,0 +1,66 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +from op_test import OpTest + + +def quantize_max_abs(x, max_range): + scale = np.max(np.abs(x).flatten()) + y = np.round(x / scale * max_range) + return y, scale + + +def dequantize_max_abs(x, scale, max_range): + y = (scale / max_range) * x + return y + + +class TestDequantizeMaxAbsOp(OpTest): + def set_args(self): + self.num_bits = 8 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "int8" + + def setUp(self): + self.set_args() + self.op_type = "dequantize_abs_max" + x = np.random.randn(31, 65).astype(self.data_type) + yq, scale = quantize_max_abs(x, self.max_range) + ydq = dequantize_max_abs(yq, scale, self.max_range) + + self.inputs = { + 'X': np.array(yq).astype(self.data_type), + 'Scale': np.array(scale).astype('float32') + } + self.attrs = {'max_range': self.max_range} + self.outputs = {'Out': ydq} + + def test_check_output(self): + self.check_output() + + +class TestDequantizeMaxAbsOp5Bits(TestDequantizeMaxAbsOp): + def set_args(self): + self.num_bits = 5 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "int8" + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index 17fa2c15a9e..1a9e226ac26 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -182,5 +182,140 @@ class TestEmbedOpError(OpTest): fluid.layers.embedding(input=input3, size=(10, 64), dtype='float16') +class TestLookupTableOpInt8(OpTest): + def setUp(self): + self.op_type = "lookup_table" + table = np.random.randint( + low=-128, high=127, size=(17, 31)).astype("int8") + ids = np.random.randint(0, 17, 4).astype("int64") + ids_expand = np.expand_dims(ids, axis=1) + self.inputs = {'W': table, 'Ids': ids_expand} + self.outputs = {'Out': table[ids]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + # since int8 type only be used in test and inference, there is + # no gradient implement, so we don't need to test it + pass + + +class TestLookupTableOpWithTensorIdsInt8(OpTest): + def setUp(self): + self.op_type = "lookup_table" + table = np.random.randint( + low=-128, high=127, size=(17, 31)).astype("int8") + ids = np.random.randint( + low=0, high=17, size=(2, 4, 5, 1)).astype("int64") + self.inputs = {'W': table, 'Ids': ids} + self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + # since int8 type only be used in test and inference, there is + # no gradient implement, so we don't need to test it + pass + + +class TestLookupTableOpWithPaddingInt8(TestLookupTableOpInt8): + def test_check_output(self): + ids = np.squeeze(self.inputs['Ids']) + padding_idx = np.random.choice(ids, 1)[0] + self.outputs['Out'][ids == padding_idx] = np.zeros(31) + self.attrs = {'padding_idx': int(padding_idx)} + self.check_output() + + def test_check_grad(self): + # Since paddings are not trainable and fixed in forward, the gradient of + # paddings makes no sense and we don't test the gradient here. + pass + + +class TestLookupTableOpWithTensorIdsAndPaddingInt8( + TestLookupTableOpWithTensorIdsInt8): + def test_check_output(self): + ids = self.inputs['Ids'] + flatten_idx = ids.flatten() + padding_idx = np.random.choice(flatten_idx, 1)[0] + self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) + self.attrs = {'padding_idx': cpt.long_type(padding_idx)} + self.check_output() + + def test_check_grad(self): + # Since paddings are not trainable and fixed in forward, the gradient of + # paddings makes no sense and we don't test the gradient here. + pass + + +class TestLookupTableWIsSelectedRowsInt8(OpTest): + def prepare_ids(self, scope, place): + ids_tensor = scope.var('Ids').get_tensor() + ids_array = np.array([[0], [4], [3], [5]]).astype("int64") + ids_tensor.set(ids_array, place) + return ids_array + + def prepare_w(self, scope, place): + rows = [0, 1, 2, 3, 4, 5, 6] + row_numel = 12 + + w_selected_rows = scope.var('W').get_selected_rows() + w_selected_rows.set_height(len(rows)) + w_selected_rows.set_rows(rows) + w_array = np.ones((len(rows), row_numel)).astype("int8") + for i in range(len(rows)): + w_array[i] *= i + w_tensor = w_selected_rows.get_tensor() + w_tensor.set(w_array, place) + + def create_out_tensor(self, scope, place): + return scope.var('Out').get_tensor() + + def check_result(self, ids_array, result_array): + # all(): return True if all elements of the iterable are true (or if the iterable is empty) + for idx, row in enumerate(ids_array): + assert (row[0] == result_array[idx]).all() + + def check_with_place(self, place): + scope = core.Scope() + + ids_array = self.prepare_ids(scope, place) + + self.prepare_w(scope, place) + + out_tensor = self.create_out_tensor(scope, place) + + # create and run lookup_table operator + lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out') + lookup_table.run(scope, place) + + # get result from Out + result_array = np.array(out_tensor) + + self.check_result(ids_array, result_array) + + def test_w_is_selected_rows(self): + places = [core.CPUPlace()] + # currently only support CPU + for place in places: + self.check_with_place(place) + + +class TestLookupTableWithTensorIdsWIsSelectedRowsInt8( + TestLookupTableWIsSelectedRowsInt8): + def prepare_ids(self, scope, place): + ids_tensor = scope.var('Ids').get_tensor() + ids_array = np.random.randint( + low=0, high=6, size=(2, 4, 3, 1)).astype("int64") + ids_tensor.set(ids_array, place) + return ids_array + + def check_result(self, ids_array, result_array): + for idx, row in np.ndenumerate(ids_array): + assert (row == result_array[idx]).all() + + if __name__ == "__main__": unittest.main() -- GitLab