From 9bd933d3fb355cc12a884c522a6fa28b47c8742a Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Mon, 3 Sep 2018 10:31:22 +0800 Subject: [PATCH] Improve and fix fake_quantize_op (#13092) * Improve and fix fake_quantize_op. --- paddle/fluid/operators/CMakeLists.txt | 3 + paddle/fluid/operators/fake_quantize_op.cc | 224 +++++++++---- paddle/fluid/operators/fake_quantize_op.cu | 293 ++++++------------ paddle/fluid/operators/fake_quantize_op.h | 180 ++++------- .../tests/unittests/test_fake_quantize_op.py | 45 ++- 5 files changed, 369 insertions(+), 376 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 2f8fadcefe5..7ec1e78da4e 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -178,6 +178,8 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(relu);\n") elseif(${TARGET} STREQUAL "fake_dequantize") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") + elseif(${TARGET} STREQUAL "fake_quantize") + file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n") elseif(${TARGET} STREQUAL "tensorrt_engine_op") message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") elseif(${TARGET} STREQUAL "fc") @@ -293,6 +295,7 @@ op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) +op_library(fake_quantize_op DEPS memory) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index a91e0f520e9..e608eba05d5 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -14,86 +14,198 @@ limitations under the License. */ #include "paddle/fluid/operators/fake_quantize_op.h" #include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { -class FakeQuantizeOp : public framework::OperatorWithKernel { +template +using EigenVectorArrayMap = + Eigen::TensorMap>; + +template +using ConstEigenVectorArrayMap = + Eigen::TensorMap>; + +template +struct FindAbsMaxFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const T* in, + const int num, T* out) { + Eigen::DSizes idim(num); + Eigen::DSizes odim(1); + Eigen::TensorMap> in_e(in, idim); + Eigen::TensorMap> out_e(out, odim); + + out_e = in_e.abs().maximum(); + } +}; + +template struct FindAbsMaxFunctor; + +template +struct ClipAndFakeQuantFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + T s = scale.data()[0]; + platform::Transform trans; + trans(ctx, in.data(), in.data() + in.numel(), + out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); + auto in_e = framework::EigenVector::Flatten(in); + auto out_e = framework::EigenVector::Flatten(*out); + + out_e.device(*ctx.eigen_device()) = (bin_cnt / s * in_e).round(); + } +}; + +template struct ClipAndFakeQuantFunctor; + +template +struct FindRangeAbsMaxFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& cur_scale, + const framework::Tensor& last_scale, + const framework::Tensor& iter, const int window_size, + framework::Tensor* scales_arr, framework::Tensor* out_scale) { + T* scale_arr = scales_arr->mutable_data(ctx.GetPlace()); + int64_t it = iter.data()[0]; + int idx = it % window_size; + T removed = scale_arr[idx]; + T cur = cur_scale.data()[0]; + scale_arr[idx] = cur; + + T max = last_scale.data()[0]; + if (max < cur) { + max = cur; + } else if (fabs(removed - max) < 1e-6) { + int size = (it > window_size) ? window_size : it; + FindAbsMaxFunctor()(ctx, scale_arr, size, + &max); + } + out_scale->mutable_data(ctx.GetPlace())[0] = max; + } +}; + +template struct FindRangeAbsMaxFunctor; + +class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { public: - FakeQuantizeOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) + FakeQuantizeAbsMaxOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FakeQuantizeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FakeQuantizeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"), - "OutMovingScale(Out) of FakeQuantizeOp should not be null"); - // if (ctx->HasInput("InMovingScale")) { - ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale")); - //} - // if (ctx->HasInput("InScales")) { - PADDLE_ENFORCE(ctx->HasOutput("OutScales"), - "OutScales(Out) of FakeQuantizeOp should not be null"); - ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales")); - // PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0], - // ctx->Outputs("OutScales")[0], - // "Mean and MeanOut should share the same memory"); - //} + PADDLE_ENFORCE(ctx->HasOutput("OutScale"), + "Output(Scale) of FakeQuantizeOp should not be null."); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->SetOutputDim("OutScale", {1}); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } }; -class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker { +class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(Tensor) Input tensor of scale operator."); - AddInput("InScales", "(Tensor) scale buffer, used in static quantization.") - .AsDispensable(); - AddInput("InMovingScale", "Last scale, used in static quantization.") - .AsDispensable(); - AddInput("InCurrentIter", - "Last iteration number, used in static quantization.") - .AsDispensable(); - AddOutput("Out", "(Tensor) Output of quantized low level tensor."); - AddOutput("OutScales", - "(Tensor) scale buffer, used in static quantization.") - .AsDispensable(); - AddOutput("OutMovingScale", " Current scale"); - AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable(); - AddAttr("quantize_type", - "(string, default abs_max)" - "The scaling tpe of the quantize operator.") - .SetDefault("abs_max"); - AddAttr("window_size", "(int, default 10000)").SetDefault(10000); + AddInput("X", "(Tensor) Input is float data type."); + AddOutput("Out", + "(Tensor) Output of quantized low level tensor, " + "but also saved as float data type."); + AddOutput("OutScale", "(Tensor) Current scale"); AddAttr("bit_length", "(int, default 8)") .SetDefault(8) - .AddCustomChecker([](const int &bit_length) { + .AddCustomChecker([](const int& bit_length) { PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, "'bit_length' should be between 1 and 16."); }); - AddAttr("is_test", "").SetDefault(false); AddComment(R"DOC( FakeQuantize operator -quantize_type = abs_max: +$$scale = max(abs(X))$$ +$$range = 2^{bit_length - 1} - 1$$ +$$Out = round(X/scale * range)$$ - $$scale = max(abs(x))$$ +)DOC"); + } +}; -quantize_type = range_abs_max: +class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { + public: + FakeQuantizeRangeAbsMaxOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} - $$scale = max(max(abs(x)), history_abs_max)$$ + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FakeQuantizeRangeAbsMaxOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("OutScale"), + "Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null"); + if (ctx->HasOutput("OutScales")) { + int window_size = ctx->Attrs().Get("window_size"); + ctx->SetOutputDim("OutScales", {window_size}); + } + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->SetOutputDim("OutScale", {1}); + ctx->ShareLoD("X", /*->*/ "Out"); + } -quantize_type = moving_average_abs_max: + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; - $$scale = 0.1*scale+0.9*new_abs_max)$$ +class FakeQuantizeRangeAbsMaxOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input is float data type."); + AddInput("InScale", "Last scale."); + AddInput("Iter", "Global step iteration.").AsDispensable(); + AddOutput("Out", "(Tensor) Output of quantized low level tensor."); + AddOutput("OutScale", " Current scale"); + AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable(); + AddAttr("window_size", "(int, default 10000) window range size.") + .SetDefault(10000); + AddAttr("bit_length", "(int, default 8), quantization bit number.") + .SetDefault(8) + .AddCustomChecker([](const int& bit_length) { + PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, + "'bit_length' should be between 1 and 16."); + }); + AddAttr("is_test", "").SetDefault(false); + AddComment(R"DOC( +FakeQuantize operator is used in static quantization. -$$Out = scale*X$$ +$$scale = max(max(abs(x)), history_abs_max)$$ +$$range = 2^{bit_length - 1} - 1$$ +$$Out = round(X/scale * range)$$ )DOC"); } @@ -103,10 +215,16 @@ $$Out = scale*X$$ } // namespace paddle namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp, + ops::FakeQuantizeAbsMaxOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max, + ops::FakeQuantizeAbsMaxKernel); -REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker, +REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp, + ops::FakeQuantizeRangeAbsMaxOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - fake_quantize, - ops::FakeQuantizeKernel, - ops::FakeQuantizeKernel); +REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max, + ops::FakeQuantizeRangeAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index be0c6730a51..7c65d6dba7d 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -20,7 +21,7 @@ namespace paddle { namespace operators { template -__global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { +__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; @@ -43,7 +44,7 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { __syncthreads(); for (int i = blockDim.x / 2; i > 0; i >>= 1) { - if (tid < i && shared_max_data[tid] < shared_max_data[tid + i]) { + if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { shared_max_data[tid] = shared_max_data[tid + i]; } __syncthreads(); @@ -53,220 +54,124 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { } } -float FindAbsMaxGpu(const platform::CUDADeviceContext& ctx, const float* array, - int length) { - float host_max; - int kNumTheads = 1024; - int gridDimx = (kNumTheads - 1 + length) / kNumTheads; - gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx; - framework::Tensor t; - float* device_max = t.mutable_data(framework::make_ddim({gridDimx}), - platform::CUDAPlace()); - FindAbsMaxKernel<<>>(length, array, device_max); - FindAbsMaxKernel< - float><<<1, kNumTheads, kNumTheads * sizeof(float), ctx.stream()>>>( - gridDimx, device_max, device_max); - PADDLE_ENFORCE_EQ( - cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost), - cudaSuccess, "cudaMemcpy failed"); - return host_max; -} +template +struct FindAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const T* in, + const int num, T* out) { + int block = 1024; + int grid = (block - 1 + num) / block; + grid = (grid > block) ? block : grid; + + framework::Tensor max; + T* max_data = + max.mutable_data(framework::make_ddim({grid}), ctx.GetPlace()); + FindAbsMaxKernel<<>>( + in, num, max_data); + FindAbsMaxKernel<<<1, block, 1024 * sizeof(T), ctx.stream()>>>( + max_data, grid, out); + } +}; + +template struct FindAbsMaxFunctor; template -__global__ void ApplySaturateKernel(const int n, const T* in, T* out, - int* num_saturate, const T min, - const T max) { +__global__ void ClipAndQuantKernel(const T* in, const T* scale, + const int bin_cnt, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; - extern __shared__ int shared_count[]; - shared_count[tid] = 0; + T s = scale[0]; for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - if (in[i] > max) { - out[i] = max; - shared_count[tid] += 1; - } else if (in[i] < min) { - out[i] = min; - shared_count[tid] += 1; - } else { - out[i] = in[i]; - } - } - __syncthreads(); - - for (int i = blockDim.x / 2; i > 0; i >>= 1) { - if (tid < i) { - shared_count[tid] += shared_count[tid + i]; - } - __syncthreads(); - } - if (tid == 0) { - num_saturate[blockIdx.x] = shared_count[0]; + T x = in[bid]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt / s * v; + out[bid] = round(v); } } template -__global__ void ReduceKernel(const int n, const T* in, T* out) { - int tid = threadIdx.x; - extern __shared__ T shared_sum[]; - if (tid < n) { - shared_sum[tid] = in[tid]; +__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, + const T* last_scale, + const int64_t* iter, + const int window_size, T* scale_arr, + T* out_scale, int* need_find_max, + int* out_size) { + int it = iter[0]; + int idx = it % window_size; + T removed = scale_arr[idx]; + T cur = cur_scale[0]; + scale_arr[idx] = cur; + T max = last_scale[0]; + out_scale[0] = max < cur ? cur : max; + if (fabs(removed - max) < 1e-6) { + need_find_max[0] = 1; + out_size[0] = it > window_size ? window_size : it; } else { - shared_sum[tid] = T(0); - } - __syncthreads(); - // blockDim.x must >= n - for (int i = (n + 1) / 2; i > 0; i >>= 1) { - if (tid < i) { - shared_sum[tid] += shared_sum[tid + i]; - } - __syncthreads(); - } - if (tid == 0) { - out[0] = shared_sum[0]; + need_find_max[0] = 0; } } template -int ApplySaturateGpu(const platform::CUDADeviceContext& ctx, const int n, - const T* in, T* out, const T min, const T max) { - int host_num_saturate; - int kNumTheads = 1024; - int gridDimx = (n + kNumTheads - 1) / kNumTheads; - gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx; - framework::Tensor t; - int* device_num_saturate = t.mutable_data( - framework::make_ddim({gridDimx}), platform::CUDAPlace()); - ApplySaturateKernel< - T><<>>( - n, in, out, device_num_saturate, min, max); - ReduceKernel<<<1, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>( - gridDimx, device_num_saturate, device_num_saturate); - PADDLE_ENFORCE_EQ(cudaSuccess, - cudaMemcpy(&host_num_saturate, device_num_saturate, - sizeof(int), cudaMemcpyDeviceToHost), - "cudaMemcpy failed"); - return host_num_saturate; -} - -template -class FakeQuantizeCUDAKernel : public framework::OpKernel { - public: - T FindRangeAbsMax(const platform::CUDADeviceContext& ctx, - framework::Tensor* scale_list, framework::Tensor* out_scale, - const T& cur_scale, int window_size, - int current_iter) const { - T* sl = scale_list->mutable_data(platform::CPUPlace()); - T remove_tmp = sl[current_iter]; - sl[current_iter] = cur_scale; - T& max_scale = out_scale->mutable_data(platform::CPUPlace())[0]; - if (max_scale < cur_scale) { - max_scale = cur_scale; - } else if (fabs(remove_tmp - max_scale) < 1e-6) { - int size = (current_iter > window_size) ? window_size : current_iter; - max_scale = T(FindAbsMaxGpu(ctx, scale_list->data(), size)); +struct FindRangeAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& cur_scale, + const framework::Tensor& last_scale, + const framework::Tensor& iter, const int window_size, + framework::Tensor* scales_arr, framework::Tensor* out_scale) { + auto& gpu_place = boost::get(ctx.GetPlace()); + T* scale_arr = scales_arr->mutable_data(gpu_place); + T* out_scale_data = out_scale->mutable_data(gpu_place); + + framework::Tensor need_find_max, out_size; + int* find_max = need_find_max.mutable_data(gpu_place); + int* out_size_data = out_size.mutable_data(gpu_place); + + FindRangeAbsMaxAndFillArray<<<1, 1, 0, ctx.stream()>>>( + cur_scale.data(), last_scale.data(), iter.data(), + window_size, scale_arr, out_scale_data, find_max, out_size_data); + + int g_find_max; + memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, + sizeof(int), 0); + if (g_find_max) { + int len; + memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, + sizeof(int), 0); + FindAbsMaxFunctor()(ctx, scale_arr, len, + out_scale_data); } - return max_scale; - } - - T FindMovingAverageAbsMmax(framework::Tensor* in_scale, - framework::Tensor* out_scale, - const T& cur_scale) const { - T* ins = in_scale->mutable_data(platform::CPUPlace()); - T* outs = out_scale->mutable_data(platform::CPUPlace()); - outs[0] = 0.9 * cur_scale + 0.1 * ins[0]; - return T(outs[0]); } +}; - virtual void Compute(const framework::ExecutionContext& context) const { - PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), - "This kernel only runs on GPU device."); - auto& device_ctx = context.cuda_device_context(); - auto* tensor = context.Output("Out"); - auto* in = context.Input("X"); - const bool is_test = context.Attr("is_test"); - tensor->mutable_data(in->place()); - context.Output("OutMovingScale") - ->mutable_data( - context.Input("InMovingScale")->place()); - auto quantize_type = - static_cast(context.Attr("quantize_type")); - if (quantize_type == std::string("range_abs_max")) { - context.Output("OutScales") - ->mutable_data( - context.Input("InScales")->place()); - context.Output("OutCurrentIter") - ->mutable_data( - context.Input("InCurrentIter")->place()); - } - - T scale = T(1); - int window_size = context.Attr("window_size"); - T bin_cnt = (T)((1 << (context.Attr("bit_length") - 1)) - 1); - if (quantize_type == std::string("abs_max")) { - auto* saving_scale = context.Output("OutMovingScale"); - scale = (T)FindAbsMaxGpu(device_ctx, in->data(), in->numel()); - saving_scale->mutable_data(platform::CPUPlace())[0] = scale; - - auto& device_ctx = context.template device_context(); - auto* scale_list = context.Output("OutScales"); - math::SetConstant scalar; - scale_list->mutable_data(context.GetPlace()); - scalar(device_ctx, scale_list, static_cast(0)); - auto* iter = context.Output("OutCurrentIter"); - iter->mutable_data(context.GetPlace()); - scalar(device_ctx, iter, static_cast(0)); - } else if (quantize_type == std::string("range_abs_max")) { - auto* moving_scale = const_cast( - context.Input("InMovingScale")); - if (is_test) { - scale = moving_scale->mutable_data(platform::CPUPlace())[0]; - } else { - auto* it = const_cast( - context.Input("InCurrentIter")); - auto* iter = context.Output("OutCurrentIter"); - int* last_iter = it->mutable_data(platform::CPUPlace()); - int* current_iter = iter->mutable_data(platform::CPUPlace()); - auto* scale_list = context.Output("OutScales"); - auto* saving_scale = - context.Output("OutMovingScale"); - scale = (T)FindAbsMaxGpu(device_ctx, in->data(), in->numel()); - scale = FindRangeAbsMax(device_ctx, scale_list, saving_scale, scale, - window_size, current_iter[0]); - (*current_iter) = (*last_iter) + 1; - } - } else if (quantize_type == std::string("moving_average_abs_max")) { - auto* moving_scale = const_cast( - context.Input("InMovingScale")); - if (is_test) { - scale = moving_scale->mutable_data(platform::CPUPlace())[0]; - } else { - scale = (T)FindAbsMaxGpu(device_ctx, in->data(), in->numel()); - auto* saving_scale = - context.Output("OutMovingScale"); - scale = FindMovingAverageAbsMmax( - const_cast(moving_scale), saving_scale, scale); - } - } - - ApplySaturateGpu(device_ctx, in->numel(), in->data(), - tensor->mutable_data(in->place()), -scale, scale); - scale = bin_cnt / scale; +template struct FindRangeAbsMaxFunctor; - auto& dev = - *context.template device_context().eigen_device(); - auto eigen_out = framework::EigenVector::Flatten(*tensor); - auto eigen_in = framework::EigenVector::Flatten(*tensor); - eigen_out.device(dev) = (scale * eigen_in).round(); +template +struct ClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + int num = in.numel(); + int block = 1024; + int grid = (block - 1 + num) / block; + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + ClipAndQuantKernel<<>>( + in_data, scale_data, bin_cnt, num, out_data); } }; +template struct ClipAndFakeQuantFunctor; + } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(fake_quantize, - paddle::operators::FakeQuantizeCUDAKernel< - paddle::platform::CUDADeviceContext, float>, - paddle::operators::FakeQuantizeCUDAKernel< - paddle::platform::CUDADeviceContext, double>); +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max, + ops::FakeQuantizeAbsMaxKernel); +REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, + ops::FakeQuantizeRangeAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 80f71d85dde..7ace7573ec5 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -17,137 +17,91 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { -using platform::Transform; +template +struct FindAbsMaxFunctor { + void operator()(const DeviceContext& ctx, const T* in, const int num, T* out); +}; template -class FakeQuantizeKernel : public framework::OpKernel { +struct ClipAndFakeQuantFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& in, + const framework::Tensor& scale, const int bin_cnt, + framework::Tensor* out); +}; + +template +struct FindRangeAbsMaxFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale, + const framework::Tensor& last_scale, + const framework::Tensor& iter, const int window_size, + framework::Tensor* scales_arr, framework::Tensor* out_scale); +}; + +template +class FakeQuantizeAbsMaxKernel : public framework::OpKernel { public: - T FindAbsMax(framework::Tensor* in, int n) const { - T* p = in->mutable_data(platform::CPUPlace()); - T abs_max = (T)0.00000001; - for (int i = 0; i < n; i++) { - T tmp = fabs(p[i]); - if (tmp > abs_max) abs_max = tmp; - } - return T(abs_max); - } - T FindRangeAbsMax(framework::Tensor* scale_list, framework::Tensor* out_scale, - const T& cur_scale, int window_size, - int current_iter) const { - T* sl = scale_list->mutable_data(platform::CPUPlace()); - T remove_tmp = sl[current_iter]; - sl[current_iter] = cur_scale; - T& max_scale = out_scale->mutable_data(platform::CPUPlace())[0]; - if (max_scale < cur_scale) { - max_scale = cur_scale; - } else if (fabs(remove_tmp - max_scale) < 1e-6) { - int size = (current_iter > window_size) ? window_size : current_iter; - max_scale = T(FindAbsMax(scale_list, size)); - } - return max_scale; - } + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); - T FindMovingAverageAbsMmax(framework::Tensor* in_scale, - framework::Tensor* out_scale, - const T& cur_scale) const { - T* ins = in_scale->mutable_data(platform::CPUPlace()); - T* outs = out_scale->mutable_data(platform::CPUPlace()); - outs[0] = 0.9 * cur_scale + 0.1 * ins[0]; - return T(outs[0]); + auto* out = context.Output("Out"); + auto* out_scale = context.Output("OutScale"); + T* out_s = out_scale->mutable_data(context.GetPlace()); + + int bit_length = context.Attr("bit_length"); + int bin_cnt = std::pow(2, bit_length - 1) - 1; + + auto& dev_ctx = context.template device_context(); + const T* in_data = in->data(); + FindAbsMaxFunctor()(dev_ctx, in_data, in->numel(), out_s); + ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, + bin_cnt, out); } +}; - virtual void Compute(const framework::ExecutionContext& context) const { - auto* tensor = context.Output("Out"); +template +class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); - const bool is_test = context.Attr("is_test"); - tensor->mutable_data(in->place()); - - auto* oms_tensor = context.Output("OutMovingScale"); - oms_tensor->mutable_data(in->place()); - - auto quantize_type = - static_cast(context.Attr("quantize_type")); - if (quantize_type == std::string("range_abs_max")) { - auto* oss_tensor = context.Output("OutScales"); - oss_tensor->mutable_data( - context.Input("InScales")->place()); - auto* oci_tensor = context.Output("OutCurrentIter"); - oci_tensor->mutable_data( - context.Input("InCurrentIter")->place()); - } + auto* in_scale = context.Input("InScale"); - T scale = static_cast(1); - int window_size = context.Attr("window_size"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + bool is_test = context.Attr("is_test"); int bit_length = context.Attr("bit_length"); int bin_cnt = std::pow(2, bit_length - 1) - 1; + auto& dev_ctx = context.template device_context(); - auto& dev = - *context.template device_context().eigen_device(); - auto raw_in = framework::EigenVector::Flatten(*in); - if (quantize_type == std::string("abs_max")) { - auto* saving_scale = context.Output("OutMovingScale"); - auto scale_out = framework::EigenVector::Flatten(*saving_scale); - scale_out.device(dev) = raw_in.abs().maximum(); - scale = scale_out(0); - - auto& device_ctx = context.template device_context(); - auto* scale_list = context.Output("OutScales"); - math::SetConstant scalar; - scale_list->mutable_data(context.GetPlace()); - scalar(device_ctx, scale_list, static_cast(0)); - auto* iter = context.Output("OutCurrentIter"); - iter->mutable_data(context.GetPlace()); - scalar(device_ctx, iter, static_cast(0)); - } else if (quantize_type == std::string("range_abs_max")) { - auto* moving_scale = context.Input("InMovingScale"); - if (is_test) { - scale = moving_scale->data()[0]; - } else { - auto* it = context.Input("InCurrentIter"); - auto* iter = context.Output("OutCurrentIter"); - const int* last_iter = it->data(); - int* current_iter = iter->mutable_data(platform::CPUPlace()); - auto* scale_list = context.Output("OutScales"); - auto* saving_scale = - context.Output("OutMovingScale"); - auto scale_out = framework::EigenVector::Flatten(*saving_scale); - scale_out.device(dev) = raw_in.abs().maximum(); - scale = saving_scale->mutable_data(platform::CPUPlace())[0]; - scale = FindRangeAbsMax(scale_list, saving_scale, scale, window_size, - current_iter[0]); - saving_scale->mutable_data(platform::CPUPlace())[0] = scale; - (*current_iter) = (*last_iter) + 1; - } - } else if (quantize_type == std::string("moving_average_abs_max")) { - auto* moving_scale = context.Input("InMovingScale"); - if (is_test) { - scale = moving_scale->data()[0]; - } else { - auto* saving_scale = - context.Output("OutMovingScale"); - auto scale_out = framework::EigenVector::Flatten(*saving_scale); - scale_out.device(dev) = raw_in.abs().maximum(); - scale = saving_scale->mutable_data(platform::CPUPlace())[0]; - scale = FindMovingAverageAbsMmax( - const_cast(moving_scale), saving_scale, scale); - saving_scale->mutable_data(platform::CPUPlace())[0] = scale; - } + // testing + if (is_test) { + ClipAndFakeQuantFunctor()(dev_ctx, *in, *in_scale, + bin_cnt, out); + return; } - Transform trans; - trans(context.template device_context(), in->data(), - in->data() + in->numel(), tensor->mutable_data(in->place()), - ClipFunctor(-scale, scale)); - auto eigen_out = framework::EigenVector::Flatten(*tensor); - auto eigen_in = framework::EigenVector::Flatten(*tensor); - eigen_out.device(dev) = (bin_cnt / scale * eigen_in).round(); + // training + auto* out_scale = context.Output("OutScale"); + auto* out_scales = context.Output("OutScales"); + auto* iter = context.Input("Iter"); + + int window_size = context.Attr("window_size"); + out_scale->mutable_data(context.GetPlace()); + + framework::Tensor cur_scale; + T* cur_scale_data = cur_scale.mutable_data({1}, context.GetPlace()); + FindAbsMaxFunctor()(dev_ctx, in->data(), in->numel(), + cur_scale_data); + FindRangeAbsMaxFunctor()(dev_ctx, cur_scale, *in_scale, + *iter, window_size, out_scales, + out_scale); + ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, + bin_cnt, out); } }; diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index cc0494774a5..820ad4af88e 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -21,28 +21,41 @@ from op_test import OpTest class TestFakeQuantizeOp(OpTest): def setUp(self): - self.op_type = "fake_quantize" + self.op_type = "fake_quantize_abs_max" + self.attrs = {'bit_length': 8} + self.inputs = {'X': np.random.random((124, 240)).astype("float32"), } + scale = np.max(np.abs(self.inputs['X'])).astype("float32") + self.outputs = { + 'Out': np.round(self.inputs['X'] / scale * ( + (1 << (self.attrs['bit_length'] - 1)) - 1)), + 'OutScale': np.array(scale).astype("float32"), + } + + def test_check_output(self): + self.check_output() + + +class TestFakeQuantizeOp(OpTest): + def setUp(self): + self.op_type = "fake_quantize_range_abs_max" self.attrs = { - 'bit_length': 8, - 'quantize_type': 'abs_max', - 'window_size': 10000 + 'bit_length': int(5), + 'window_size': int(1), + 'is_test': False } self.inputs = { - 'X': np.random.random((10, 10)).astype("float32"), - 'InScales': np.zeros(self.attrs['window_size']).astype("float32"), - 'InCurrentIter': np.zeros(1).astype("float32"), - 'InMovingScale': np.zeros(1).astype("float32") - } - self.scale = { - 'abs_max': np.max(np.abs(self.inputs['X'])).astype("float32") + 'X': np.random.random((8, 16, 7, 7)).astype("float32"), + 'Iter': np.zeros(1).astype("int64"), + 'InScale': np.zeros(1).astype("float32") } + scale = np.max(np.abs(self.inputs['X'])).astype("float32") + out_scales = np.zeros(self.attrs['window_size']).astype("float32") + out_scales[0] = scale self.outputs = { - 'Out': np.round(self.inputs['X'] / self.scale['abs_max'] * ( + 'Out': np.round(self.inputs['X'] / scale * ( (1 << (self.attrs['bit_length'] - 1)) - 1)), - 'OutScales': np.zeros(self.attrs['window_size']).astype("float32"), - 'OutMovingScale': - np.array([self.scale['abs_max']]).astype("float32"), - 'OutCurrentIter': np.zeros(1).astype("float32") + 'OutScale': scale, + 'OutScales': out_scales, } def test_check_output(self): -- GitLab