提交 93c034ee 编写于 作者: T tensor-tang

Merge remote-tracking branch 'ups/develop' into optimize/op/fusion_lstm

......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <sys/time.h>
#include <algorithm>
#include <map>
#include <set>
......@@ -23,32 +22,14 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/timer.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_bool(profile, false, "Turn on profiler for fluid");
namespace paddle {
namespace {
// Timer for timer
class Timer {
public:
double start;
double startu;
void tic() {
struct timeval tp;
gettimeofday(&tp, NULL);
start = tp.tv_sec;
startu = tp.tv_usec;
}
double toc() {
struct timeval tp;
gettimeofday(&tp, NULL);
double used_time_ms =
(tp.tv_sec - start) * 1000.0 + (tp.tv_usec - startu) / 1000.0;
return used_time_ms;
}
};
using paddle::inference::Timer;
template <class T>
std::string num2str(T a) {
......@@ -80,7 +61,7 @@ void NativePaddlePredictor::PrepareFeedFetch() {
bool NativePaddlePredictor::Init(
std::shared_ptr<framework::Scope> parent_scope) {
VLOG(3) << "Predictor::init()";
#if !defined(_WIN32)
if (FLAGS_profile) {
LOG(WARNING) << "Profiler is actived, might affect the performance";
LOG(INFO) << "You can turn off by set gflags '-profile false'";
......@@ -89,6 +70,7 @@ bool NativePaddlePredictor::Init(
: platform::ProfilerState::kCPU;
platform::EnableProfiler(tracking_device);
}
#endif
if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device);
......@@ -133,10 +115,12 @@ bool NativePaddlePredictor::Init(
}
NativePaddlePredictor::~NativePaddlePredictor() {
#if !defined(_WIN32)
if (FLAGS_profile) {
platform::DisableProfiler(platform::EventSortingKey::kTotal,
"./profile.log");
}
#endif
if (sub_scope_) {
scope_->DeleteScope(sub_scope_);
}
......
......@@ -3,6 +3,11 @@ cmake_minimum_required(VERSION 3.0)
project(cpp_inference_demo CXX C)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
if (WIN32)
set(CMAKE_STATIC_LIBRARY_PREFIX "lib")
else()
set(CMAKE_STATIC_LIBRARY_PREFIX "")
endif()
if(NOT DEFINED PADDLE_LIB)
message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib")
......@@ -32,44 +37,56 @@ endif(NOT WIN32)
include_directories("${PADDLE_LIB}/third_party/boost")
include_directories("${PADDLE_LIB}/third_party/eigen3")
if (NOT WIN32)
link_directories("${PADDLE_LIB}/third_party/install/snappy/lib")
link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib")
link_directories("${PADDLE_LIB}/third_party/install/zlib/lib")
endif(NOT WIN32)
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
link_directories("${PADDLE_LIB}/third_party/install/gflags/lib")
link_directories("${PADDLE_LIB}/third_party/install/zlib/lib")
link_directories("${PADDLE_LIB}/paddle/fluid/inference")
add_executable(${DEMO_NAME} ${DEMO_NAME}.cc)
if(WITH_MKL)
include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel.so
${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5.so)
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn")
if(EXISTS ${MKLDNN_PATH})
include_directories("${MKLDNN_PATH}/include")
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
endif()
else()
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.a)
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX})
endif()
# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a
if(WITH_STATIC_LIB)
set(DEPS
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.a)
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(DEPS
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.so)
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
if (NOT WIN32)
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf snappystream snappy z
${EXTERNAL_LIB})
else()
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
${CMAKE_STATIC_LIBRARY_PREFIX}glog ${CMAKE_STATIC_LIBRARY_PREFIX}gflags ${CMAKE_STATIC_LIBRARY_PREFIX}protobuf
${EXTERNAL_LIB})
endif(NOT WIN32)
if(WITH_GPU)
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart.so)
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
target_link_libraries(${DEMO_NAME} ${DEPS})
......@@ -16,35 +16,15 @@
#include <sys/time.h>
#include <algorithm>
#include <numeric>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/timer.h"
namespace paddle {
namespace inference {
// Timer for timer
class Timer {
public:
double start;
double startu;
void tic() {
struct timeval tp;
gettimeofday(&tp, NULL);
start = tp.tv_sec;
startu = tp.tv_usec;
}
double toc() {
struct timeval tp;
gettimeofday(&tp, NULL);
double used_time_ms =
(tp.tv_sec - start) * 1000.0 + (tp.tv_usec - startu) / 1000.0;
return used_time_ms;
}
};
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono> // NOLINT
namespace paddle {
namespace inference {
// Timer for timer
class Timer {
public:
std::chrono::high_resolution_clock::time_point start;
std::chrono::high_resolution_clock::time_point startu;
void tic() { start = std::chrono::high_resolution_clock::now(); }
double toc() {
startu = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> time_span =
std::chrono::duration_cast<std::chrono::duration<double>>(startu -
start);
double used_time_ms = static_cast<double>(time_span.count()) * 1000.0;
return used_time_ms;
}
};
} // namespace inference
} // namespace paddle
......@@ -178,6 +178,8 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
elseif(${TARGET} STREQUAL "fake_quantize")
file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
elseif(${TARGET} STREQUAL "tensorrt_engine_op")
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
elseif(${TARGET} STREQUAL "fc")
......@@ -293,6 +295,7 @@ op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op)
op_library(fake_quantize_op DEPS memory)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
......@@ -14,86 +14,198 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h"
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
class FakeQuantizeOp : public framework::OperatorWithKernel {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVectorArrayMap =
Eigen::TensorMap<Eigen::Tensor<T, 1, MajorType, IndexType>>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using ConstEigenVectorArrayMap =
Eigen::TensorMap<const Eigen::Tensor<T, 1, MajorType, IndexType>>;
template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in,
const int num, T* out) {
Eigen::DSizes<Eigen::DenseIndex, 1> idim(num);
Eigen::DSizes<Eigen::DenseIndex, 1> odim(1);
Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>> in_e(in, idim);
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>> out_e(out, odim);
out_e = in_e.abs().maximum();
}
};
template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
T s = scale.data<T>()[0];
platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
auto in_e = framework::EigenVector<T>::Flatten(in);
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * in_e).round();
}
};
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& cur_scale,
const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale) {
T* scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace());
int64_t it = iter.data<int64_t>()[0];
int idx = it % window_size;
T removed = scale_arr[idx];
T cur = cur_scale.data<T>()[0];
scale_arr[idx] = cur;
T max = last_scale.data<T>()[0];
if (max < cur) {
max = cur;
} else if (fabs(removed - max) < 1e-6) {
int size = (it > window_size) ? window_size : it;
FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size,
&max);
}
out_scale->mutable_data<T>(ctx.GetPlace())[0] = max;
}
};
template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;
class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
FakeQuantizeAbsMaxOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"),
"OutMovingScale(Out) of FakeQuantizeOp should not be null");
// if (ctx->HasInput("InMovingScale")) {
ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale"));
//}
// if (ctx->HasInput("InScales")) {
PADDLE_ENFORCE(ctx->HasOutput("OutScales"),
"OutScales(Out) of FakeQuantizeOp should not be null");
ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales"));
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
// ctx->Outputs("OutScales")[0],
// "Mean and MeanOut should share the same memory");
//}
PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
"Output(Scale) of FakeQuantizeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker {
class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of scale operator.");
AddInput("InScales", "(Tensor) scale buffer, used in static quantization.")
.AsDispensable();
AddInput("InMovingScale", "Last scale, used in static quantization.")
.AsDispensable();
AddInput("InCurrentIter",
"Last iteration number, used in static quantization.")
.AsDispensable();
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
AddOutput("OutScales",
"(Tensor) scale buffer, used in static quantization.")
.AsDispensable();
AddOutput("OutMovingScale", " Current scale");
AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable();
AddAttr<std::string>("quantize_type",
"(string, default abs_max)"
"The scaling tpe of the quantize operator.")
.SetDefault("abs_max");
AddAttr<int>("window_size", "(int, default 10000)").SetDefault(10000);
AddInput("X", "(Tensor) Input is float data type.");
AddOutput("Out",
"(Tensor) Output of quantized low level tensor, "
"but also saved as float data type.");
AddOutput("OutScale", "(Tensor) Current scale");
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int &bit_length) {
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
});
AddAttr<bool>("is_test", "").SetDefault(false);
AddComment(R"DOC(
FakeQuantize operator
quantize_type = abs_max:
$$scale = max(abs(X))$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
$$scale = max(abs(x))$$
)DOC");
}
};
quantize_type = range_abs_max:
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeRangeAbsMaxOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
$$scale = max(max(abs(x)), history_abs_max)$$
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantizeRangeAbsMaxOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutScale"),
"Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null");
if (ctx->HasOutput("OutScales")) {
int window_size = ctx->Attrs().Get<int>("window_size");
ctx->SetOutputDim("OutScales", {window_size});
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
}
quantize_type = moving_average_abs_max:
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
$$scale = 0.1*scale+0.9*new_abs_max)$$
class FakeQuantizeRangeAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input is float data type.");
AddInput("InScale", "Last scale.");
AddInput("Iter", "Global step iteration.").AsDispensable();
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
AddOutput("OutScale", " Current scale");
AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable();
AddAttr<int>("window_size", "(int, default 10000) window range size.")
.SetDefault(10000);
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
});
AddAttr<bool>("is_test", "").SetDefault(false);
AddComment(R"DOC(
FakeQuantize operator is used in static quantization.
$$Out = scale*X$$
$$scale = max(max(abs(x)), history_abs_max)$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
)DOC");
}
......@@ -103,10 +215,16 @@ $$Out = scale*X$$
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp,
ops::FakeQuantizeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker,
REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
ops::FakeQuantizeRangeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
fake_quantize,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -20,7 +21,7 @@ namespace paddle {
namespace operators {
template <typename T>
__global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
......@@ -43,7 +44,7 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i && shared_max_data[tid] < shared_max_data[tid + i]) {
if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
shared_max_data[tid] = shared_max_data[tid + i];
}
__syncthreads();
......@@ -53,220 +54,124 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
}
}
float FindAbsMaxGpu(const platform::CUDADeviceContext& ctx, const float* array,
int length) {
float host_max;
int kNumTheads = 1024;
int gridDimx = (kNumTheads - 1 + length) / kNumTheads;
gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx;
framework::Tensor t;
float* device_max = t.mutable_data<float>(framework::make_ddim({gridDimx}),
platform::CUDAPlace());
FindAbsMaxKernel<float><<<gridDimx, kNumTheads, kNumTheads * sizeof(float),
ctx.stream()>>>(length, array, device_max);
FindAbsMaxKernel<
float><<<1, kNumTheads, kNumTheads * sizeof(float), ctx.stream()>>>(
gridDimx, device_max, device_max);
PADDLE_ENFORCE_EQ(
cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost),
cudaSuccess, "cudaMemcpy failed");
return host_max;
}
template <typename T>
struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* in,
const int num, T* out) {
int block = 1024;
int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid;
framework::Tensor max;
T* max_data =
max.mutable_data<T>(framework::make_ddim({grid}), ctx.GetPlace());
FindAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
in, num, max_data);
FindAbsMaxKernel<T><<<1, block, 1024 * sizeof(T), ctx.stream()>>>(
max_data, grid, out);
}
};
template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T>
__global__ void ApplySaturateKernel(const int n, const T* in, T* out,
int* num_saturate, const T min,
const T max) {
__global__ void ClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n, T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
extern __shared__ int shared_count[];
shared_count[tid] = 0;
T s = scale[0];
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
if (in[i] > max) {
out[i] = max;
shared_count[tid] += 1;
} else if (in[i] < min) {
out[i] = min;
shared_count[tid] += 1;
} else {
out[i] = in[i];
}
}
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i) {
shared_count[tid] += shared_count[tid + i];
}
__syncthreads();
}
if (tid == 0) {
num_saturate[blockIdx.x] = shared_count[0];
T x = in[bid];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt / s * v;
out[bid] = round(v);
}
}
template <typename T>
__global__ void ReduceKernel(const int n, const T* in, T* out) {
int tid = threadIdx.x;
extern __shared__ T shared_sum[];
if (tid < n) {
shared_sum[tid] = in[tid];
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
const T* last_scale,
const int64_t* iter,
const int window_size, T* scale_arr,
T* out_scale, int* need_find_max,
int* out_size) {
int it = iter[0];
int idx = it % window_size;
T removed = scale_arr[idx];
T cur = cur_scale[0];
scale_arr[idx] = cur;
T max = last_scale[0];
out_scale[0] = max < cur ? cur : max;
if (fabs(removed - max) < 1e-6) {
need_find_max[0] = 1;
out_size[0] = it > window_size ? window_size : it;
} else {
shared_sum[tid] = T(0);
}
__syncthreads();
// blockDim.x must >= n
for (int i = (n + 1) / 2; i > 0; i >>= 1) {
if (tid < i) {
shared_sum[tid] += shared_sum[tid + i];
}
__syncthreads();
}
if (tid == 0) {
out[0] = shared_sum[0];
need_find_max[0] = 0;
}
}
template <typename T>
int ApplySaturateGpu(const platform::CUDADeviceContext& ctx, const int n,
const T* in, T* out, const T min, const T max) {
int host_num_saturate;
int kNumTheads = 1024;
int gridDimx = (n + kNumTheads - 1) / kNumTheads;
gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx;
framework::Tensor t;
int* device_num_saturate = t.mutable_data<int>(
framework::make_ddim({gridDimx}), platform::CUDAPlace());
ApplySaturateKernel<
T><<<gridDimx, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>(
n, in, out, device_num_saturate, min, max);
ReduceKernel<int><<<1, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>(
gridDimx, device_num_saturate, device_num_saturate);
PADDLE_ENFORCE_EQ(cudaSuccess,
cudaMemcpy(&host_num_saturate, device_num_saturate,
sizeof(int), cudaMemcpyDeviceToHost),
"cudaMemcpy failed");
return host_num_saturate;
}
template <typename DeviceContext, typename T>
class FakeQuantizeCUDAKernel : public framework::OpKernel<T> {
public:
T FindRangeAbsMax(const platform::CUDADeviceContext& ctx,
framework::Tensor* scale_list, framework::Tensor* out_scale,
const T& cur_scale, int window_size,
int current_iter) const {
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
T remove_tmp = sl[current_iter];
sl[current_iter] = cur_scale;
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMaxGpu(ctx, scale_list->data<float>(), size));
struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& cur_scale,
const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale) {
auto& gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
T* scale_arr = scales_arr->mutable_data<T>(gpu_place);
T* out_scale_data = out_scale->mutable_data<T>(gpu_place);
framework::Tensor need_find_max, out_size;
int* find_max = need_find_max.mutable_data<int>(gpu_place);
int* out_size_data = out_size.mutable_data<int>(gpu_place);
FindRangeAbsMaxAndFillArray<T><<<1, 1, 0, ctx.stream()>>>(
cur_scale.data<T>(), last_scale.data<T>(), iter.data<int64_t>(),
window_size, scale_arr, out_scale_data, find_max, out_size_data);
int g_find_max;
memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max,
sizeof(int), 0);
if (g_find_max) {
int len;
memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data,
sizeof(int), 0);
FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
out_scale_data);
}
return max_scale;
}
T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
framework::Tensor* out_scale,
const T& cur_scale) const {
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
T* outs = out_scale->mutable_data<T>(platform::CPUPlace());
outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
return T(outs[0]);
}
};
virtual void Compute(const framework::ExecutionContext& context) const {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
auto& device_ctx = context.cuda_device_context();
auto* tensor = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
const bool is_test = context.Attr<bool>("is_test");
tensor->mutable_data<T>(in->place());
context.Output<framework::Tensor>("OutMovingScale")
->mutable_data<T>(
context.Input<framework::Tensor>("InMovingScale")->place());
auto quantize_type =
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
if (quantize_type == std::string("range_abs_max")) {
context.Output<framework::Tensor>("OutScales")
->mutable_data<T>(
context.Input<framework::Tensor>("InScales")->place());
context.Output<framework::Tensor>("OutCurrentIter")
->mutable_data<T>(
context.Input<framework::Tensor>("InCurrentIter")->place());
}
T scale = T(1);
int window_size = context.Attr<int>("window_size");
T bin_cnt = (T)((1 << (context.Attr<int>("bit_length") - 1)) - 1);
if (quantize_type == std::string("abs_max")) {
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale");
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
auto& device_ctx = context.template device_context<DeviceContext>();
auto* scale_list = context.Output<framework::Tensor>("OutScales");
math::SetConstant<DeviceContext, T> scalar;
scale_list->mutable_data<T>(context.GetPlace());
scalar(device_ctx, scale_list, static_cast<T>(0));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
iter->mutable_data<T>(context.GetPlace());
scalar(device_ctx, iter, static_cast<T>(0));
} else if (quantize_type == std::string("range_abs_max")) {
auto* moving_scale = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InMovingScale"));
if (is_test) {
scale = moving_scale->mutable_data<T>(platform::CPUPlace())[0];
} else {
auto* it = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InCurrentIter"));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
int* last_iter = it->mutable_data<int>(platform::CPUPlace());
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
auto* scale_list = context.Output<framework::Tensor>("OutScales");
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
scale = FindRangeAbsMax(device_ctx, scale_list, saving_scale, scale,
window_size, current_iter[0]);
(*current_iter) = (*last_iter) + 1;
}
} else if (quantize_type == std::string("moving_average_abs_max")) {
auto* moving_scale = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InMovingScale"));
if (is_test) {
scale = moving_scale->mutable_data<T>(platform::CPUPlace())[0];
} else {
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
scale = FindMovingAverageAbsMmax(
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
}
}
ApplySaturateGpu<T>(device_ctx, in->numel(), in->data<T>(),
tensor->mutable_data<T>(in->place()), -scale, scale);
scale = bin_cnt / scale;
template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
auto& dev =
*context.template device_context<DeviceContext>().eigen_device();
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor);
eigen_out.device(dev) = (scale * eigen_in).round();
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data);
}
};
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(fake_quantize,
paddle::operators::FakeQuantizeCUDAKernel<
paddle::platform::CUDADeviceContext, float>,
paddle::operators::FakeQuantizeCUDAKernel<
paddle::platform::CUDADeviceContext, double>);
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
......@@ -17,137 +17,91 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
using platform::Transform;
template <typename DeviceContext, typename T>
struct FindAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num, T* out);
};
template <typename DeviceContext, typename T>
class FakeQuantizeKernel : public framework::OpKernel<T> {
struct ClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
framework::Tensor* out);
};
template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale,
const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale);
};
template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
public:
T FindAbsMax(framework::Tensor* in, int n) const {
T* p = in->mutable_data<T>(platform::CPUPlace());
T abs_max = (T)0.00000001;
for (int i = 0; i < n; i++) {
T tmp = fabs(p[i]);
if (tmp > abs_max) abs_max = tmp;
}
return T(abs_max);
}
T FindRangeAbsMax(framework::Tensor* scale_list, framework::Tensor* out_scale,
const T& cur_scale, int window_size,
int current_iter) const {
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
T remove_tmp = sl[current_iter];
sl[current_iter] = cur_scale;
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMax(scale_list, size));
}
return max_scale;
}
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
framework::Tensor* out_scale,
const T& cur_scale) const {
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
T* outs = out_scale->mutable_data<T>(platform::CPUPlace());
outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
return T(outs[0]);
auto* out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_s = out_scale->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>();
const T* in_data = in->data<T>();
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale,
bin_cnt, out);
}
};
virtual void Compute(const framework::ExecutionContext& context) const {
auto* tensor = context.Output<framework::Tensor>("Out");
template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
const bool is_test = context.Attr<bool>("is_test");
tensor->mutable_data<T>(in->place());
auto* oms_tensor = context.Output<framework::Tensor>("OutMovingScale");
oms_tensor->mutable_data<T>(in->place());
auto quantize_type =
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
if (quantize_type == std::string("range_abs_max")) {
auto* oss_tensor = context.Output<framework::Tensor>("OutScales");
oss_tensor->mutable_data<T>(
context.Input<framework::Tensor>("InScales")->place());
auto* oci_tensor = context.Output<framework::Tensor>("OutCurrentIter");
oci_tensor->mutable_data<T>(
context.Input<framework::Tensor>("InCurrentIter")->place());
}
auto* in_scale = context.Input<framework::Tensor>("InScale");
T scale = static_cast<T>(1);
int window_size = context.Attr<int>("window_size");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
bool is_test = context.Attr<bool>("is_test");
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto& dev =
*context.template device_context<DeviceContext>().eigen_device();
auto raw_in = framework::EigenVector<T>::Flatten(*in);
if (quantize_type == std::string("abs_max")) {
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = scale_out(0);
auto& device_ctx = context.template device_context<DeviceContext>();
auto* scale_list = context.Output<framework::Tensor>("OutScales");
math::SetConstant<DeviceContext, T> scalar;
scale_list->mutable_data<T>(context.GetPlace());
scalar(device_ctx, scale_list, static_cast<T>(0));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
iter->mutable_data<T>(context.GetPlace());
scalar(device_ctx, iter, static_cast<T>(0));
} else if (quantize_type == std::string("range_abs_max")) {
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
if (is_test) {
scale = moving_scale->data<T>()[0];
} else {
auto* it = context.Input<framework::Tensor>("InCurrentIter");
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
const int* last_iter = it->data<int>();
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
auto* scale_list = context.Output<framework::Tensor>("OutScales");
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
scale = FindRangeAbsMax(scale_list, saving_scale, scale, window_size,
current_iter[0]);
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
(*current_iter) = (*last_iter) + 1;
}
} else if (quantize_type == std::string("moving_average_abs_max")) {
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
if (is_test) {
scale = moving_scale->data<T>()[0];
} else {
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
scale = FindMovingAverageAbsMmax(
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
}
// testing
if (is_test) {
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *in_scale,
bin_cnt, out);
return;
}
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), in->data<T>(),
in->data<T>() + in->numel(), tensor->mutable_data<T>(in->place()),
ClipFunctor<T>(-scale, scale));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor);
eigen_out.device(dev) = (bin_cnt / scale * eigen_in).round();
// training
auto* out_scale = context.Output<framework::Tensor>("OutScale");
auto* out_scales = context.Output<framework::Tensor>("OutScales");
auto* iter = context.Input<framework::Tensor>("Iter");
int window_size = context.Attr<int>("window_size");
out_scale->mutable_data<T>(context.GetPlace());
framework::Tensor cur_scale;
T* cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(),
cur_scale_data);
FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx, cur_scale, *in_scale,
*iter, window_size, out_scales,
out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale,
bin_cnt, out);
}
};
......
......@@ -53,7 +53,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
static_cast<T>(context.Attr<float>("min")),
static_cast<T>(context.Attr<float>("max")));
std::vector<T> ids(batch_size);
std::vector<int64_t> ids(batch_size);
for (int i = 0; i < batch_size; ++i) {
T r = dist(engine);
int idx = width - 1;
......@@ -63,7 +63,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
break;
}
}
ids[i] = ins_vector[idx];
ids[i] = int64_t(idx);
}
std::vector<int64_t> out_dim;
......
......@@ -98,10 +98,9 @@ class Inferencer(object):
raise ValueError(
"inputs should be a map of {'input_name': input_var}")
with executor.scope_guard(self.scope):
results = self.exe.run(self.inference_program,
feed=inputs,
fetch_list=[self.predict_var],
with self._prog_and_scope_guard():
results = self.exe.run(feed=inputs,
fetch_list=[self.predict_var.name],
return_numpy=return_numpy)
return results
......
......@@ -16,7 +16,9 @@ from __future__ import print_function
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy
import os
import cifar10_small_test_set
......@@ -89,7 +91,7 @@ def optimizer_func():
return fluid.optimizer.Adam(learning_rate=0.001)
def train(use_cuda, train_program, params_dirname):
def train(use_cuda, train_program, parallel, params_dirname):
BATCH_SIZE = 128
EPOCH_NUM = 1
......@@ -116,7 +118,10 @@ def train(use_cuda, train_program, params_dirname):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = fluid.Trainer(
train_func=train_program, optimizer_func=optimizer_func, place=place)
train_func=train_program,
optimizer_func=optimizer_func,
place=place,
parallel=parallel)
trainer.train(
reader=train_reader,
......@@ -125,10 +130,13 @@ def train(use_cuda, train_program, params_dirname):
feed_order=['pixel', 'label'])
def infer(use_cuda, inference_program, params_dirname=None):
def infer(use_cuda, inference_program, parallel, params_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place)
infer_func=inference_program,
param_path=params_dirname,
place=place,
parallel=parallel)
# The input's dimension of conv should be 4-D or 5-D.
# Use normilized image pixels as input data, which should be in the range
......@@ -139,22 +147,34 @@ def infer(use_cuda, inference_program, params_dirname=None):
print("infer results: ", results)
def main(use_cuda):
def main(use_cuda, parallel):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
save_path = "image_classification_resnet.inference.model"
os.environ['CPU_NUM'] = str(4)
train(
use_cuda=use_cuda,
train_program=train_network,
params_dirname=save_path)
params_dirname=save_path,
parallel=parallel)
# FIXME(zcd): in the inference stage, the number of
# input data is one, it is not appropriate to use parallel.
if parallel and use_cuda:
return
os.environ['CPU_NUM'] = str(1)
infer(
use_cuda=use_cuda,
inference_program=inference_network,
params_dirname=save_path)
params_dirname=save_path,
parallel=parallel)
if __name__ == '__main__':
for use_cuda in (False, True):
main(use_cuda=use_cuda)
for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
main(use_cuda=use_cuda, parallel=parallel)
......@@ -16,7 +16,9 @@ from __future__ import print_function
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy
import os
import cifar10_small_test_set
......@@ -68,7 +70,7 @@ def optimizer_func():
return fluid.optimizer.Adam(learning_rate=0.001)
def train(use_cuda, train_program, params_dirname):
def train(use_cuda, train_program, parallel, params_dirname):
BATCH_SIZE = 128
train_reader = paddle.batch(
paddle.reader.shuffle(
......@@ -93,7 +95,10 @@ def train(use_cuda, train_program, params_dirname):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = fluid.Trainer(
train_func=train_program, place=place, optimizer_func=optimizer_func)
train_func=train_program,
place=place,
optimizer_func=optimizer_func,
parallel=parallel)
trainer.train(
reader=train_reader,
......@@ -102,10 +107,13 @@ def train(use_cuda, train_program, params_dirname):
feed_order=['pixel', 'label'])
def infer(use_cuda, inference_program, params_dirname=None):
def infer(use_cuda, inference_program, parallel, params_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place)
infer_func=inference_program,
param_path=params_dirname,
place=place,
parallel=parallel)
# The input's dimension of conv should be 4-D or 5-D.
# Use normilized image pixels as input data, which should be in the range
......@@ -116,22 +124,31 @@ def infer(use_cuda, inference_program, params_dirname=None):
print("infer results: ", results)
def main(use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
def main(use_cuda, parallel):
save_path = "image_classification_vgg.inference.model"
os.environ['CPU_NUM'] = str(4)
train(
use_cuda=use_cuda,
train_program=train_network,
params_dirname=save_path)
params_dirname=save_path,
parallel=parallel)
# FIXME(zcd): in the inference stage, the number of
# input data is one, it is not appropriate to use parallel.
if parallel and use_cuda:
return
os.environ['CPU_NUM'] = str(1)
infer(
use_cuda=use_cuda,
inference_program=inference_network,
params_dirname=save_path)
params_dirname=save_path,
parallel=parallel)
if __name__ == '__main__':
for use_cuda in (False, True):
main(use_cuda=use_cuda)
for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
main(use_cuda=use_cuda, parallel=parallel)
......@@ -64,14 +64,14 @@ def optimizer_func():
return fluid.optimizer.Adam(learning_rate=0.001)
def train(use_cuda, train_program, params_dirname):
def train(use_cuda, train_program, parallel, params_dirname):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = fluid.Trainer(
train_func=train_program,
place=place,
optimizer_func=optimizer_func,
parallel=True)
parallel=parallel)
def event_handler(event):
if isinstance(event, fluid.EndEpochEvent):
......@@ -108,11 +108,14 @@ def train(use_cuda, train_program, params_dirname):
feed_order=['img', 'label'])
def infer(use_cuda, inference_program, params_dirname=None):
def infer(use_cuda, inference_program, parallel, params_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place)
infer_func=inference_program,
param_path=params_dirname,
place=place,
parallel=parallel)
batch_size = 1
tensor_img = numpy.random.uniform(-1.0, 1.0,
......@@ -123,20 +126,32 @@ def infer(use_cuda, inference_program, params_dirname=None):
print("infer results: ", results[0])
def main(use_cuda):
def main(use_cuda, parallel):
params_dirname = "recognize_digits_conv.inference.model"
# call train() with is_local argument to run distributed train
os.environ['CPU_NUM'] = str(4)
train(
use_cuda=use_cuda,
train_program=train_program,
params_dirname=params_dirname)
params_dirname=params_dirname,
parallel=parallel)
# FIXME(zcd): in the inference stage, the number of
# input data is one, it is not appropriate to use parallel.
if parallel and use_cuda:
return
os.environ['CPU_NUM'] = str(1)
infer(
use_cuda=use_cuda,
inference_program=inference_program,
params_dirname=params_dirname)
params_dirname=params_dirname,
parallel=parallel)
if __name__ == '__main__':
# for use_cuda in (False, True):
main(use_cuda=core.is_compiled_with_cuda())
for use_cuda in (False, True):
for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
main(use_cuda=use_cuda, parallel=parallel)
......@@ -16,6 +16,7 @@ from __future__ import print_function
import argparse
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle
import sys
import numpy
......@@ -50,11 +51,14 @@ def optimizer_func():
return fluid.optimizer.Adam(learning_rate=0.001)
def train(use_cuda, train_program, params_dirname):
def train(use_cuda, train_program, params_dirname, parallel):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = fluid.Trainer(
train_func=train_program, place=place, optimizer_func=optimizer_func)
train_func=train_program,
place=place,
optimizer_func=optimizer_func,
parallel=parallel)
def event_handler(event):
if isinstance(event, fluid.EndEpochEvent):
......@@ -86,11 +90,14 @@ def train(use_cuda, train_program, params_dirname):
feed_order=['img', 'label'])
def infer(use_cuda, inference_program, params_dirname=None):
def infer(use_cuda, inference_program, parallel, params_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place)
infer_func=inference_program,
param_path=params_dirname,
place=place,
parallel=parallel)
batch_size = 1
tensor_img = numpy.random.uniform(-1.0, 1.0,
......@@ -101,20 +108,32 @@ def infer(use_cuda, inference_program, params_dirname=None):
print("infer results: ", results[0])
def main(use_cuda):
def main(use_cuda, parallel):
params_dirname = "recognize_digits_mlp.inference.model"
# call train() with is_local argument to run distributed train
os.environ['CPU_NUM'] = str(4)
train(
use_cuda=use_cuda,
train_program=train_program,
params_dirname=params_dirname)
params_dirname=params_dirname,
parallel=parallel)
# FIXME(zcd): in the inference stage, the number of
# input data is one, it is not appropriate to use parallel.
if parallel and use_cuda:
return
os.environ['CPU_NUM'] = str(1)
infer(
use_cuda=use_cuda,
inference_program=inference_program,
params_dirname=params_dirname)
params_dirname=params_dirname,
parallel=parallel)
if __name__ == '__main__':
# for use_cuda in (False, True):
main(use_cuda=False)
for use_cuda in (False, True):
for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
main(use_cuda=use_cuda, parallel=parallel)
......@@ -21,28 +21,41 @@ from op_test import OpTest
class TestFakeQuantizeOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize"
self.op_type = "fake_quantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {'X': np.random.random((124, 240)).astype("float32"), }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
self.outputs = {
'Out': np.round(self.inputs['X'] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScale': np.array(scale).astype("float32"),
}
def test_check_output(self):
self.check_output()
class TestFakeQuantizeOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_range_abs_max"
self.attrs = {
'bit_length': 8,
'quantize_type': 'abs_max',
'window_size': 10000
'bit_length': int(5),
'window_size': int(1),
'is_test': False
}
self.inputs = {
'X': np.random.random((10, 10)).astype("float32"),
'InScales': np.zeros(self.attrs['window_size']).astype("float32"),
'InCurrentIter': np.zeros(1).astype("float32"),
'InMovingScale': np.zeros(1).astype("float32")
}
self.scale = {
'abs_max': np.max(np.abs(self.inputs['X'])).astype("float32")
'X': np.random.random((8, 16, 7, 7)).astype("float32"),
'Iter': np.zeros(1).astype("int64"),
'InScale': np.zeros(1).astype("float32")
}
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
out_scales = np.zeros(self.attrs['window_size']).astype("float32")
out_scales[0] = scale
self.outputs = {
'Out': np.round(self.inputs['X'] / self.scale['abs_max'] * (
'Out': np.round(self.inputs['X'] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScales': np.zeros(self.attrs['window_size']).astype("float32"),
'OutMovingScale':
np.array([self.scale['abs_max']]).astype("float32"),
'OutCurrentIter': np.zeros(1).astype("float32")
'OutScale': scale,
'OutScales': out_scales,
}
def test_check_output(self):
......
......@@ -25,9 +25,9 @@ class TestSamplingIdOp(OpTest):
self.op_type = "sampling_id"
self.use_mkldnn = False
self.init_kernel_type()
self.X = np.random.random((8, 4)).astype('float32')
self.X = np.random.random((100, 10)).astype('float32')
self.inputs = {"X": self.X}
self.Y = np.random.random(8).astype('float32')
self.Y = np.random.random(100).astype('int64')
self.outputs = {'Out': self.Y}
self.attrs = {'max': 1.0, 'min': 0.0, 'seed': 1}
......@@ -36,6 +36,16 @@ class TestSamplingIdOp(OpTest):
y1 = self.out
self.check_output_customized(self.verify_output)
y2 = self.out
# check dtype
assert y1.dtype == np.int64
assert y2.dtype == np.int64
# check output is index ids of inputs
inputs_ids = np.arange(self.X.shape[1])
assert np.isin(y1, inputs_ids).all()
assert np.isin(y2, inputs_ids).all()
self.assertTrue(np.array_equal(y1, y2))
self.assertEqual(len(y1), len(self.Y))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册