提交 251eb372 编写于 作者: D Dang Qingqing

Improve and fix fake_quantize_op.

上级 bf85cded
...@@ -178,6 +178,8 @@ function(op_library TARGET) ...@@ -178,6 +178,8 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(relu);\n") file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "fake_dequantize") elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
elseif(${TARGET} STREQUAL "fake_quantize")
file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
elseif(${TARGET} STREQUAL "tensorrt_engine_op") elseif(${TARGET} STREQUAL "tensorrt_engine_op")
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
elseif(${TARGET} STREQUAL "fc") elseif(${TARGET} STREQUAL "fc")
...@@ -291,6 +293,7 @@ op_library(unsqueeze_op DEPS reshape_op) ...@@ -291,6 +293,7 @@ op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory) op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op) op_library(flatten_op DEPS reshape_op)
op_library(fake_quantize_op DEPS memory)
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
...@@ -15,43 +15,55 @@ limitations under the License. */ ...@@ -15,43 +15,55 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/operators/fake_quantize_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, int MajorType = Eigen::RowMajor,
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>; typename IndexType = Eigen::DenseIndex>
template <typename T> using EigenVectorArrayMap =
Eigen::TensorMap<Eigen::Tensor<T, 1, MajorType, IndexType>>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using ConstEigenVectorArrayMap = using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>; Eigen::TensorMap<const Eigen::Tensor<T, 1, MajorType, IndexType>>;
template <typename T> template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const CPUDeviceContext& ctx, const T* in, const int num, void operator()(const platform::CPUDeviceContext& ctx, const T* in,
T* out) { const int num, T* out) {
ConstEigenVectorArrayMap<T> in_e(in, num); Eigen::DSizes<Eigen::DenseIndex, 1> idim(num);
EigenVectorArrayMap<T> out_e(out, 1); Eigen::DSizes<Eigen::DenseIndex, 1> odim(1);
Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>> in_e(in, idim);
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>> out_e(out, odim);
auto& dev = ctx.eigen_device();
out_e = in_e.abs().maximum(); out_e = in_e.abs().maximum();
} }
}; };
template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const CPUDeviceContext& ctx, const framework::Tensor& in, void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor* scale, const int bin_cnt, const framework::Tensor& in, const framework::Tensor& scale,
framework::Tensor* out) { const int bin_cnt, framework::Tensor* out) {
T s = scale->data<T>()[0]; T s = scale.data<T>()[0];
Transform<DeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(), trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s)); out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
auto in_e = framework::EigenVector<T>::Flatten(in); auto in_e = framework::EigenVector<T>::Flatten(in);
auto out_e = framework::EigenVector<T>::Flatten(*out); auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(dev) = (bin_cnt / s * in_e).round();
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * in_e).round();
} }
}; };
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext& ctx,
...@@ -59,10 +71,10 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> { ...@@ -59,10 +71,10 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& last_scale, const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size, const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale) { framework::Tensor* scales_arr, framework::Tensor* out_scale) {
T* scale_arr = scales_arr->mutable_data<T>(cxt.GetPlace()); T* scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace());
int it = iter.data<int>()[0]; int64_t it = iter.data<int64_t>()[0];
int idx = it % window_size; int idx = it % window_size;
T removd = scale_arr[idx]; T removed = scale_arr[idx];
T cur = cur_scale.data<T>()[0]; T cur = cur_scale.data<T>()[0];
scale_arr[idx] = cur; scale_arr[idx] = cur;
...@@ -74,10 +86,12 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> { ...@@ -74,10 +86,12 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size, FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size,
&max); &max);
} }
out_scale->mutable_data<T>()[0] = max; out_scale->mutable_data<T>(ctx.GetPlace())[0] = max;
} }
}; };
template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;
class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
public: public:
FakeQuantizeAbsMaxOp(const std::string& type, FakeQuantizeAbsMaxOp(const std::string& type,
...@@ -97,6 +111,14 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -97,6 +111,14 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("OutScale", {1}); ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
}; };
class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -126,10 +148,10 @@ $$Out = round(X/scale * range)$$ ...@@ -126,10 +148,10 @@ $$Out = round(X/scale * range)$$
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public: public:
FakeQuantizeOp(const std::string& type, FakeQuantizeRangeAbsMaxOp(const std::string& type,
const framework::VariableNameMap& inputs, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
...@@ -141,16 +163,22 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -141,16 +163,22 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasOutput("OutScale"), ctx->HasOutput("OutScale"),
"Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null"); "Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null");
if (ctx->HasInput("InScales")) { if (ctx->HasOutput("OutScales")) {
PADDLE_ENFORCE( int window_size = ctx->Attrs().Get<int>("window_size");
ctx->HasOutput("OutScales"), ctx->SetOutputDim("OutScales", {window_size});
"Output(OutScales) of FakeQuantizeRangeAbsMaxOp should not be null");
ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales"));
} }
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1}); ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
}; };
class FakeQuantizeRangeAbsMaxOpMaker class FakeQuantizeRangeAbsMaxOpMaker
...@@ -158,10 +186,8 @@ class FakeQuantizeRangeAbsMaxOpMaker ...@@ -158,10 +186,8 @@ class FakeQuantizeRangeAbsMaxOpMaker
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) Input is float data type."); AddInput("X", "(Tensor) Input is float data type.");
AddInput("InScales", "(Tensor) scale buffer.").AsDispensable(); AddInput("InScale", "Last scale.");
AddInput("InScale", "Last scale.") AddInput("Iter", "Global step iteration.").AsDispensable();
AddInput("Iter", "Global step iteration.")
.AsDispensable();
AddOutput("Out", "(Tensor) Output of quantized low level tensor."); AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
AddOutput("OutScale", " Current scale"); AddOutput("OutScale", " Current scale");
AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable(); AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable();
...@@ -189,19 +215,16 @@ $$Out = round(X/scale * range)$$ ...@@ -189,19 +215,16 @@ $$Out = round(X/scale * range)$$
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp, REGISTER_OPERATOR(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp,
ops::FakeQuantizeAbsMaxOpMaker, ops::FakeQuantizeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
fake_quantize_abs_max, ops::FakeQuantizeAbsMaxKernel<CPU, float>);
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeOp, REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
ops::FakeQuantizeRangeAbsMaxOpMaker, ops::FakeQuantizeRangeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -24,53 +25,55 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { ...@@ -24,53 +25,55 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
extern __shared__ T shared_max[]; extern __shared__ T shared_max_data[];
if (gridDim.x > 1) { if (gridDim.x > 1) {
shared_max[tid] = T(0); shared_max_data[tid] = T(0);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T tmp = fabs(in[i]); T tmp = fabs(in[i]);
if (tmp > shared_max[tid]) { if (tmp > shared_max_data[tid]) {
shared_max[tid] = tmp; shared_max_data[tid] = tmp;
} }
} }
} else { } else {
if (bid < n) { if (bid < n) {
shared_max[tid] = fabs(in[bid]); shared_max_data[tid] = fabs(in[bid]);
} else { } else {
shared_max[tid] = T(0); shared_max_data[tid] = T(0);
} }
} }
__syncthreads(); __syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) { for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i && (shared_max[tid] < shared_max[tid + i])) { if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
shared_max[tid] = shared_max[tid + i]; shared_max_data[tid] = shared_max_data[tid + i];
} }
__syncthreads(); __syncthreads();
} }
if (tid == 0) { if (tid == 0) {
out[blockIdx.x] = shared_max[0]; out[blockIdx.x] = shared_max_data[0];
} }
} }
template <typename T> template <typename T>
struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> { struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const CUDADeviceContext& ctx, const T* in, const int num, void operator()(const platform::CUDADeviceContext& ctx, const T* in,
T* out) { const int num, T* out) {
int block = 1024; int block = 1024;
int grid = (block - 1 + num) / block; int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid; grid = (grid > block) ? block : grid;
Tensor max; framework::Tensor max;
T* max_data = T* max_data =
max.mutable_data<T>(framework::make_ddim({grid}), ctx.GetPlace()); max.mutable_data<T>(framework::make_ddim({grid}), ctx.GetPlace());
FindAbsMaxKernel<T><<<grid, block, block * sizeof(T), ctx.stream()>>>( FindAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
in_data, num, max_data); in, num, max_data);
FindAbsMaxKernel<T><<<1, block, block * sizeof(T), ctx.stream()>>>( FindAbsMaxKernel<T><<<1, block, 1024 * sizeof(T), ctx.stream()>>>(
max_data, grid, out); max_data, grid, out);
} }
}; };
template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
__global__ void ClipAndQuantKernel(const T* in, const T* scale, __global__ void ClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n, T* out) { const int bin_cnt, const int n, T* out) {
...@@ -88,11 +91,25 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, ...@@ -88,11 +91,25 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
} }
template <typename T> template <typename T>
__global__ void FillScaleArray(T* scale_arr, T* out_scale, const int* it, __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
const int window_size, ) { const T* last_scale,
int tid = threadIdx.x; const int64_t* iter,
const int window_size, T* scale_arr,
T* out_scale, int* need_find_max,
int* out_size) {
int it = iter[0];
int idx = it % window_size; int idx = it % window_size;
// scale_arr[idx] = ; T removed = scale_arr[idx];
T cur = cur_scale[0];
scale_arr[idx] = cur;
T max = last_scale[0];
out_scale[0] = max < cur ? cur : max;
if (fabs(removed - max) < 1e-6) {
need_find_max[0] = 1;
out_size[0] = it > window_size ? window_size : it;
} else {
need_find_max[0] = 0;
}
} }
template <typename T> template <typename T>
...@@ -102,46 +119,44 @@ struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -102,46 +119,44 @@ struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& last_scale, const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size, const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale) { framework::Tensor* scales_arr, framework::Tensor* out_scale) {
T* scale_arr = scales_arr->mutable_data<T>(cxt.GetPlace());
auto& gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace()); auto& gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
int it; T* scale_arr = scales_arr->mutable_data<T>(gpu_place);
memory::Copy(platform::CPUPlace(), &it, gpu_place, iter.data<int>(),
sizeof(int), ctx.stream());
int idx = current_iter % window_size;
T removed;
memory::Copy(platform::CPUPlace(), &removed, gpu_place, scale_arr + idx,
sizeof(T), ctx.stream());
T cur;
memory::Copy(gpu_place, &cur, gpu_place, cur_scale.data<T>(), sizeof(T),
ctx.stream());
T max;
memory::Copy(platform::CPUPlace(), &max, gpu_place, last_scale.data<T>(),
sizeof(T), ctx.stream());
T* out_scale_data = out_scale->mutable_data<T>(gpu_place); T* out_scale_data = out_scale->mutable_data<T>(gpu_place);
if (max < cur) {
max = cur; framework::Tensor need_find_max, out_size;
memory::Copy(gpu_place, out_scale_data, gpu_place, &max, sizeof(T), int* find_max = need_find_max.mutable_data<int>(gpu_place);
ctx.stream()); int* out_size_data = out_size.mutable_data<int>(gpu_place);
} else if (fabs(removed - max) < 1e-6) {
int size = (it > window_size) ? window_size : it; FindRangeAbsMaxAndFillArray<T><<<1, 1, 0, ctx.stream()>>>(
FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size, cur_scale.data<T>(), last_scale.data<T>(), iter.data<int64_t>(),
out_scale_data); window_size, scale_arr, out_scale_data, find_max, out_size_data);
int g_find_max;
memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max,
sizeof(int), 0);
if (g_find_max) {
int len;
memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data,
sizeof(int), 0);
FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
out_scale_data);
} }
} }
}; };
template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const CPUDeviceContext& ctx, const framework::Tensor& in, void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor* scale, const int bin_cnt, const framework::Tensor& in, const framework::Tensor& scale,
framework::Tensor* out) { const int bin_cnt, framework::Tensor* out) {
int num = in.numel(); int num = in.numel();
int block = 1024; int block = 1024;
int grid = (block - 1 + num) / block; int grid = (block - 1 + num) / block;
T* in_data = in.data<T>(); const T* in_data = in.data<T>();
T* scale_data = scale.data<T>(); const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>( ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
...@@ -149,11 +164,14 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -149,11 +164,14 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
} }
}; };
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(fake_quantize, namespace ops = paddle::operators;
paddle::operators::FakeQuantizeCUDAKernel< using CUDA = paddle::platform::CUDADeviceContext;
paddle::platform::CUDADeviceContext, float>, REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
paddle::operators::FakeQuantizeCUDAKernel< ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
paddle::platform::CUDADeviceContext, double>); REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
...@@ -17,9 +17,7 @@ limitations under the License. */ ...@@ -17,9 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -32,51 +30,26 @@ struct FindAbsMaxFunctor { ...@@ -32,51 +30,26 @@ struct FindAbsMaxFunctor {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ClipAndFakeQuantFunctor { struct ClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor* scale, const int bin_cnt, const framework::Tensor& scale, const int bin_cnt,
framework::Tensor* out); framework::Tensor* out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor { struct FindRangeAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale,
const framework::Tensor& cur_scale,
const framework::Tensor& last_scale, const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size, const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale, framework::Tensor* scales_arr, framework::Tensor* out_scale);
framework::Tensor* out);
}; };
void FindRangeAbsMax(const platform::CUDADeviceContext& ctx,
framework::Tensor* scale_list, const T last_max_scale,
const T& cur_scale, int window_size,
int current_iter) const {
T* sl = scale_list->mutable_data<T>(scale_list->place());
T remove_tmp;
auto& gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
int idx = current_iter % window_size;
memory::Copy(platform::CPUPlace(), &remove_tmp, gpu_place, sl + idx,
sizeof(float), ctx.stream());
memory::Copy(gpu_place, sl + idx, platform::CPUPlace(), &cur_scale, sizeof(T),
ctx.stream());
T max_scale = last_max_scale;
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMaxGpu(ctx, scale_list->data<float>(), size));
}
return max_scale;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> { class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
virtual void Compute(const framework::ExecutionContext& context) const { public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* in_scale = context.Input<framework::Tensor>("InScale");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_data = out->mutable_data<T>(context.GetPlace());
T* out_s = out_scale->mutable_data<T>(context.GetPlace()); T* out_s = out_scale->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
...@@ -84,7 +57,7 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -84,7 +57,7 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
const T* in_data = in->data<T>(); const T* in_data = in->data<T>();
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in.numel(), out_s); FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale, ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale,
bin_cnt, out); bin_cnt, out);
} }
...@@ -92,9 +65,10 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -92,9 +65,10 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> { class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
virtual void Compute(const framework::ExecutionContext& context) const { public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* in_scale = context.Input<framework::Tensor>("X"); auto* in_scale = context.Input<framework::Tensor>("InScale");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
...@@ -113,19 +87,19 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -113,19 +87,19 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
// training // training
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto* out_scale = context.Output<framework::Tensor>("OutScale");
auto* in_scales = context.Input<framework::Tensor>("InScales"); auto* out_scales = context.Output<framework::Tensor>("OutScales");
auto* out_scales = context.Input<framework::Tensor>("OutScales");
auto* iter = context.Input<framework::Tensor>("Iter"); auto* iter = context.Input<framework::Tensor>("Iter");
bool window_size = context.Attr<bool>("window_size"); int window_size = context.Attr<int>("window_size");
out_scale->mutable_data<T>(context.GetPlace()); out_scale->mutable_data<T>(context.GetPlace());
Tensor cur_scale; framework::Tensor cur_scale;
T* cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace()); T* cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(), FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(),
cur_scale_data); cur_scale_data);
FindRangeAbsMaxFunctor<DeviceContext, T>()( FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx, cur_scale, *in_scale,
dev_ctx, cur_scale, in_scale, iter, window_size, out_scale, out_scale); *iter, window_size, out_scales,
out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale, ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale,
bin_cnt, out); bin_cnt, out);
} }
......
...@@ -21,28 +21,41 @@ from op_test import OpTest ...@@ -21,28 +21,41 @@ from op_test import OpTest
class TestFakeQuantizeOp(OpTest): class TestFakeQuantizeOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fake_quantize" self.op_type = "fake_quantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {'X': np.random.random((124, 240)).astype("float32"), }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
self.outputs = {
'Out': np.round(self.inputs['X'] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScale': np.array(scale).astype("float32"),
}
def test_check_output(self):
self.check_output()
class TestFakeQuantizeOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_range_abs_max"
self.attrs = { self.attrs = {
'bit_length': 8, 'bit_length': int(5),
'quantize_type': 'abs_max', 'window_size': int(1),
'window_size': 10000 'is_test': False
} }
self.inputs = { self.inputs = {
'X': np.random.random((10, 10)).astype("float32"), 'X': np.random.random((8, 16, 7, 7)).astype("float32"),
'InScales': np.zeros(self.attrs['window_size']).astype("float32"), 'Iter': np.zeros(1).astype("int64"),
'InCurrentIter': np.zeros(1).astype("float32"), 'InScale': np.zeros(1).astype("float32")
'InMovingScale': np.zeros(1).astype("float32")
}
self.scale = {
'abs_max': np.max(np.abs(self.inputs['X'])).astype("float32")
} }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
out_scales = np.zeros(self.attrs['window_size']).astype("float32")
out_scales[0] = scale
self.outputs = { self.outputs = {
'Out': np.round(self.inputs['X'] / self.scale['abs_max'] * ( 'Out': np.round(self.inputs['X'] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)), (1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScales': np.zeros(self.attrs['window_size']).astype("float32"), 'OutScale': scale,
'OutMovingScale': 'OutScales': out_scales,
np.array([self.scale['abs_max']]).astype("float32"),
'OutCurrentIter': np.zeros(1).astype("float32")
} }
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册