提交 93c034ee 编写于 作者: T tensor-tang

Merge remote-tracking branch 'ups/develop' into optimize/op/fusion_lstm

...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <sys/time.h>
#include <algorithm> #include <algorithm>
#include <map> #include <map>
#include <set> #include <set>
...@@ -23,32 +22,14 @@ limitations under the License. */ ...@@ -23,32 +22,14 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/timer.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_bool(profile, false, "Turn on profiler for fluid"); DEFINE_bool(profile, false, "Turn on profiler for fluid");
namespace paddle { namespace paddle {
namespace { namespace {
using paddle::inference::Timer;
// Timer for timer
class Timer {
public:
double start;
double startu;
void tic() {
struct timeval tp;
gettimeofday(&tp, NULL);
start = tp.tv_sec;
startu = tp.tv_usec;
}
double toc() {
struct timeval tp;
gettimeofday(&tp, NULL);
double used_time_ms =
(tp.tv_sec - start) * 1000.0 + (tp.tv_usec - startu) / 1000.0;
return used_time_ms;
}
};
template <class T> template <class T>
std::string num2str(T a) { std::string num2str(T a) {
...@@ -80,7 +61,7 @@ void NativePaddlePredictor::PrepareFeedFetch() { ...@@ -80,7 +61,7 @@ void NativePaddlePredictor::PrepareFeedFetch() {
bool NativePaddlePredictor::Init( bool NativePaddlePredictor::Init(
std::shared_ptr<framework::Scope> parent_scope) { std::shared_ptr<framework::Scope> parent_scope) {
VLOG(3) << "Predictor::init()"; VLOG(3) << "Predictor::init()";
#if !defined(_WIN32)
if (FLAGS_profile) { if (FLAGS_profile) {
LOG(WARNING) << "Profiler is actived, might affect the performance"; LOG(WARNING) << "Profiler is actived, might affect the performance";
LOG(INFO) << "You can turn off by set gflags '-profile false'"; LOG(INFO) << "You can turn off by set gflags '-profile false'";
...@@ -89,6 +70,7 @@ bool NativePaddlePredictor::Init( ...@@ -89,6 +70,7 @@ bool NativePaddlePredictor::Init(
: platform::ProfilerState::kCPU; : platform::ProfilerState::kCPU;
platform::EnableProfiler(tracking_device); platform::EnableProfiler(tracking_device);
} }
#endif
if (config_.use_gpu) { if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device); place_ = paddle::platform::CUDAPlace(config_.device);
...@@ -133,10 +115,12 @@ bool NativePaddlePredictor::Init( ...@@ -133,10 +115,12 @@ bool NativePaddlePredictor::Init(
} }
NativePaddlePredictor::~NativePaddlePredictor() { NativePaddlePredictor::~NativePaddlePredictor() {
#if !defined(_WIN32)
if (FLAGS_profile) { if (FLAGS_profile) {
platform::DisableProfiler(platform::EventSortingKey::kTotal, platform::DisableProfiler(platform::EventSortingKey::kTotal,
"./profile.log"); "./profile.log");
} }
#endif
if (sub_scope_) { if (sub_scope_) {
scope_->DeleteScope(sub_scope_); scope_->DeleteScope(sub_scope_);
} }
......
...@@ -3,6 +3,11 @@ cmake_minimum_required(VERSION 3.0) ...@@ -3,6 +3,11 @@ cmake_minimum_required(VERSION 3.0)
project(cpp_inference_demo CXX C) project(cpp_inference_demo CXX C)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
if (WIN32)
set(CMAKE_STATIC_LIBRARY_PREFIX "lib")
else()
set(CMAKE_STATIC_LIBRARY_PREFIX "")
endif()
if(NOT DEFINED PADDLE_LIB) if(NOT DEFINED PADDLE_LIB)
message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib")
...@@ -32,44 +37,56 @@ endif(NOT WIN32) ...@@ -32,44 +37,56 @@ endif(NOT WIN32)
include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/boost")
include_directories("${PADDLE_LIB}/third_party/eigen3") include_directories("${PADDLE_LIB}/third_party/eigen3")
if (NOT WIN32)
link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") link_directories("${PADDLE_LIB}/third_party/install/snappy/lib")
link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib")
link_directories("${PADDLE_LIB}/third_party/install/zlib/lib")
endif(NOT WIN32)
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") link_directories("${PADDLE_LIB}/third_party/install/gflags/lib")
link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") link_directories("${PADDLE_LIB}/paddle/fluid/inference")
add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) add_executable(${DEMO_NAME} ${DEMO_NAME}.cc)
if(WITH_MKL) if(WITH_MKL)
include_directories("${PADDLE_LIB}/third_party/install/mklml/include") include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel.so set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5.so) ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn")
if(EXISTS ${MKLDNN_PATH}) if(EXISTS ${MKLDNN_PATH})
include_directories("${MKLDNN_PATH}/include") include_directories("${MKLDNN_PATH}/include")
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
endif() endif()
else() else()
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.a) set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX})
endif() endif()
# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a # Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a
if(WITH_STATIC_LIB) if(WITH_STATIC_LIB)
set(DEPS set(DEPS
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.a) ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX})
else() else()
set(DEPS set(DEPS
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.so) ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX})
endif() endif()
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
if (NOT WIN32)
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${DEPS} set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB} ${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf snappystream snappy z glog gflags protobuf snappystream snappy z
${EXTERNAL_LIB}) ${EXTERNAL_LIB})
else()
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
${CMAKE_STATIC_LIBRARY_PREFIX}glog ${CMAKE_STATIC_LIBRARY_PREFIX}gflags ${CMAKE_STATIC_LIBRARY_PREFIX}protobuf
${EXTERNAL_LIB})
endif(NOT WIN32)
if(WITH_GPU) if(WITH_GPU)
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart.so) set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
endif() endif()
target_link_libraries(${DEMO_NAME} ${DEPS}) target_link_libraries(${DEMO_NAME} ${DEPS})
...@@ -16,35 +16,15 @@ ...@@ -16,35 +16,15 @@
#include <sys/time.h> #include <sys/time.h>
#include <algorithm> #include <algorithm>
#include <numeric>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/timer.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
// Timer for timer
class Timer {
public:
double start;
double startu;
void tic() {
struct timeval tp;
gettimeofday(&tp, NULL);
start = tp.tv_sec;
startu = tp.tv_usec;
}
double toc() {
struct timeval tp;
gettimeofday(&tp, NULL);
double used_time_ms =
(tp.tv_sec - start) * 1000.0 + (tp.tv_usec - startu) / 1000.0;
return used_time_ms;
}
};
static void split(const std::string &str, char sep, static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) { std::vector<std::string> *pieces) {
pieces->clear(); pieces->clear();
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono> // NOLINT
namespace paddle {
namespace inference {
// Timer for timer
class Timer {
public:
std::chrono::high_resolution_clock::time_point start;
std::chrono::high_resolution_clock::time_point startu;
void tic() { start = std::chrono::high_resolution_clock::now(); }
double toc() {
startu = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> time_span =
std::chrono::duration_cast<std::chrono::duration<double>>(startu -
start);
double used_time_ms = static_cast<double>(time_span.count()) * 1000.0;
return used_time_ms;
}
};
} // namespace inference
} // namespace paddle
...@@ -178,6 +178,8 @@ function(op_library TARGET) ...@@ -178,6 +178,8 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(relu);\n") file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "fake_dequantize") elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
elseif(${TARGET} STREQUAL "fake_quantize")
file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
elseif(${TARGET} STREQUAL "tensorrt_engine_op") elseif(${TARGET} STREQUAL "tensorrt_engine_op")
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
elseif(${TARGET} STREQUAL "fc") elseif(${TARGET} STREQUAL "fc")
...@@ -293,6 +295,7 @@ op_library(extract_rows_op DEPS memory) ...@@ -293,6 +295,7 @@ op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op) op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding) op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op) op_library(unstack_op DEPS stack_op)
op_library(fake_quantize_op DEPS memory)
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
...@@ -14,86 +14,198 @@ limitations under the License. */ ...@@ -14,86 +14,198 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/operators/fake_quantize_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FakeQuantizeOp : public framework::OperatorWithKernel { template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVectorArrayMap =
Eigen::TensorMap<Eigen::Tensor<T, 1, MajorType, IndexType>>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using ConstEigenVectorArrayMap =
Eigen::TensorMap<const Eigen::Tensor<T, 1, MajorType, IndexType>>;
template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in,
const int num, T* out) {
Eigen::DSizes<Eigen::DenseIndex, 1> idim(num);
Eigen::DSizes<Eigen::DenseIndex, 1> odim(1);
Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>> in_e(in, idim);
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>> out_e(out, odim);
out_e = in_e.abs().maximum();
}
};
template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
T s = scale.data<T>()[0];
platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
auto in_e = framework::EigenVector<T>::Flatten(in);
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * in_e).round();
}
};
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& cur_scale,
const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale) {
T* scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace());
int64_t it = iter.data<int64_t>()[0];
int idx = it % window_size;
T removed = scale_arr[idx];
T cur = cur_scale.data<T>()[0];
scale_arr[idx] = cur;
T max = last_scale.data<T>()[0];
if (max < cur) {
max = cur;
} else if (fabs(removed - max) < 1e-6) {
int size = (it > window_size) ? window_size : it;
FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size,
&max);
}
out_scale->mutable_data<T>(ctx.GetPlace())[0] = max;
}
};
template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;
class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
public: public:
FakeQuantizeOp(const std::string &type, FakeQuantizeAbsMaxOp(const std::string& type,
const framework::VariableNameMap &inputs, const framework::VariableNameMap& inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantizeOp should not be null."); "Input(X) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeOp should not be null."); "Output(Out) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"), PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
"OutMovingScale(Out) of FakeQuantizeOp should not be null"); "Output(Scale) of FakeQuantizeOp should not be null.");
// if (ctx->HasInput("InMovingScale")) {
ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale"));
//}
// if (ctx->HasInput("InScales")) {
PADDLE_ENFORCE(ctx->HasOutput("OutScales"),
"OutScales(Out) of FakeQuantizeOp should not be null");
ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales"));
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
// ctx->Outputs("OutScales")[0],
// "Mean and MeanOut should share the same memory");
//}
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
}; };
class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker { class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) Input tensor of scale operator."); AddInput("X", "(Tensor) Input is float data type.");
AddInput("InScales", "(Tensor) scale buffer, used in static quantization.") AddOutput("Out",
.AsDispensable(); "(Tensor) Output of quantized low level tensor, "
AddInput("InMovingScale", "Last scale, used in static quantization.") "but also saved as float data type.");
.AsDispensable(); AddOutput("OutScale", "(Tensor) Current scale");
AddInput("InCurrentIter",
"Last iteration number, used in static quantization.")
.AsDispensable();
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
AddOutput("OutScales",
"(Tensor) scale buffer, used in static quantization.")
.AsDispensable();
AddOutput("OutMovingScale", " Current scale");
AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable();
AddAttr<std::string>("quantize_type",
"(string, default abs_max)"
"The scaling tpe of the quantize operator.")
.SetDefault("abs_max");
AddAttr<int>("window_size", "(int, default 10000)").SetDefault(10000);
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int &bit_length) { .AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16."); "'bit_length' should be between 1 and 16.");
}); });
AddAttr<bool>("is_test", "").SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
FakeQuantize operator FakeQuantize operator
quantize_type = abs_max: $$scale = max(abs(X))$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
$$scale = max(abs(x))$$ )DOC");
}
};
quantize_type = range_abs_max: class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeRangeAbsMaxOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
$$scale = max(max(abs(x)), history_abs_max)$$ void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantizeRangeAbsMaxOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutScale"),
"Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null");
if (ctx->HasOutput("OutScales")) {
int window_size = ctx->Attrs().Get<int>("window_size");
ctx->SetOutputDim("OutScales", {window_size});
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
}
quantize_type = moving_average_abs_max: protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
$$scale = 0.1*scale+0.9*new_abs_max)$$ class FakeQuantizeRangeAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input is float data type.");
AddInput("InScale", "Last scale.");
AddInput("Iter", "Global step iteration.").AsDispensable();
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
AddOutput("OutScale", " Current scale");
AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable();
AddAttr<int>("window_size", "(int, default 10000) window range size.")
.SetDefault(10000);
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
});
AddAttr<bool>("is_test", "").SetDefault(false);
AddComment(R"DOC(
FakeQuantize operator is used in static quantization.
$$Out = scale*X$$ $$scale = max(max(abs(x)), history_abs_max)$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
)DOC"); )DOC");
} }
...@@ -103,10 +215,16 @@ $$Out = scale*X$$ ...@@ -103,10 +215,16 @@ $$Out = scale*X$$
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp,
ops::FakeQuantizeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker, REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
ops::FakeQuantizeRangeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
fake_quantize, ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -20,7 +21,7 @@ namespace paddle { ...@@ -20,7 +21,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
__global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -43,7 +44,7 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { ...@@ -43,7 +44,7 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
__syncthreads(); __syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) { for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i && shared_max_data[tid] < shared_max_data[tid + i]) { if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
shared_max_data[tid] = shared_max_data[tid + i]; shared_max_data[tid] = shared_max_data[tid + i];
} }
__syncthreads(); __syncthreads();
...@@ -53,220 +54,124 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { ...@@ -53,220 +54,124 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
} }
} }
float FindAbsMaxGpu(const platform::CUDADeviceContext& ctx, const float* array, template <typename T>
int length) { struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
float host_max; void operator()(const platform::CUDADeviceContext& ctx, const T* in,
int kNumTheads = 1024; const int num, T* out) {
int gridDimx = (kNumTheads - 1 + length) / kNumTheads; int block = 1024;
gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx; int grid = (block - 1 + num) / block;
framework::Tensor t; grid = (grid > block) ? block : grid;
float* device_max = t.mutable_data<float>(framework::make_ddim({gridDimx}),
platform::CUDAPlace()); framework::Tensor max;
FindAbsMaxKernel<float><<<gridDimx, kNumTheads, kNumTheads * sizeof(float), T* max_data =
ctx.stream()>>>(length, array, device_max); max.mutable_data<T>(framework::make_ddim({grid}), ctx.GetPlace());
FindAbsMaxKernel< FindAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
float><<<1, kNumTheads, kNumTheads * sizeof(float), ctx.stream()>>>( in, num, max_data);
gridDimx, device_max, device_max); FindAbsMaxKernel<T><<<1, block, 1024 * sizeof(T), ctx.stream()>>>(
PADDLE_ENFORCE_EQ( max_data, grid, out);
cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost), }
cudaSuccess, "cudaMemcpy failed"); };
return host_max;
} template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
__global__ void ApplySaturateKernel(const int n, const T* in, T* out, __global__ void ClipAndQuantKernel(const T* in, const T* scale,
int* num_saturate, const T min, const int bin_cnt, const int n, T* out) {
const T max) {
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
extern __shared__ int shared_count[]; T s = scale[0];
shared_count[tid] = 0;
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
if (in[i] > max) { T x = in[bid];
out[i] = max; T v = x > s ? s : x;
shared_count[tid] += 1; v = v < -s ? -s : v;
} else if (in[i] < min) { v = bin_cnt / s * v;
out[i] = min; out[bid] = round(v);
shared_count[tid] += 1;
} else {
out[i] = in[i];
}
}
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i) {
shared_count[tid] += shared_count[tid + i];
}
__syncthreads();
}
if (tid == 0) {
num_saturate[blockIdx.x] = shared_count[0];
} }
} }
template <typename T> template <typename T>
__global__ void ReduceKernel(const int n, const T* in, T* out) { __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
int tid = threadIdx.x; const T* last_scale,
extern __shared__ T shared_sum[]; const int64_t* iter,
if (tid < n) { const int window_size, T* scale_arr,
shared_sum[tid] = in[tid]; T* out_scale, int* need_find_max,
int* out_size) {
int it = iter[0];
int idx = it % window_size;
T removed = scale_arr[idx];
T cur = cur_scale[0];
scale_arr[idx] = cur;
T max = last_scale[0];
out_scale[0] = max < cur ? cur : max;
if (fabs(removed - max) < 1e-6) {
need_find_max[0] = 1;
out_size[0] = it > window_size ? window_size : it;
} else { } else {
shared_sum[tid] = T(0); need_find_max[0] = 0;
}
__syncthreads();
// blockDim.x must >= n
for (int i = (n + 1) / 2; i > 0; i >>= 1) {
if (tid < i) {
shared_sum[tid] += shared_sum[tid + i];
}
__syncthreads();
}
if (tid == 0) {
out[0] = shared_sum[0];
} }
} }
template <typename T> template <typename T>
int ApplySaturateGpu(const platform::CUDADeviceContext& ctx, const int n, struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
const T* in, T* out, const T min, const T max) { void operator()(const platform::CUDADeviceContext& ctx,
int host_num_saturate; const framework::Tensor& cur_scale,
int kNumTheads = 1024; const framework::Tensor& last_scale,
int gridDimx = (n + kNumTheads - 1) / kNumTheads; const framework::Tensor& iter, const int window_size,
gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx; framework::Tensor* scales_arr, framework::Tensor* out_scale) {
framework::Tensor t; auto& gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
int* device_num_saturate = t.mutable_data<int>( T* scale_arr = scales_arr->mutable_data<T>(gpu_place);
framework::make_ddim({gridDimx}), platform::CUDAPlace()); T* out_scale_data = out_scale->mutable_data<T>(gpu_place);
ApplySaturateKernel<
T><<<gridDimx, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>( framework::Tensor need_find_max, out_size;
n, in, out, device_num_saturate, min, max); int* find_max = need_find_max.mutable_data<int>(gpu_place);
ReduceKernel<int><<<1, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>( int* out_size_data = out_size.mutable_data<int>(gpu_place);
gridDimx, device_num_saturate, device_num_saturate);
PADDLE_ENFORCE_EQ(cudaSuccess, FindRangeAbsMaxAndFillArray<T><<<1, 1, 0, ctx.stream()>>>(
cudaMemcpy(&host_num_saturate, device_num_saturate, cur_scale.data<T>(), last_scale.data<T>(), iter.data<int64_t>(),
sizeof(int), cudaMemcpyDeviceToHost), window_size, scale_arr, out_scale_data, find_max, out_size_data);
"cudaMemcpy failed");
return host_num_saturate; int g_find_max;
} memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max,
sizeof(int), 0);
template <typename DeviceContext, typename T> if (g_find_max) {
class FakeQuantizeCUDAKernel : public framework::OpKernel<T> { int len;
public: memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data,
T FindRangeAbsMax(const platform::CUDADeviceContext& ctx, sizeof(int), 0);
framework::Tensor* scale_list, framework::Tensor* out_scale, FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
const T& cur_scale, int window_size, out_scale_data);
int current_iter) const {
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
T remove_tmp = sl[current_iter];
sl[current_iter] = cur_scale;
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMaxGpu(ctx, scale_list->data<float>(), size));
} }
return max_scale;
}
T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
framework::Tensor* out_scale,
const T& cur_scale) const {
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
T* outs = out_scale->mutable_data<T>(platform::CPUPlace());
outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
return T(outs[0]);
} }
};
virtual void Compute(const framework::ExecutionContext& context) const { template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
auto& device_ctx = context.cuda_device_context();
auto* tensor = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
const bool is_test = context.Attr<bool>("is_test");
tensor->mutable_data<T>(in->place());
context.Output<framework::Tensor>("OutMovingScale")
->mutable_data<T>(
context.Input<framework::Tensor>("InMovingScale")->place());
auto quantize_type =
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
if (quantize_type == std::string("range_abs_max")) {
context.Output<framework::Tensor>("OutScales")
->mutable_data<T>(
context.Input<framework::Tensor>("InScales")->place());
context.Output<framework::Tensor>("OutCurrentIter")
->mutable_data<T>(
context.Input<framework::Tensor>("InCurrentIter")->place());
}
T scale = T(1);
int window_size = context.Attr<int>("window_size");
T bin_cnt = (T)((1 << (context.Attr<int>("bit_length") - 1)) - 1);
if (quantize_type == std::string("abs_max")) {
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale");
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
auto& device_ctx = context.template device_context<DeviceContext>();
auto* scale_list = context.Output<framework::Tensor>("OutScales");
math::SetConstant<DeviceContext, T> scalar;
scale_list->mutable_data<T>(context.GetPlace());
scalar(device_ctx, scale_list, static_cast<T>(0));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
iter->mutable_data<T>(context.GetPlace());
scalar(device_ctx, iter, static_cast<T>(0));
} else if (quantize_type == std::string("range_abs_max")) {
auto* moving_scale = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InMovingScale"));
if (is_test) {
scale = moving_scale->mutable_data<T>(platform::CPUPlace())[0];
} else {
auto* it = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InCurrentIter"));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
int* last_iter = it->mutable_data<int>(platform::CPUPlace());
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
auto* scale_list = context.Output<framework::Tensor>("OutScales");
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
scale = FindRangeAbsMax(device_ctx, scale_list, saving_scale, scale,
window_size, current_iter[0]);
(*current_iter) = (*last_iter) + 1;
}
} else if (quantize_type == std::string("moving_average_abs_max")) {
auto* moving_scale = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InMovingScale"));
if (is_test) {
scale = moving_scale->mutable_data<T>(platform::CPUPlace())[0];
} else {
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
scale = FindMovingAverageAbsMmax(
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
}
}
ApplySaturateGpu<T>(device_ctx, in->numel(), in->data<T>(),
tensor->mutable_data<T>(in->place()), -scale, scale);
scale = bin_cnt / scale;
auto& dev = template <typename T>
*context.template device_context<DeviceContext>().eigen_device(); struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor); void operator()(const platform::CUDADeviceContext& ctx,
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor); const framework::Tensor& in, const framework::Tensor& scale,
eigen_out.device(dev) = (scale * eigen_in).round(); const int bin_cnt, framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data);
} }
}; };
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(fake_quantize, namespace ops = paddle::operators;
paddle::operators::FakeQuantizeCUDAKernel< using CUDA = paddle::platform::CUDADeviceContext;
paddle::platform::CUDADeviceContext, float>, REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
paddle::operators::FakeQuantizeCUDAKernel< ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
paddle::platform::CUDADeviceContext, double>); REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
...@@ -17,137 +17,91 @@ limitations under the License. */ ...@@ -17,137 +17,91 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using platform::Transform; template <typename DeviceContext, typename T>
struct FindAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num, T* out);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeKernel : public framework::OpKernel<T> { struct ClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
framework::Tensor* out);
};
template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale,
const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale);
};
template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
public: public:
T FindAbsMax(framework::Tensor* in, int n) const { void Compute(const framework::ExecutionContext& context) const override {
T* p = in->mutable_data<T>(platform::CPUPlace()); auto* in = context.Input<framework::Tensor>("X");
T abs_max = (T)0.00000001;
for (int i = 0; i < n; i++) {
T tmp = fabs(p[i]);
if (tmp > abs_max) abs_max = tmp;
}
return T(abs_max);
}
T FindRangeAbsMax(framework::Tensor* scale_list, framework::Tensor* out_scale,
const T& cur_scale, int window_size,
int current_iter) const {
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
T remove_tmp = sl[current_iter];
sl[current_iter] = cur_scale;
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMax(scale_list, size));
}
return max_scale;
}
T FindMovingAverageAbsMmax(framework::Tensor* in_scale, auto* out = context.Output<framework::Tensor>("Out");
framework::Tensor* out_scale, auto* out_scale = context.Output<framework::Tensor>("OutScale");
const T& cur_scale) const { T* out_s = out_scale->mutable_data<T>(context.GetPlace());
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
T* outs = out_scale->mutable_data<T>(platform::CPUPlace()); int bit_length = context.Attr<int>("bit_length");
outs[0] = 0.9 * cur_scale + 0.1 * ins[0]; int bin_cnt = std::pow(2, bit_length - 1) - 1;
return T(outs[0]);
auto& dev_ctx = context.template device_context<DeviceContext>();
const T* in_data = in->data<T>();
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale,
bin_cnt, out);
} }
};
virtual void Compute(const framework::ExecutionContext& context) const { template <typename DeviceContext, typename T>
auto* tensor = context.Output<framework::Tensor>("Out"); class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
const bool is_test = context.Attr<bool>("is_test"); auto* in_scale = context.Input<framework::Tensor>("InScale");
tensor->mutable_data<T>(in->place());
auto* oms_tensor = context.Output<framework::Tensor>("OutMovingScale");
oms_tensor->mutable_data<T>(in->place());
auto quantize_type =
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
if (quantize_type == std::string("range_abs_max")) {
auto* oss_tensor = context.Output<framework::Tensor>("OutScales");
oss_tensor->mutable_data<T>(
context.Input<framework::Tensor>("InScales")->place());
auto* oci_tensor = context.Output<framework::Tensor>("OutCurrentIter");
oci_tensor->mutable_data<T>(
context.Input<framework::Tensor>("InCurrentIter")->place());
}
T scale = static_cast<T>(1); auto* out = context.Output<framework::Tensor>("Out");
int window_size = context.Attr<int>("window_size"); out->mutable_data<T>(context.GetPlace());
bool is_test = context.Attr<bool>("is_test");
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto& dev = // testing
*context.template device_context<DeviceContext>().eigen_device(); if (is_test) {
auto raw_in = framework::EigenVector<T>::Flatten(*in); ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *in_scale,
if (quantize_type == std::string("abs_max")) { bin_cnt, out);
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale"); return;
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = scale_out(0);
auto& device_ctx = context.template device_context<DeviceContext>();
auto* scale_list = context.Output<framework::Tensor>("OutScales");
math::SetConstant<DeviceContext, T> scalar;
scale_list->mutable_data<T>(context.GetPlace());
scalar(device_ctx, scale_list, static_cast<T>(0));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
iter->mutable_data<T>(context.GetPlace());
scalar(device_ctx, iter, static_cast<T>(0));
} else if (quantize_type == std::string("range_abs_max")) {
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
if (is_test) {
scale = moving_scale->data<T>()[0];
} else {
auto* it = context.Input<framework::Tensor>("InCurrentIter");
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
const int* last_iter = it->data<int>();
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
auto* scale_list = context.Output<framework::Tensor>("OutScales");
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
scale = FindRangeAbsMax(scale_list, saving_scale, scale, window_size,
current_iter[0]);
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
(*current_iter) = (*last_iter) + 1;
}
} else if (quantize_type == std::string("moving_average_abs_max")) {
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
if (is_test) {
scale = moving_scale->data<T>()[0];
} else {
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
scale = FindMovingAverageAbsMmax(
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
}
} }
Transform<DeviceContext> trans; // training
trans(context.template device_context<DeviceContext>(), in->data<T>(), auto* out_scale = context.Output<framework::Tensor>("OutScale");
in->data<T>() + in->numel(), tensor->mutable_data<T>(in->place()), auto* out_scales = context.Output<framework::Tensor>("OutScales");
ClipFunctor<T>(-scale, scale)); auto* iter = context.Input<framework::Tensor>("Iter");
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor); int window_size = context.Attr<int>("window_size");
eigen_out.device(dev) = (bin_cnt / scale * eigen_in).round(); out_scale->mutable_data<T>(context.GetPlace());
framework::Tensor cur_scale;
T* cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(),
cur_scale_data);
FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx, cur_scale, *in_scale,
*iter, window_size, out_scales,
out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale,
bin_cnt, out);
} }
}; };
......
...@@ -53,7 +53,7 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -53,7 +53,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
static_cast<T>(context.Attr<float>("min")), static_cast<T>(context.Attr<float>("min")),
static_cast<T>(context.Attr<float>("max"))); static_cast<T>(context.Attr<float>("max")));
std::vector<T> ids(batch_size); std::vector<int64_t> ids(batch_size);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
T r = dist(engine); T r = dist(engine);
int idx = width - 1; int idx = width - 1;
...@@ -63,7 +63,7 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -63,7 +63,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
break; break;
} }
} }
ids[i] = ins_vector[idx]; ids[i] = int64_t(idx);
} }
std::vector<int64_t> out_dim; std::vector<int64_t> out_dim;
......
...@@ -98,10 +98,9 @@ class Inferencer(object): ...@@ -98,10 +98,9 @@ class Inferencer(object):
raise ValueError( raise ValueError(
"inputs should be a map of {'input_name': input_var}") "inputs should be a map of {'input_name': input_var}")
with executor.scope_guard(self.scope): with self._prog_and_scope_guard():
results = self.exe.run(self.inference_program, results = self.exe.run(feed=inputs,
feed=inputs, fetch_list=[self.predict_var.name],
fetch_list=[self.predict_var],
return_numpy=return_numpy) return_numpy=return_numpy)
return results return results
......
...@@ -16,7 +16,9 @@ from __future__ import print_function ...@@ -16,7 +16,9 @@ from __future__ import print_function
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy import numpy
import os
import cifar10_small_test_set import cifar10_small_test_set
...@@ -89,7 +91,7 @@ def optimizer_func(): ...@@ -89,7 +91,7 @@ def optimizer_func():
return fluid.optimizer.Adam(learning_rate=0.001) return fluid.optimizer.Adam(learning_rate=0.001)
def train(use_cuda, train_program, params_dirname): def train(use_cuda, train_program, parallel, params_dirname):
BATCH_SIZE = 128 BATCH_SIZE = 128
EPOCH_NUM = 1 EPOCH_NUM = 1
...@@ -116,7 +118,10 @@ def train(use_cuda, train_program, params_dirname): ...@@ -116,7 +118,10 @@ def train(use_cuda, train_program, params_dirname):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = fluid.Trainer( trainer = fluid.Trainer(
train_func=train_program, optimizer_func=optimizer_func, place=place) train_func=train_program,
optimizer_func=optimizer_func,
place=place,
parallel=parallel)
trainer.train( trainer.train(
reader=train_reader, reader=train_reader,
...@@ -125,10 +130,13 @@ def train(use_cuda, train_program, params_dirname): ...@@ -125,10 +130,13 @@ def train(use_cuda, train_program, params_dirname):
feed_order=['pixel', 'label']) feed_order=['pixel', 'label'])
def infer(use_cuda, inference_program, params_dirname=None): def infer(use_cuda, inference_program, parallel, params_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer( inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place) infer_func=inference_program,
param_path=params_dirname,
place=place,
parallel=parallel)
# The input's dimension of conv should be 4-D or 5-D. # The input's dimension of conv should be 4-D or 5-D.
# Use normilized image pixels as input data, which should be in the range # Use normilized image pixels as input data, which should be in the range
...@@ -139,22 +147,34 @@ def infer(use_cuda, inference_program, params_dirname=None): ...@@ -139,22 +147,34 @@ def infer(use_cuda, inference_program, params_dirname=None):
print("infer results: ", results) print("infer results: ", results)
def main(use_cuda): def main(use_cuda, parallel):
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
return return
save_path = "image_classification_resnet.inference.model" save_path = "image_classification_resnet.inference.model"
os.environ['CPU_NUM'] = str(4)
train( train(
use_cuda=use_cuda, use_cuda=use_cuda,
train_program=train_network, train_program=train_network,
params_dirname=save_path) params_dirname=save_path,
parallel=parallel)
# FIXME(zcd): in the inference stage, the number of
# input data is one, it is not appropriate to use parallel.
if parallel and use_cuda:
return
os.environ['CPU_NUM'] = str(1)
infer( infer(
use_cuda=use_cuda, use_cuda=use_cuda,
inference_program=inference_network, inference_program=inference_network,
params_dirname=save_path) params_dirname=save_path,
parallel=parallel)
if __name__ == '__main__': if __name__ == '__main__':
for use_cuda in (False, True): for use_cuda in (False, True):
main(use_cuda=use_cuda) for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
main(use_cuda=use_cuda, parallel=parallel)
...@@ -16,7 +16,9 @@ from __future__ import print_function ...@@ -16,7 +16,9 @@ from __future__ import print_function
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy import numpy
import os
import cifar10_small_test_set import cifar10_small_test_set
...@@ -68,7 +70,7 @@ def optimizer_func(): ...@@ -68,7 +70,7 @@ def optimizer_func():
return fluid.optimizer.Adam(learning_rate=0.001) return fluid.optimizer.Adam(learning_rate=0.001)
def train(use_cuda, train_program, params_dirname): def train(use_cuda, train_program, parallel, params_dirname):
BATCH_SIZE = 128 BATCH_SIZE = 128
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
...@@ -93,7 +95,10 @@ def train(use_cuda, train_program, params_dirname): ...@@ -93,7 +95,10 @@ def train(use_cuda, train_program, params_dirname):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = fluid.Trainer( trainer = fluid.Trainer(
train_func=train_program, place=place, optimizer_func=optimizer_func) train_func=train_program,
place=place,
optimizer_func=optimizer_func,
parallel=parallel)
trainer.train( trainer.train(
reader=train_reader, reader=train_reader,
...@@ -102,10 +107,13 @@ def train(use_cuda, train_program, params_dirname): ...@@ -102,10 +107,13 @@ def train(use_cuda, train_program, params_dirname):
feed_order=['pixel', 'label']) feed_order=['pixel', 'label'])
def infer(use_cuda, inference_program, params_dirname=None): def infer(use_cuda, inference_program, parallel, params_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer( inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place) infer_func=inference_program,
param_path=params_dirname,
place=place,
parallel=parallel)
# The input's dimension of conv should be 4-D or 5-D. # The input's dimension of conv should be 4-D or 5-D.
# Use normilized image pixels as input data, which should be in the range # Use normilized image pixels as input data, which should be in the range
...@@ -116,22 +124,31 @@ def infer(use_cuda, inference_program, params_dirname=None): ...@@ -116,22 +124,31 @@ def infer(use_cuda, inference_program, params_dirname=None):
print("infer results: ", results) print("infer results: ", results)
def main(use_cuda): def main(use_cuda, parallel):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
save_path = "image_classification_vgg.inference.model" save_path = "image_classification_vgg.inference.model"
os.environ['CPU_NUM'] = str(4)
train( train(
use_cuda=use_cuda, use_cuda=use_cuda,
train_program=train_network, train_program=train_network,
params_dirname=save_path) params_dirname=save_path,
parallel=parallel)
# FIXME(zcd): in the inference stage, the number of
# input data is one, it is not appropriate to use parallel.
if parallel and use_cuda:
return
os.environ['CPU_NUM'] = str(1)
infer( infer(
use_cuda=use_cuda, use_cuda=use_cuda,
inference_program=inference_network, inference_program=inference_network,
params_dirname=save_path) params_dirname=save_path,
parallel=parallel)
if __name__ == '__main__': if __name__ == '__main__':
for use_cuda in (False, True): for use_cuda in (False, True):
main(use_cuda=use_cuda) for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
main(use_cuda=use_cuda, parallel=parallel)
...@@ -64,14 +64,14 @@ def optimizer_func(): ...@@ -64,14 +64,14 @@ def optimizer_func():
return fluid.optimizer.Adam(learning_rate=0.001) return fluid.optimizer.Adam(learning_rate=0.001)
def train(use_cuda, train_program, params_dirname): def train(use_cuda, train_program, parallel, params_dirname):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = fluid.Trainer( trainer = fluid.Trainer(
train_func=train_program, train_func=train_program,
place=place, place=place,
optimizer_func=optimizer_func, optimizer_func=optimizer_func,
parallel=True) parallel=parallel)
def event_handler(event): def event_handler(event):
if isinstance(event, fluid.EndEpochEvent): if isinstance(event, fluid.EndEpochEvent):
...@@ -108,11 +108,14 @@ def train(use_cuda, train_program, params_dirname): ...@@ -108,11 +108,14 @@ def train(use_cuda, train_program, params_dirname):
feed_order=['img', 'label']) feed_order=['img', 'label'])
def infer(use_cuda, inference_program, params_dirname=None): def infer(use_cuda, inference_program, parallel, params_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer( inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place) infer_func=inference_program,
param_path=params_dirname,
place=place,
parallel=parallel)
batch_size = 1 batch_size = 1
tensor_img = numpy.random.uniform(-1.0, 1.0, tensor_img = numpy.random.uniform(-1.0, 1.0,
...@@ -123,20 +126,32 @@ def infer(use_cuda, inference_program, params_dirname=None): ...@@ -123,20 +126,32 @@ def infer(use_cuda, inference_program, params_dirname=None):
print("infer results: ", results[0]) print("infer results: ", results[0])
def main(use_cuda): def main(use_cuda, parallel):
params_dirname = "recognize_digits_conv.inference.model" params_dirname = "recognize_digits_conv.inference.model"
# call train() with is_local argument to run distributed train # call train() with is_local argument to run distributed train
os.environ['CPU_NUM'] = str(4)
train( train(
use_cuda=use_cuda, use_cuda=use_cuda,
train_program=train_program, train_program=train_program,
params_dirname=params_dirname) params_dirname=params_dirname,
parallel=parallel)
# FIXME(zcd): in the inference stage, the number of
# input data is one, it is not appropriate to use parallel.
if parallel and use_cuda:
return
os.environ['CPU_NUM'] = str(1)
infer( infer(
use_cuda=use_cuda, use_cuda=use_cuda,
inference_program=inference_program, inference_program=inference_program,
params_dirname=params_dirname) params_dirname=params_dirname,
parallel=parallel)
if __name__ == '__main__': if __name__ == '__main__':
# for use_cuda in (False, True): for use_cuda in (False, True):
main(use_cuda=core.is_compiled_with_cuda()) for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
main(use_cuda=use_cuda, parallel=parallel)
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import argparse import argparse
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle import paddle
import sys import sys
import numpy import numpy
...@@ -50,11 +51,14 @@ def optimizer_func(): ...@@ -50,11 +51,14 @@ def optimizer_func():
return fluid.optimizer.Adam(learning_rate=0.001) return fluid.optimizer.Adam(learning_rate=0.001)
def train(use_cuda, train_program, params_dirname): def train(use_cuda, train_program, params_dirname, parallel):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = fluid.Trainer( trainer = fluid.Trainer(
train_func=train_program, place=place, optimizer_func=optimizer_func) train_func=train_program,
place=place,
optimizer_func=optimizer_func,
parallel=parallel)
def event_handler(event): def event_handler(event):
if isinstance(event, fluid.EndEpochEvent): if isinstance(event, fluid.EndEpochEvent):
...@@ -86,11 +90,14 @@ def train(use_cuda, train_program, params_dirname): ...@@ -86,11 +90,14 @@ def train(use_cuda, train_program, params_dirname):
feed_order=['img', 'label']) feed_order=['img', 'label'])
def infer(use_cuda, inference_program, params_dirname=None): def infer(use_cuda, inference_program, parallel, params_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer( inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place) infer_func=inference_program,
param_path=params_dirname,
place=place,
parallel=parallel)
batch_size = 1 batch_size = 1
tensor_img = numpy.random.uniform(-1.0, 1.0, tensor_img = numpy.random.uniform(-1.0, 1.0,
...@@ -101,20 +108,32 @@ def infer(use_cuda, inference_program, params_dirname=None): ...@@ -101,20 +108,32 @@ def infer(use_cuda, inference_program, params_dirname=None):
print("infer results: ", results[0]) print("infer results: ", results[0])
def main(use_cuda): def main(use_cuda, parallel):
params_dirname = "recognize_digits_mlp.inference.model" params_dirname = "recognize_digits_mlp.inference.model"
# call train() with is_local argument to run distributed train # call train() with is_local argument to run distributed train
os.environ['CPU_NUM'] = str(4)
train( train(
use_cuda=use_cuda, use_cuda=use_cuda,
train_program=train_program, train_program=train_program,
params_dirname=params_dirname) params_dirname=params_dirname,
parallel=parallel)
# FIXME(zcd): in the inference stage, the number of
# input data is one, it is not appropriate to use parallel.
if parallel and use_cuda:
return
os.environ['CPU_NUM'] = str(1)
infer( infer(
use_cuda=use_cuda, use_cuda=use_cuda,
inference_program=inference_program, inference_program=inference_program,
params_dirname=params_dirname) params_dirname=params_dirname,
parallel=parallel)
if __name__ == '__main__': if __name__ == '__main__':
# for use_cuda in (False, True): for use_cuda in (False, True):
main(use_cuda=False) for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
main(use_cuda=use_cuda, parallel=parallel)
...@@ -21,28 +21,41 @@ from op_test import OpTest ...@@ -21,28 +21,41 @@ from op_test import OpTest
class TestFakeQuantizeOp(OpTest): class TestFakeQuantizeOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fake_quantize" self.op_type = "fake_quantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {'X': np.random.random((124, 240)).astype("float32"), }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
self.outputs = {
'Out': np.round(self.inputs['X'] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScale': np.array(scale).astype("float32"),
}
def test_check_output(self):
self.check_output()
class TestFakeQuantizeOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_range_abs_max"
self.attrs = { self.attrs = {
'bit_length': 8, 'bit_length': int(5),
'quantize_type': 'abs_max', 'window_size': int(1),
'window_size': 10000 'is_test': False
} }
self.inputs = { self.inputs = {
'X': np.random.random((10, 10)).astype("float32"), 'X': np.random.random((8, 16, 7, 7)).astype("float32"),
'InScales': np.zeros(self.attrs['window_size']).astype("float32"), 'Iter': np.zeros(1).astype("int64"),
'InCurrentIter': np.zeros(1).astype("float32"), 'InScale': np.zeros(1).astype("float32")
'InMovingScale': np.zeros(1).astype("float32")
}
self.scale = {
'abs_max': np.max(np.abs(self.inputs['X'])).astype("float32")
} }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
out_scales = np.zeros(self.attrs['window_size']).astype("float32")
out_scales[0] = scale
self.outputs = { self.outputs = {
'Out': np.round(self.inputs['X'] / self.scale['abs_max'] * ( 'Out': np.round(self.inputs['X'] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)), (1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScales': np.zeros(self.attrs['window_size']).astype("float32"), 'OutScale': scale,
'OutMovingScale': 'OutScales': out_scales,
np.array([self.scale['abs_max']]).astype("float32"),
'OutCurrentIter': np.zeros(1).astype("float32")
} }
def test_check_output(self): def test_check_output(self):
......
...@@ -25,9 +25,9 @@ class TestSamplingIdOp(OpTest): ...@@ -25,9 +25,9 @@ class TestSamplingIdOp(OpTest):
self.op_type = "sampling_id" self.op_type = "sampling_id"
self.use_mkldnn = False self.use_mkldnn = False
self.init_kernel_type() self.init_kernel_type()
self.X = np.random.random((8, 4)).astype('float32') self.X = np.random.random((100, 10)).astype('float32')
self.inputs = {"X": self.X} self.inputs = {"X": self.X}
self.Y = np.random.random(8).astype('float32') self.Y = np.random.random(100).astype('int64')
self.outputs = {'Out': self.Y} self.outputs = {'Out': self.Y}
self.attrs = {'max': 1.0, 'min': 0.0, 'seed': 1} self.attrs = {'max': 1.0, 'min': 0.0, 'seed': 1}
...@@ -36,6 +36,16 @@ class TestSamplingIdOp(OpTest): ...@@ -36,6 +36,16 @@ class TestSamplingIdOp(OpTest):
y1 = self.out y1 = self.out
self.check_output_customized(self.verify_output) self.check_output_customized(self.verify_output)
y2 = self.out y2 = self.out
# check dtype
assert y1.dtype == np.int64
assert y2.dtype == np.int64
# check output is index ids of inputs
inputs_ids = np.arange(self.X.shape[1])
assert np.isin(y1, inputs_ids).all()
assert np.isin(y2, inputs_ids).all()
self.assertTrue(np.array_equal(y1, y2)) self.assertTrue(np.array_equal(y1, y2))
self.assertEqual(len(y1), len(self.Y)) self.assertEqual(len(y1), len(self.Y))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册