未验证 提交 4db03190 编写于 作者: L Liufang Sang 提交者: GitHub

add dequantize_log_op and make pyramid hash support int8 weight (#22548)

* add dequantize_log_op and make pyramid hash support int8 weight test=develop

* add unittest and update pyramid hash op test=develop

* remove paddle_enforce test=develop

* fix error message test=develop

* remove incorrent commit test=develop

* fix error message in log_dequantize test=develop

* change 2019 to 2020 test=develop

* remove useless check_grad test=develop
上级 e5fef8f3
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dequantize_log_op.h"
#include <math.h>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
template <typename T>
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* dict,
framework::Tensor* out) {
const float* dict_data = dict->data<float>();
const T* input_data = in->data<T>();
float* output_data = out->mutable_data<float>(dev_ctx.GetPlace());
int ind = in->numel();
for (size_t i = 0; i < (unsigned)ind; i++) {
if (input_data[i] < 0) {
output_data[i] = -pow(2, dict_data[input_data[i] + 128]);
} else {
output_data[i] = pow(2, dict_data[input_data[i]]);
}
}
}
};
template struct DequantizeFunctor<platform::CPUDeviceContext, int8_t>;
class DequantizeLogOp : public framework::OperatorWithKernel {
public:
DequantizeLogOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of DequantizeLogOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of DequantizeLogOp is not found."));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
auto type = framework::OpKernelType(data_type, ctx.device_context());
return type;
}
};
class DequantizeLogOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(int8 Tensor) The input with int8 type is the "
"low precision tensor.");
AddInput("Dict", "(float) The Dict in quantization stage.");
AddOutput("Out",
"(float32 Tensor) The output is the dequantized high "
"precision tensor.");
AddComment(R"DOC(
DequantizeLogOp operator.
This calculation is an opposite operation of QuantizeLogOp:
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
dequantize_log, ops::DequantizeLogOp, ops::DequantizeLogOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(dequantize_log, ops::DequantizeLogKernel<CPU, int8_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dequantize_log_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void KeDequantize(const T* in, const float* dict, int num,
float* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < num) {
if (in[idx] < 0) {
out[idx] = -pow(2, dict[in[idx] + 128]);
} else {
out[idx] = pow(2, dict[in[idx]]);
}
}
}
template <typename T>
struct DequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* dict,
framework::Tensor* out) {
const T* in_data = in->data<T>();
const float* dict_data = dict->data<float>();
float* out_data = out->mutable_data<float>(dev_ctx.GetPlace());
int num = in->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeDequantize<T><<<grid, block, 0, dev_ctx.stream()>>>(in_data, dict_data,
num, out_data);
}
};
template struct DequantizeFunctor<platform::CUDADeviceContext, int8_t>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(dequantize_log, ops::DequantizeLogKernel<CUDA, int8_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
struct DequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor* dict, framework::Tensor* out);
};
template <typename DeviceContext, typename T>
class DequantizeLogKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>("X");
auto* dict = ctx.Input<framework::Tensor>("Dict");
auto* out = ctx.Output<framework::Tensor>("Out");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
out->mutable_data<float>(dev_ctx.GetPlace());
DequantizeFunctor<DeviceContext, T>()(dev_ctx, in, dict, out);
}
};
} // namespace operators
} // namespace paddle
......@@ -84,52 +84,111 @@ class PyramidHashOP : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "X(Input) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, "W(Input) should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of PyramidHashOP is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("W"), true,
platform::errors::NotFound("Input(W) of PyramidHashOP is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Out(Output) should not be null.");
platform::errors::NotFound(
"Output(Out) of PyramidHashOP is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("DropPos"), true,
"DropPos(TMP Output) should not be null.");
platform::errors::NotFound(
"Output(DropPos) of PyramidHashOP is not found."));
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The rank of X(Input) should be 2.");
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(X) of PyramidHashOP is invalid. "
"It should be 2, but got %d",
x_dims.size()));
auto w_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "W should be 2-D tensor");
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(W) of PyramidHashOP is invalid. "
"It should be 2, but got %d",
w_dims.size()));
int space_len = ctx->Attrs().Get<int>("space_len");
int rand_len = ctx->Attrs().Get<int>("rand_len");
PADDLE_ENFORCE_EQ(w_dims[0], space_len + rand_len,
"w_dims[0] should be equal to (space_len + rand_len)");
PADDLE_ENFORCE_EQ(w_dims[1], 1, "w_dims[1] should be equal to 1");
PADDLE_ENFORCE_EQ(
w_dims[0], space_len + rand_len,
platform::errors::InvalidArgument(
"The first dimension of Input(W) of PyramidHashOP is invalid. "
"It should be space_len + rand_len, but now %d != %d + %d",
w_dims[0], space_len, rand_len));
PADDLE_ENFORCE_EQ(
w_dims[1], 1,
platform::errors::InvalidArgument(
"The second dimension of Input(W) of PyramidHashOP is invalid."
" It should be 1, but got %d",
w_dims[1]));
int num_emb = ctx->Attrs().Get<int>("num_emb");
PADDLE_ENFORCE_EQ(num_emb % rand_len, 0,
"random length should mod embedding size");
PADDLE_ENFORCE_EQ(
num_emb % rand_len, 0,
platform::errors::InvalidArgument(
"The PyramidHashOP's Attr(num_emb) should mod Attr(rand_len), "
"but num_emb is %d, rand_len is %d",
num_emb, rand_len));
int white_list_len = ctx->Attrs().Get<int>("white_list_len");
if (white_list_len > 0) {
PADDLE_ENFORCE_EQ(
ctx->HasInput("WhiteList"), true,
"WhiteList(Input) should not be null when white_list_len > 0");
platform::errors::NotFound("Input(WhiteList) of PyramidHashOP is not "
"found but white_list_len > 0."));
auto wl_dims = ctx->GetInputDim("WhiteList");
PADDLE_ENFORCE_EQ(wl_dims.size(), 2, "WhiteList should be 2-D tensor");
PADDLE_ENFORCE_EQ(
wl_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(WhiteList) of PyramidHashOP is invalid."
" It should be 2, but got %d",
wl_dims.size()));
PADDLE_ENFORCE_EQ(wl_dims[0], white_list_len,
"wl_dims[0] should be equal to white_list_len");
PADDLE_ENFORCE_EQ(wl_dims[1], 1, "wl_dims[1] should be equal to 1");
platform::errors::InvalidArgument(
"The first dimension of Input(WhiteList) of "
"PyramidHashOP is invalid."
" It should be equal to Attr(white_list_len) "
", but first dimension is %d, white_list_len is %d",
wl_dims[0], white_list_len));
PADDLE_ENFORCE_EQ(wl_dims[1], 1,
platform::errors::InvalidArgument(
"The second dimension of Input(WhiteList) of "
"PyramidHashOP is invalid."
" It should be 1, but got %d",
wl_dims[1]));
}
int black_list_len = ctx->Attrs().Get<int>("black_list_len");
if (black_list_len > 0) {
PADDLE_ENFORCE_EQ(
ctx->HasInput("BlackList"), true,
"BlackList(Input) should not be null when black_list_len > 0");
platform::errors::NotFound("Input(BlackList) of PyramidHashOP is not "
"found but black_list_len > 0."));
auto bl_dims = ctx->GetInputDim("BlackList");
PADDLE_ENFORCE_EQ(bl_dims.size(), 2, "BlackList should be 2-D tensor");
PADDLE_ENFORCE_EQ(
bl_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(BlackList) of PyramidHashOP is invalid."
" It should be 2, but got %d",
bl_dims.size()));
PADDLE_ENFORCE_EQ(bl_dims[0], black_list_len,
"bl_dims[0] should be equal to black_list_len");
PADDLE_ENFORCE_EQ(bl_dims[1], 1, "bl_dims[1] should be equal to 1");
platform::errors::InvalidArgument(
"The first dimension of Input(BlackList) of "
"PyramidHashOP is invalid."
" It should be equal to Attr(black_list_len)"
", but first dimension is %d, black_list_len is %d",
bl_dims[0], black_list_len));
PADDLE_ENFORCE_EQ(bl_dims[1], 1,
platform::errors::InvalidArgument(
"The second dimension of Input(BlackList) of "
"PyramidHashOP is invalid."
" It should be 1, but got %d",
bl_dims[1]));
}
if (ctx->IsRuntime()) {
......@@ -154,20 +213,22 @@ template <typename DeviceContext, typename T>
class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
public:
bool should_use_term(math::bloomfilter* _filter,
math::bloomfilter* _black_filter, const T* word_repr,
math::bloomfilter* _black_filter, const float* word_repr,
int len) const {
return (!_filter ||
1 == math::bloomfilter_get(_filter, word_repr, len * sizeof(T))) &&
1 == math::bloomfilter_get(_filter, word_repr,
len * sizeof(float))) &&
(!_black_filter ||
0 == math::bloomfilter_get(_black_filter, word_repr,
len * sizeof(T)));
len * sizeof(float)));
}
void hash_embedding_ff(const T* hash_id, int len, T* top_pos,
void hash_embedding_ff(const float* hash_id, int len, T* top_pos,
const T* weights, int _num_emb, int _rand_len,
int _space_len) const {
unsigned int pos1 = XXH32(hash_id, len * sizeof(T), 0) % _space_len;
unsigned int pos2 = XXH32(hash_id, len * sizeof(T), _rand_len) % _space_len;
unsigned int pos1 = XXH32(hash_id, len * sizeof(float), 0) % _space_len;
unsigned int pos2 =
XXH32(hash_id, len * sizeof(float), _rand_len) % _space_len;
for (int j = 0; j != _num_emb; j += _rand_len) {
if (j + _rand_len < _num_emb) {
......@@ -176,8 +237,8 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
}
unsigned int pos3 =
XXH32(hash_id, len * sizeof(T), j + 2 * _rand_len) % _space_len;
memcpy(top_pos + j, const_cast<float*>(weights + pos1),
XXH32(hash_id, len * sizeof(float), j + 2 * _rand_len) % _space_len;
memcpy(top_pos + j, const_cast<T*>(weights + pos1),
_rand_len * sizeof(T));
pos1 = pos2;
pos2 = pos3;
......@@ -208,7 +269,7 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
const auto* bottom_data_ori = bottom->data<int32_t>();
auto* buff = ctx.Output<LoDTensor>("X_Temp_Out");
buff->Resize(framework::make_ddim({bottom->dims()[0], bottom->dims()[1]}));
T* bottom_data = buff->mutable_data<T>(ctx.GetPlace());
float* bottom_data = buff->mutable_data<float>(ctx.GetPlace());
for (int i = 0; i < bottom->dims()[0]; i++) {
bottom_data[i] = bottom_data_ori[i];
}
......@@ -223,12 +284,12 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
math::bloomfilter* _black_filter = NULL;
if (use_filter) {
if (white_list_len != 0) {
_filter = (math::bloomfilter*)_blobs_1->data<T>();
_filter = (math::bloomfilter*)_blobs_1->data<float>();
PADDLE_ENFORCE_EQ(math::bloomfilter_check(_filter), 1,
"white filter not load");
}
if (black_list_len != 0) {
_black_filter = (math::bloomfilter*)_blobs_2->data<T>();
_black_filter = (math::bloomfilter*)_blobs_2->data<float>();
PADDLE_ENFORCE_EQ(math::bloomfilter_check(_black_filter), 1,
"black filter not load");
}
......@@ -251,11 +312,11 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) {
for (int l = 0; l < w - ilayer; ++l) {
if (should_use_term(_filter, _black_filter,
(const T*)(bottom_data + offset[i] + l),
(const float*)(bottom_data + offset[i] + l),
ilayer + 1)) {
if (_is_training != 0) {
unsigned int rand_val = rand_r(&_seed);
T rate = static_cast<T>(rand_val) / (RAND_MAX);
float rate = static_cast<float>(rand_val) / (RAND_MAX);
*(iter_end++) = (rate < _drop_out_percent ? 0 : 1);
} else {
*(iter_end++) = 1;
......@@ -311,7 +372,7 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
// do nothing
} else {
auto* top_pos = top_data + top_counter++ * _num_emb;
hash_embedding_ff((const T*)(bottom_data + offset[i] + l),
hash_embedding_ff((const float*)(bottom_data + offset[i] + l),
ilayer + 1, top_pos, weights, _num_emb,
_rand_len, _space_len);
}
......@@ -322,7 +383,8 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
if (iter != iter_end) {
exit(1);
}
if (_is_training == 0) {
auto weight_type = _blobs_0->type();
if (_is_training == 0 && weight_type != framework::proto::VarType::INT8) {
avx_axpy_noadd(top_data, top_data, top->dims()[0] * top->dims()[1],
_drop_out_percent);
}
......@@ -334,15 +396,23 @@ class PyramidHashOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, "Input(W) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of PyramidHashOpGrad is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
platform::errors::NotFound(
"Input(W) of PyramidHashOpGrad is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("DropPos"), true,
"Input(DropPos) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X_Temp_Out"), true,
"Input(X_Temp_Out) should not be null.");
platform::errors::NotFound(
"Input(DropPos) of PyramidHashOpGrad is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("X_Temp_Out"), true,
platform::errors::NotFound(
"Input(X_Temp_Out) of PyramidHashOpGrad is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) of PyramidHashGradOp should not be null.");
platform::errors::NotFound(
"Input(Out@Grad) of PyramidHashOpGrad is not found."));
}
protected:
......@@ -410,6 +480,7 @@ class CPUPyramidHashOPGradKernel : public framework::OpKernel<T> {
auto& drop_pos_offset = drop_pos->lod()[0];
const auto* top_diff = top->data<T>();
// in-place update weight, so need const_cast
T* weights = const_cast<T*>(_blobs->data<T>());
T mlr = -1.0 * _lr;
......@@ -453,7 +524,10 @@ REGISTER_OPERATOR(pyramid_hash, ops::PyramidHashOP, ops::PyramidHashOpMaker,
REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad);
REGISTER_OP_CPU_KERNEL(
pyramid_hash, ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, float>);
pyramid_hash, ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, float>,
ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, double>,
ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, int8_t>);
REGISTER_OP_CPU_KERNEL(
pyramid_hash_grad,
ops::CPUPyramidHashOPGradKernel<plt::CPUDeviceContext, float>);
ops::CPUPyramidHashOPGradKernel<plt::CPUDeviceContext, float>,
ops::CPUPyramidHashOPGradKernel<plt::CPUDeviceContext, double>);
......@@ -83,8 +83,13 @@ static const unsigned int AVX_CUT_LEN_MASK = 7U;
#define _mm256_store_px _mm256_storeu_ps
#define _mm256_broadcast_sx _mm256_broadcast_ss
template <typename T>
inline void avx_axpy(const T* x, T* y, size_t len, const T alpha) {
#define _mm256_mul_pd _mm256_mul_pd
#define _mm256_add_pd _mm256_add_pd
#define _mm256_load_pd _mm256_loadu_pd
#define _mm256_store_pd _mm256_storeu_pd
#define _mm256_broadcast_sd _mm256_broadcast_sd
inline void avx_axpy(const float* x, float* y, size_t len, const float alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
......@@ -102,8 +107,43 @@ inline void avx_axpy(const T* x, T* y, size_t len, const T alpha) {
}
}
template <typename T>
inline void avx_axpy_noadd(const T* x, T* y, size_t len, const T alpha) {
inline void avx_axpy(const double* x, double* y, size_t len,
const float alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
lll = len & ~AVX_CUT_LEN_MASK;
double alpha_d = static_cast<double>(alpha);
__m256d mm_alpha = _mm256_broadcast_sd(&alpha_d);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
_mm256_store_pd(
y + jjj,
_mm256_add_pd(_mm256_load_pd(y + jjj),
_mm256_mul_pd(mm_alpha, _mm256_load_pd(x + jjj))));
}
for (; jjj < len; jjj++) {
y[jjj] += alpha * x[jjj];
}
}
inline void avx_axpy_noadd(const double* x, double* y, size_t len,
const float alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
double alpha_d = static_cast<double>(alpha);
lll = len & ~AVX_CUT_LEN_MASK;
__m256d mm_alpha = _mm256_broadcast_sd(&alpha_d);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
_mm256_store_pd(y + jjj, _mm256_mul_pd(mm_alpha, _mm256_load_pd(x + jjj)));
}
for (; jjj < len; jjj++) {
y[jjj] = alpha * x[jjj];
}
}
inline void avx_axpy_noadd(const float* x, float* y, size_t len,
const float alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
......@@ -117,6 +157,11 @@ inline void avx_axpy_noadd(const T* x, T* y, size_t len, const T alpha) {
y[jjj] = alpha * x[jjj];
}
}
inline void avx_axpy_noadd(const int8_t* x, int8_t* y, size_t len,
const float alpha) {
PADDLE_THROW(platform::errors::Unimplemented(
"int8_t input of avx_axpy_noadd is not supported"));
}
} // namespace operators
} // namespace paddle
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
from op_test import OpTest
def dequantize_log(x, dict_data):
output_data = np.zeros_like(x).astype('float32')
x_f = x.flatten()
output_data_f = output_data.flatten()
for i in range(x_f.size):
if x_f[i] < 0:
output_data_f[i] = -np.power(2, dict_data[x_f[i] + 128])
else:
output_data_f[i] = np.power(2, dict_data[x_f[i]])
return output_data_f.reshape(x.shape)
class TestDequantizeLogOp(OpTest):
def setUp(self):
self.op_type = "dequantize_log"
x = np.random.randint(low=-128, high=127, size=(20, 10)).astype('int8')
dict_data = np.random.random(128).astype('float32')
xdq = dequantize_log(x, dict_data)
self.inputs = {
'X': np.array(x).astype('int8'),
'Dict': np.array(dict_data).astype('float32')
}
self.outputs = {'Out': xdq}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册