From fd8e833b57039124c8292a943cef6f4d66c3fdbe Mon Sep 17 00:00:00 2001 From: Liufang Sang Date: Fri, 3 Apr 2020 01:03:49 -0500 Subject: [PATCH] Cherry pick pyramid hash support int8 weight and add dequantize_log op (#23371) * cherry-pick add dequantize_log_op and make pyramid hash support int8 weight (#22548) * add dequantize_log_op and make pyramid hash support int8 weight test=develop * add unittest and update pyramid hash op test=develop * remove paddle_enforce test=develop * fix error message test=develop * remove incorrent commit test=develop * fix error message in log_dequantize test=develop * change 2019 to 2020 test=develop * remove useless check_grad test=develop * cherry-pick fix compile error in win gpu (#23196) * fix compile error in win gpu test=develop * fix compile error in win gpu test=develop * fix compile error in win gpu test=develop --- paddle/fluid/operators/dequantize_log_op.cc | 103 ++++++++++++ paddle/fluid/operators/dequantize_log_op.cu | 61 +++++++ paddle/fluid/operators/dequantize_log_op.h | 46 +++++ paddle/fluid/operators/pyramid_hash_op.cc | 158 +++++++++++++----- paddle/fluid/operators/search_compute.h | 53 +++++- .../tests/unittests/test_dequantize_log_op.py | 53 ++++++ 6 files changed, 428 insertions(+), 46 deletions(-) create mode 100644 paddle/fluid/operators/dequantize_log_op.cc create mode 100644 paddle/fluid/operators/dequantize_log_op.cu create mode 100644 paddle/fluid/operators/dequantize_log_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_dequantize_log_op.py diff --git a/paddle/fluid/operators/dequantize_log_op.cc b/paddle/fluid/operators/dequantize_log_op.cc new file mode 100644 index 0000000000..bfd26061e3 --- /dev/null +++ b/paddle/fluid/operators/dequantize_log_op.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/dequantize_log_op.h" +#include +#include +#include + +namespace paddle { +namespace operators { + +template +struct DequantizeFunctor { + void operator()(const platform::CPUDeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* dict, + framework::Tensor* out) { + const float* dict_data = dict->data(); + const T* input_data = in->data(); + float* output_data = out->mutable_data(dev_ctx.GetPlace()); + int ind = in->numel(); + for (size_t i = 0; i < (unsigned)ind; i++) { + if (input_data[i] < 0) { + output_data[i] = -std::pow(2.0, dict_data[input_data[i] + 128]); + } else { + output_data[i] = std::pow(2.0, dict_data[input_data[i]]); + } + } + } +}; + +template struct DequantizeFunctor; + +class DequantizeLogOp : public framework::OperatorWithKernel { + public: + DequantizeLogOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::NotFound( + "Input(X) of DequantizeLogOp is not found.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::NotFound( + "Output(Out) of DequantizeLogOp is not found.")); + + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + auto type = framework::OpKernelType(data_type, ctx.device_context()); + return type; + } +}; + +class DequantizeLogOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(int8 Tensor) The input with int8 type is the " + "low precision tensor."); + AddInput("Dict", "(float) The Dict in quantization stage."); + AddOutput("Out", + "(float32 Tensor) The output is the dequantized high " + "precision tensor."); + AddComment(R"DOC( +DequantizeLogOp operator. + +This calculation is an opposite operation of QuantizeLogOp: + + + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR( + dequantize_log, ops::DequantizeLogOp, ops::DequantizeLogOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(dequantize_log, ops::DequantizeLogKernel); diff --git a/paddle/fluid/operators/dequantize_log_op.cu b/paddle/fluid/operators/dequantize_log_op.cu new file mode 100644 index 0000000000..57bad318ab --- /dev/null +++ b/paddle/fluid/operators/dequantize_log_op.cu @@ -0,0 +1,61 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/dequantize_log_op.h" +#include "paddle/fluid/operators/math.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +template +__global__ void KeDequantize(const T* in, const float* dict, int num, + float* out) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < num) { + if (in[idx] < 0) { + out[idx] = -std::pow(static_cast(2.0), dict[in[idx] + 128]); + } else { + out[idx] = std::pow(static_cast(2.0), dict[in[idx]]); + } + } +} + +template +struct DequantizeFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* dict, + framework::Tensor* out) { + const T* in_data = in->data(); + const float* dict_data = dict->data(); + float* out_data = out->mutable_data(dev_ctx.GetPlace()); + + int num = in->numel(); + int block = 512; + int grid = (num + block - 1) / block; + + KeDequantize<<>>(in_data, dict_data, + num, out_data); + } +}; + +template struct DequantizeFunctor; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(dequantize_log, ops::DequantizeLogKernel); diff --git a/paddle/fluid/operators/dequantize_log_op.h b/paddle/fluid/operators/dequantize_log_op.h new file mode 100644 index 0000000000..f6590ecf61 --- /dev/null +++ b/paddle/fluid/operators/dequantize_log_op.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +struct DequantizeFunctor { + void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, + const framework::Tensor* dict, framework::Tensor* out); +}; + +template +class DequantizeLogKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto* in = ctx.Input("X"); + auto* dict = ctx.Input("Dict"); + auto* out = ctx.Output("Out"); + + auto& dev_ctx = ctx.template device_context(); + out->mutable_data(dev_ctx.GetPlace()); + + DequantizeFunctor()(dev_ctx, in, dict, out); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/pyramid_hash_op.cc b/paddle/fluid/operators/pyramid_hash_op.cc index fc6aab07fa..0cc4fdafb7 100644 --- a/paddle/fluid/operators/pyramid_hash_op.cc +++ b/paddle/fluid/operators/pyramid_hash_op.cc @@ -84,52 +84,111 @@ class PyramidHashOP : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "X(Input) should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, "W(Input) should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of PyramidHashOP is not found.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("W"), true, + platform::errors::NotFound("Input(W) of PyramidHashOP is not found.")); PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - "Out(Output) should not be null."); + platform::errors::NotFound( + "Output(Out) of PyramidHashOP is not found.")); PADDLE_ENFORCE_EQ(ctx->HasOutput("DropPos"), true, - "DropPos(TMP Output) should not be null."); + platform::errors::NotFound( + "Output(DropPos) of PyramidHashOP is not found.")); auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The rank of X(Input) should be 2."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, + platform::errors::InvalidArgument( + "The rank of Input(X) of PyramidHashOP is invalid. " + "It should be 2, but got %d", + x_dims.size())); auto w_dims = ctx->GetInputDim("W"); - PADDLE_ENFORCE_EQ(w_dims.size(), 2, "W should be 2-D tensor"); + PADDLE_ENFORCE_EQ(w_dims.size(), 2, + platform::errors::InvalidArgument( + "The rank of Input(W) of PyramidHashOP is invalid. " + "It should be 2, but got %d", + w_dims.size())); int space_len = ctx->Attrs().Get("space_len"); int rand_len = ctx->Attrs().Get("rand_len"); - PADDLE_ENFORCE_EQ(w_dims[0], space_len + rand_len, - "w_dims[0] should be equal to (space_len + rand_len)"); - PADDLE_ENFORCE_EQ(w_dims[1], 1, "w_dims[1] should be equal to 1"); + PADDLE_ENFORCE_EQ( + w_dims[0], space_len + rand_len, + platform::errors::InvalidArgument( + "The first dimension of Input(W) of PyramidHashOP is invalid. " + "It should be space_len + rand_len, but now %d != %d + %d", + w_dims[0], space_len, rand_len)); + PADDLE_ENFORCE_EQ( + w_dims[1], 1, + platform::errors::InvalidArgument( + "The second dimension of Input(W) of PyramidHashOP is invalid." + " It should be 1, but got %d", + w_dims[1])); int num_emb = ctx->Attrs().Get("num_emb"); - PADDLE_ENFORCE_EQ(num_emb % rand_len, 0, - "random length should mod embedding size"); + PADDLE_ENFORCE_EQ( + num_emb % rand_len, 0, + platform::errors::InvalidArgument( + "The PyramidHashOP's Attr(num_emb) should mod Attr(rand_len), " + "but num_emb is %d, rand_len is %d", + num_emb, rand_len)); int white_list_len = ctx->Attrs().Get("white_list_len"); if (white_list_len > 0) { PADDLE_ENFORCE_EQ( ctx->HasInput("WhiteList"), true, - "WhiteList(Input) should not be null when white_list_len > 0"); + platform::errors::NotFound("Input(WhiteList) of PyramidHashOP is not " + "found but white_list_len > 0.")); auto wl_dims = ctx->GetInputDim("WhiteList"); - PADDLE_ENFORCE_EQ(wl_dims.size(), 2, "WhiteList should be 2-D tensor"); + PADDLE_ENFORCE_EQ( + wl_dims.size(), 2, + platform::errors::InvalidArgument( + "The rank of Input(WhiteList) of PyramidHashOP is invalid." + " It should be 2, but got %d", + wl_dims.size())); PADDLE_ENFORCE_EQ(wl_dims[0], white_list_len, - "wl_dims[0] should be equal to white_list_len"); - PADDLE_ENFORCE_EQ(wl_dims[1], 1, "wl_dims[1] should be equal to 1"); + platform::errors::InvalidArgument( + "The first dimension of Input(WhiteList) of " + "PyramidHashOP is invalid." + " It should be equal to Attr(white_list_len) " + ", but first dimension is %d, white_list_len is %d", + wl_dims[0], white_list_len)); + PADDLE_ENFORCE_EQ(wl_dims[1], 1, + platform::errors::InvalidArgument( + "The second dimension of Input(WhiteList) of " + "PyramidHashOP is invalid." + " It should be 1, but got %d", + wl_dims[1])); } int black_list_len = ctx->Attrs().Get("black_list_len"); if (black_list_len > 0) { PADDLE_ENFORCE_EQ( ctx->HasInput("BlackList"), true, - "BlackList(Input) should not be null when black_list_len > 0"); + platform::errors::NotFound("Input(BlackList) of PyramidHashOP is not " + "found but black_list_len > 0.")); auto bl_dims = ctx->GetInputDim("BlackList"); - PADDLE_ENFORCE_EQ(bl_dims.size(), 2, "BlackList should be 2-D tensor"); + PADDLE_ENFORCE_EQ( + bl_dims.size(), 2, + platform::errors::InvalidArgument( + "The rank of Input(BlackList) of PyramidHashOP is invalid." + " It should be 2, but got %d", + bl_dims.size())); PADDLE_ENFORCE_EQ(bl_dims[0], black_list_len, - "bl_dims[0] should be equal to black_list_len"); - PADDLE_ENFORCE_EQ(bl_dims[1], 1, "bl_dims[1] should be equal to 1"); + platform::errors::InvalidArgument( + "The first dimension of Input(BlackList) of " + "PyramidHashOP is invalid." + " It should be equal to Attr(black_list_len)" + ", but first dimension is %d, black_list_len is %d", + bl_dims[0], black_list_len)); + PADDLE_ENFORCE_EQ(bl_dims[1], 1, + platform::errors::InvalidArgument( + "The second dimension of Input(BlackList) of " + "PyramidHashOP is invalid." + " It should be 1, but got %d", + bl_dims[1])); } if (ctx->IsRuntime()) { @@ -154,20 +213,22 @@ template class CPUPyramidHashOPKernel : public framework::OpKernel { public: bool should_use_term(math::bloomfilter* _filter, - math::bloomfilter* _black_filter, const T* word_repr, + math::bloomfilter* _black_filter, const float* word_repr, int len) const { return (!_filter || - 1 == math::bloomfilter_get(_filter, word_repr, len * sizeof(T))) && + 1 == math::bloomfilter_get(_filter, word_repr, + len * sizeof(float))) && (!_black_filter || 0 == math::bloomfilter_get(_black_filter, word_repr, - len * sizeof(T))); + len * sizeof(float))); } - void hash_embedding_ff(const T* hash_id, int len, T* top_pos, + void hash_embedding_ff(const float* hash_id, int len, T* top_pos, const T* weights, int _num_emb, int _rand_len, int _space_len) const { - unsigned int pos1 = XXH32(hash_id, len * sizeof(T), 0) % _space_len; - unsigned int pos2 = XXH32(hash_id, len * sizeof(T), _rand_len) % _space_len; + unsigned int pos1 = XXH32(hash_id, len * sizeof(float), 0) % _space_len; + unsigned int pos2 = + XXH32(hash_id, len * sizeof(float), _rand_len) % _space_len; for (int j = 0; j != _num_emb; j += _rand_len) { if (j + _rand_len < _num_emb) { @@ -176,8 +237,8 @@ class CPUPyramidHashOPKernel : public framework::OpKernel { } unsigned int pos3 = - XXH32(hash_id, len * sizeof(T), j + 2 * _rand_len) % _space_len; - memcpy(top_pos + j, const_cast(weights + pos1), + XXH32(hash_id, len * sizeof(float), j + 2 * _rand_len) % _space_len; + memcpy(top_pos + j, const_cast(weights + pos1), _rand_len * sizeof(T)); pos1 = pos2; pos2 = pos3; @@ -208,7 +269,7 @@ class CPUPyramidHashOPKernel : public framework::OpKernel { const auto* bottom_data_ori = bottom->data(); auto* buff = ctx.Output("X_Temp_Out"); buff->Resize(framework::make_ddim({bottom->dims()[0], bottom->dims()[1]})); - T* bottom_data = buff->mutable_data(ctx.GetPlace()); + float* bottom_data = buff->mutable_data(ctx.GetPlace()); for (int i = 0; i < bottom->dims()[0]; i++) { bottom_data[i] = bottom_data_ori[i]; } @@ -223,12 +284,12 @@ class CPUPyramidHashOPKernel : public framework::OpKernel { math::bloomfilter* _black_filter = NULL; if (use_filter) { if (white_list_len != 0) { - _filter = (math::bloomfilter*)_blobs_1->data(); + _filter = (math::bloomfilter*)_blobs_1->data(); PADDLE_ENFORCE_EQ(math::bloomfilter_check(_filter), 1, "white filter not load"); } if (black_list_len != 0) { - _black_filter = (math::bloomfilter*)_blobs_2->data(); + _black_filter = (math::bloomfilter*)_blobs_2->data(); PADDLE_ENFORCE_EQ(math::bloomfilter_check(_black_filter), 1, "black filter not load"); } @@ -251,11 +312,11 @@ class CPUPyramidHashOPKernel : public framework::OpKernel { for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) { for (int l = 0; l < w - ilayer; ++l) { if (should_use_term(_filter, _black_filter, - (const T*)(bottom_data + offset[i] + l), + (const float*)(bottom_data + offset[i] + l), ilayer + 1)) { if (_is_training != 0) { unsigned int rand_val = rand_r(&_seed); - T rate = static_cast(rand_val) / (RAND_MAX); + float rate = static_cast(rand_val) / (RAND_MAX); *(iter_end++) = (rate < _drop_out_percent ? 0 : 1); } else { *(iter_end++) = 1; @@ -311,7 +372,7 @@ class CPUPyramidHashOPKernel : public framework::OpKernel { // do nothing } else { auto* top_pos = top_data + top_counter++ * _num_emb; - hash_embedding_ff((const T*)(bottom_data + offset[i] + l), + hash_embedding_ff((const float*)(bottom_data + offset[i] + l), ilayer + 1, top_pos, weights, _num_emb, _rand_len, _space_len); } @@ -322,7 +383,8 @@ class CPUPyramidHashOPKernel : public framework::OpKernel { if (iter != iter_end) { exit(1); } - if (_is_training == 0) { + auto weight_type = _blobs_0->type(); + if (_is_training == 0 && weight_type != framework::proto::VarType::INT8) { avx_axpy_noadd(top_data, top_data, top->dims()[0] * top->dims()[1], _drop_out_percent); } @@ -334,15 +396,23 @@ class PyramidHashOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, "Input(W) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::NotFound( + "Input(X) of PyramidHashOpGrad is not found.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, + platform::errors::NotFound( + "Input(W) of PyramidHashOpGrad is not found.")); PADDLE_ENFORCE_EQ(ctx->HasInput("DropPos"), true, - "Input(DropPos) should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput("X_Temp_Out"), true, - "Input(X_Temp_Out) should not be null."); + platform::errors::NotFound( + "Input(DropPos) of PyramidHashOpGrad is not found.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("X_Temp_Out"), true, + platform::errors::NotFound( + "Input(X_Temp_Out) of PyramidHashOpGrad is not found.")); PADDLE_ENFORCE_EQ( ctx->HasInput(framework::GradVarName("Out")), true, - "Input(Out@GRAD) of PyramidHashGradOp should not be null."); + platform::errors::NotFound( + "Input(Out@Grad) of PyramidHashOpGrad is not found.")); } protected: @@ -412,6 +482,7 @@ class CPUPyramidHashOPGradKernel : public framework::OpKernel { auto& drop_pos_offset = drop_pos->lod()[0]; const auto* top_diff = top->data(); + // in-place update weight, so need const_cast T* weights = const_cast(_blobs->data()); T mlr = -1.0 * _lr; @@ -455,7 +526,10 @@ REGISTER_OPERATOR(pyramid_hash, ops::PyramidHashOP, ops::PyramidHashOpMaker, REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad); REGISTER_OP_CPU_KERNEL( - pyramid_hash, ops::CPUPyramidHashOPKernel); + pyramid_hash, ops::CPUPyramidHashOPKernel, + ops::CPUPyramidHashOPKernel, + ops::CPUPyramidHashOPKernel); REGISTER_OP_CPU_KERNEL( pyramid_hash_grad, - ops::CPUPyramidHashOPGradKernel); + ops::CPUPyramidHashOPGradKernel, + ops::CPUPyramidHashOPGradKernel); diff --git a/paddle/fluid/operators/search_compute.h b/paddle/fluid/operators/search_compute.h index a2bf3a6bed..1caec6d393 100644 --- a/paddle/fluid/operators/search_compute.h +++ b/paddle/fluid/operators/search_compute.h @@ -83,8 +83,13 @@ static const unsigned int AVX_CUT_LEN_MASK = 7U; #define _mm256_store_px _mm256_storeu_ps #define _mm256_broadcast_sx _mm256_broadcast_ss -template -inline void avx_axpy(const T* x, T* y, size_t len, const T alpha) { +#define _mm256_mul_pd _mm256_mul_pd +#define _mm256_add_pd _mm256_add_pd +#define _mm256_load_pd _mm256_loadu_pd +#define _mm256_store_pd _mm256_storeu_pd +#define _mm256_broadcast_sd _mm256_broadcast_sd + +inline void avx_axpy(const float* x, float* y, size_t len, const float alpha) { unsigned int jjj, lll; jjj = lll = 0; @@ -102,8 +107,43 @@ inline void avx_axpy(const T* x, T* y, size_t len, const T alpha) { } } -template -inline void avx_axpy_noadd(const T* x, T* y, size_t len, const T alpha) { +inline void avx_axpy(const double* x, double* y, size_t len, + const float alpha) { + unsigned int jjj, lll; + jjj = lll = 0; + + lll = len & ~AVX_CUT_LEN_MASK; + double alpha_d = static_cast(alpha); + + __m256d mm_alpha = _mm256_broadcast_sd(&alpha_d); + for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) { + _mm256_store_pd( + y + jjj, + _mm256_add_pd(_mm256_load_pd(y + jjj), + _mm256_mul_pd(mm_alpha, _mm256_load_pd(x + jjj)))); + } + + for (; jjj < len; jjj++) { + y[jjj] += alpha * x[jjj]; + } +} +inline void avx_axpy_noadd(const double* x, double* y, size_t len, + const float alpha) { + unsigned int jjj, lll; + jjj = lll = 0; + double alpha_d = static_cast(alpha); + lll = len & ~AVX_CUT_LEN_MASK; + __m256d mm_alpha = _mm256_broadcast_sd(&alpha_d); + for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) { + _mm256_store_pd(y + jjj, _mm256_mul_pd(mm_alpha, _mm256_load_pd(x + jjj))); + } + + for (; jjj < len; jjj++) { + y[jjj] = alpha * x[jjj]; + } +} +inline void avx_axpy_noadd(const float* x, float* y, size_t len, + const float alpha) { unsigned int jjj, lll; jjj = lll = 0; @@ -117,6 +157,11 @@ inline void avx_axpy_noadd(const T* x, T* y, size_t len, const T alpha) { y[jjj] = alpha * x[jjj]; } } +inline void avx_axpy_noadd(const int8_t* x, int8_t* y, size_t len, + const float alpha) { + PADDLE_THROW(platform::errors::Unimplemented( + "int8_t input of avx_axpy_noadd is not supported")); +} } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_dequantize_log_op.py b/python/paddle/fluid/tests/unittests/test_dequantize_log_op.py new file mode 100644 index 0000000000..6c6f0811bb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dequantize_log_op.py @@ -0,0 +1,53 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +from op_test import OpTest + + +def dequantize_log(x, dict_data): + output_data = np.zeros_like(x).astype('float32') + x_f = x.flatten() + output_data_f = output_data.flatten() + for i in range(x_f.size): + if x_f[i] < 0: + output_data_f[i] = -np.power(2, dict_data[x_f[i] + 128]) + else: + output_data_f[i] = np.power(2, dict_data[x_f[i]]) + return output_data_f.reshape(x.shape) + + +class TestDequantizeLogOp(OpTest): + def setUp(self): + self.op_type = "dequantize_log" + x = np.random.randint(low=-128, high=127, size=(20, 10)).astype('int8') + dict_data = np.random.random(128).astype('float32') + xdq = dequantize_log(x, dict_data) + + self.inputs = { + 'X': np.array(x).astype('int8'), + 'Dict': np.array(dict_data).astype('float32') + } + self.outputs = {'Out': xdq} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() -- GitLab