提交 f0b15184 编写于 作者: L Liufang Sang 提交者: whs

add dequantize_abs_max op and modify lookup_table op (#20899)

* add int8 kernel to lookup_table op and add dequantize op test=develop

* change paddle_enforce to paddle_enforce_eq test=develop

* change copyright and change some not suitable code test=develop

* remove debug log test=develop

* replace GetInputType with IndicateVarDataType test=develop

* fix EmptyGradMaker test=develop

* fix diff between cpu and gpu test=develop

* use memcopy when int8_t test=develop
上级 a6ce2306
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dequantize_abs_max_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
template <typename T>
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale,
float max_range, framework::Tensor* out) {
const float* scale_factor = scale->data<float>();
const T* input_data = in->data<T>();
float* output_data = out->mutable_data<float>(dev_ctx.GetPlace());
int ind = in->numel();
for (size_t i = 0; i < (unsigned)ind; i++) {
output_data[i] = scale_factor[0] * input_data[i] / max_range;
}
}
};
template struct DequantizeFunctor<platform::CPUDeviceContext, int8_t>;
class DequantizeMaxAbsOp : public framework::OperatorWithKernel {
public:
DequantizeMaxAbsOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of DequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of DequantizeMaxAbsOp should not be null.");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
auto type = framework::OpKernelType(data_type, ctx.device_context());
return type;
}
};
class DequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(int8 Tensor) The input with int8 type is the "
"low precision tensor.");
AddInput("Scale", "(float) The scale in quantization stage.");
AddOutput("Out",
"(float32 Tensor) The output is the dequantized high "
"precision tensor.");
AddAttr<float>("max_range", "(float) The max range in quantization stage.");
AddComment(R"DOC(
DequantizeMaxAbsOp operator.
This calculation is an opposite operation of QuantizeMaxAbsOp:
$$Out = \frac{scale*X}{ max\_range }$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
dequantize_abs_max, ops::DequantizeMaxAbsOp, ops::DequantizeMaxAbsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(dequantize_abs_max,
ops::DequantizeMaxAbsKernel<CPU, int8_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dequantize_abs_max_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void KeDequantize(const T* in, const float* scale, float max_range,
int num, float* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < num) {
out[idx] = in[idx] * scale[0] / max_range;
}
}
template <typename T>
struct DequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale,
float max_range, framework::Tensor* out) {
const T* in_data = in->data<T>();
const float* scale_factor = scale->data<float>();
float* out_data = out->mutable_data<float>(dev_ctx.GetPlace());
int num = in->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeDequantize<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, out_data);
}
};
template struct DequantizeFunctor<platform::CUDADeviceContext, int8_t>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(dequantize_abs_max,
ops::DequantizeMaxAbsKernel<CUDA, int8_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
struct DequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor* scale, float max_range,
framework::Tensor* out);
};
template <typename DeviceContext, typename T>
class DequantizeMaxAbsKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* out = ctx.Output<framework::Tensor>("Out");
float max_range = ctx.Attr<float>("max_range");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
out->mutable_data<float>(dev_ctx.GetPlace());
DequantizeFunctor<DeviceContext, T>()(dev_ctx, in, scale, max_range, out);
}
};
} // namespace operators
} // namespace paddle
......@@ -205,6 +205,7 @@ REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
ops::LookupTableOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
ops::LookupTableKernel<double>);
ops::LookupTableKernel<double>,
ops::LookupTableKernel<int8_t>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
ops::LookupTableGradKernel<double>);
......@@ -199,7 +199,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
ops::LookupTableCUDAKernel<double>,
ops::LookupTableCUDAKernel<plat::float16>);
ops::LookupTableCUDAKernel<plat::float16>,
ops::LookupTableCUDAKernel<int8_t>);
REGISTER_OP_CUDA_KERNEL(lookup_table_grad,
ops::LookupTableGradCUDAKernel<float>,
ops::LookupTableGradCUDAKernel<double>,
......
......@@ -107,8 +107,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
auto input_data_type = table_t.value().type();
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
......@@ -122,8 +121,14 @@ class LookupTableKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GE(
id_index, 0, "the input key should be exists. But received %d.",
id_index);
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
if (input_data_type == framework::proto::VarType::INT8) {
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
} else {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
}
}
}
......
......@@ -24,6 +24,14 @@ namespace math {
template <typename T>
struct CBlas;
template <>
struct CBlas<int8_t> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
PADDLE_THROW("Blas VCOPY don't support int8_t");
}
};
#ifdef PADDLE_WITH_MKLML
template <>
struct CBlas<float> {
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
from op_test import OpTest
def quantize_max_abs(x, max_range):
scale = np.max(np.abs(x).flatten())
y = np.round(x / scale * max_range)
return y, scale
def dequantize_max_abs(x, scale, max_range):
y = (scale / max_range) * x
return y
class TestDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.num_bits = 8
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "int8"
def setUp(self):
self.set_args()
self.op_type = "dequantize_abs_max"
x = np.random.randn(31, 65).astype(self.data_type)
yq, scale = quantize_max_abs(x, self.max_range)
ydq = dequantize_max_abs(yq, scale, self.max_range)
self.inputs = {
'X': np.array(yq).astype(self.data_type),
'Scale': np.array(scale).astype('float32')
}
self.attrs = {'max_range': self.max_range}
self.outputs = {'Out': ydq}
def test_check_output(self):
self.check_output()
class TestDequantizeMaxAbsOp5Bits(TestDequantizeMaxAbsOp):
def set_args(self):
self.num_bits = 5
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "int8"
if __name__ == "__main__":
unittest.main()
......@@ -182,5 +182,140 @@ class TestEmbedOpError(OpTest):
fluid.layers.embedding(input=input3, size=(10, 64), dtype='float16')
class TestLookupTableOpInt8(OpTest):
def setUp(self):
self.op_type = "lookup_table"
table = np.random.randint(
low=-128, high=127, size=(17, 31)).astype("int8")
ids = np.random.randint(0, 17, 4).astype("int64")
ids_expand = np.expand_dims(ids, axis=1)
self.inputs = {'W': table, 'Ids': ids_expand}
self.outputs = {'Out': table[ids]}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
# since int8 type only be used in test and inference, there is
# no gradient implement, so we don't need to test it
pass
class TestLookupTableOpWithTensorIdsInt8(OpTest):
def setUp(self):
self.op_type = "lookup_table"
table = np.random.randint(
low=-128, high=127, size=(17, 31)).astype("int8")
ids = np.random.randint(
low=0, high=17, size=(2, 4, 5, 1)).astype("int64")
self.inputs = {'W': table, 'Ids': ids}
self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
# since int8 type only be used in test and inference, there is
# no gradient implement, so we don't need to test it
pass
class TestLookupTableOpWithPaddingInt8(TestLookupTableOpInt8):
def test_check_output(self):
ids = np.squeeze(self.inputs['Ids'])
padding_idx = np.random.choice(ids, 1)[0]
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output()
def test_check_grad(self):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
class TestLookupTableOpWithTensorIdsAndPaddingInt8(
TestLookupTableOpWithTensorIdsInt8):
def test_check_output(self):
ids = self.inputs['Ids']
flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
self.check_output()
def test_check_grad(self):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
class TestLookupTableWIsSelectedRowsInt8(OpTest):
def prepare_ids(self, scope, place):
ids_tensor = scope.var('Ids').get_tensor()
ids_array = np.array([[0], [4], [3], [5]]).astype("int64")
ids_tensor.set(ids_array, place)
return ids_array
def prepare_w(self, scope, place):
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 12
w_selected_rows = scope.var('W').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("int8")
for i in range(len(rows)):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
def create_out_tensor(self, scope, place):
return scope.var('Out').get_tensor()
def check_result(self, ids_array, result_array):
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(ids_array):
assert (row[0] == result_array[idx]).all()
def check_with_place(self, place):
scope = core.Scope()
ids_array = self.prepare_ids(scope, place)
self.prepare_w(scope, place)
out_tensor = self.create_out_tensor(scope, place)
# create and run lookup_table operator
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
lookup_table.run(scope, place)
# get result from Out
result_array = np.array(out_tensor)
self.check_result(ids_array, result_array)
def test_w_is_selected_rows(self):
places = [core.CPUPlace()]
# currently only support CPU
for place in places:
self.check_with_place(place)
class TestLookupTableWithTensorIdsWIsSelectedRowsInt8(
TestLookupTableWIsSelectedRowsInt8):
def prepare_ids(self, scope, place):
ids_tensor = scope.var('Ids').get_tensor()
ids_array = np.random.randint(
low=0, high=6, size=(2, 4, 3, 1)).astype("int64")
ids_tensor.set(ids_array, place)
return ids_array
def check_result(self, ids_array, result_array):
for idx, row in np.ndenumerate(ids_array):
assert (row == result_array[idx]).all()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册