提交 bf85cded 编写于 作者: D Dang Qingqing

Refine fake_quantize_op.

上级 18dd1294
......@@ -14,19 +14,79 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h"
#include <string>
#include "paddle/fluid/framework/eigen.h"
namespace paddle {
namespace operators {
template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const CPUDeviceContext& ctx, const T* in, const int num,
T* out) {
ConstEigenVectorArrayMap<T> in_e(in, num);
EigenVectorArrayMap<T> out_e(out, 1);
auto& dev = ctx.eigen_device();
out_e = in_e.abs().maximum();
}
};
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const CPUDeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor* scale, const int bin_cnt,
framework::Tensor* out) {
T s = scale->data<T>()[0];
Transform<DeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
auto in_e = framework::EigenVector<T>::Flatten(in);
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(dev) = (bin_cnt / s * in_e).round();
}
};
template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& cur_scale,
const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale) {
T* scale_arr = scales_arr->mutable_data<T>(cxt.GetPlace());
int it = iter.data<int>()[0];
int idx = it % window_size;
T removd = scale_arr[idx];
T cur = cur_scale.data<T>()[0];
scale_arr[idx] = cur;
T max = last_scale.data<T>()[0];
if (max < cur) {
max = cur;
} else if (fabs(removed - max) < 1e-6) {
int size = (it > window_size) ? window_size : it;
FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size,
&max);
}
out_scale->mutable_data<T>()[0] = max;
}
};
class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeAbsMaxOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
FakeQuantizeAbsMaxOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -49,7 +109,7 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("OutScale", "(Tensor) Current scale");
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int &bit_length) {
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
});
......@@ -64,82 +124,62 @@ $$Out = round(X/scale * range)$$
}
};
class FakeQuantizeOp : public framework::OperatorWithKernel {
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
FakeQuantizeOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"),
"OutMovingScale(Out) of FakeQuantizeOp should not be null");
// if (ctx->HasInput("InMovingScale")) {
ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale"));
//}
// if (ctx->HasInput("InScales")) {
PADDLE_ENFORCE(ctx->HasOutput("OutScales"),
"OutScales(Out) of FakeQuantizeOp should not be null");
ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales"));
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
// ctx->Outputs("OutScales")[0],
// "Mean and MeanOut should share the same memory");
//}
"Input(X) of FakeQuantizeRangeAbsMaxOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutScale"),
"Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null");
if (ctx->HasInput("InScales")) {
PADDLE_ENFORCE(
ctx->HasOutput("OutScales"),
"Output(OutScales) of FakeQuantizeRangeAbsMaxOp should not be null");
ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales"));
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker {
class FakeQuantizeRangeAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of scale operator.");
AddInput("InScales", "(Tensor) scale buffer, used in static quantization.")
.AsDispensable();
AddInput("InMovingScale", "Last scale, used in static quantization.")
.AsDispensable();
AddInput("InCurrentIter",
"Last iteration number, used in static quantization.")
.AsDispensable();
AddInput("X", "(Tensor) Input is float data type.");
AddInput("InScales", "(Tensor) scale buffer.").AsDispensable();
AddInput("InScale", "Last scale.")
AddInput("Iter", "Global step iteration.")
.AsDispensable();
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
AddOutput("OutScales",
"(Tensor) scale buffer, used in static quantization.")
.AsDispensable();
AddOutput("OutMovingScale", " Current scale");
AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable();
AddAttr<std::string>("quantize_type",
"(string, default abs_max)"
"The scaling tpe of the quantize operator.")
.SetDefault("abs_max");
AddAttr<int>("window_size", "(int, default 10000)").SetDefault(10000);
AddAttr<int>("bit_length", "(int, default 8)")
AddOutput("OutScale", " Current scale");
AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable();
AddAttr<int>("window_size", "(int, default 10000) window range size.")
.SetDefault(10000);
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8)
.AddCustomChecker([](const int &bit_length) {
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
});
AddAttr<bool>("is_test", "").SetDefault(false);
AddComment(R"DOC(
FakeQuantize operator
quantize_type = abs_max:
$$scale = max(abs(x))$$
quantize_type = range_abs_max:
$$scale = max(max(abs(x)), history_abs_max)$$
quantize_type = moving_average_abs_max:
$$scale = 0.1*scale+0.9*new_abs_max)$$
FakeQuantize operator is used in static quantization.
$$Out = scale*X$$
$$scale = max(max(abs(x)), history_abs_max)$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
)DOC");
}
......@@ -150,9 +190,18 @@ $$Out = scale*X$$
namespace ops = paddle::operators;
REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker,
REGISTER_OPERATOR(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp,
ops::FakeQuantizeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
fake_quantize_abs_max,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeOp,
ops::FakeQuantizeRangeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
fake_quantize,
fake_quantize_range_abs_max,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -20,245 +20,132 @@ namespace paddle {
namespace operators {
template <typename T>
__global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
extern __shared__ T shared_max_data[];
extern __shared__ T shared_max[];
if (gridDim.x > 1) {
shared_max_data[tid] = T(0);
shared_max[tid] = T(0);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T tmp = fabs(in[i]);
if (tmp > shared_max_data[tid]) {
shared_max_data[tid] = tmp;
if (tmp > shared_max[tid]) {
shared_max[tid] = tmp;
}
}
} else {
if (bid < n) {
shared_max_data[tid] = fabs(in[bid]);
shared_max[tid] = fabs(in[bid]);
} else {
shared_max_data[tid] = T(0);
shared_max[tid] = T(0);
}
}
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i && shared_max_data[tid] < shared_max_data[tid + i]) {
shared_max_data[tid] = shared_max_data[tid + i];
if (tid < i && (shared_max[tid] < shared_max[tid + i])) {
shared_max[tid] = shared_max[tid + i];
}
__syncthreads();
}
if (tid == 0) {
out[blockIdx.x] = shared_max_data[0];
out[blockIdx.x] = shared_max[0];
}
}
float FindAbsMaxGpu(const platform::CUDADeviceContext& ctx, const float* array,
int length) {
float host_max;
int kNumTheads = 1024;
int gridDimx = (kNumTheads - 1 + length) / kNumTheads;
gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx;
framework::Tensor t;
float* device_max = t.mutable_data<float>(framework::make_ddim({gridDimx}),
platform::CUDAPlace());
FindAbsMaxKernel<float><<<gridDimx, kNumTheads, kNumTheads * sizeof(float),
ctx.stream()>>>(length, array, device_max);
FindAbsMaxKernel<
float><<<1, kNumTheads, kNumTheads * sizeof(float), ctx.stream()>>>(
gridDimx, device_max, device_max);
PADDLE_ENFORCE_EQ(
cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost),
cudaSuccess, "cudaMemcpy failed");
return host_max;
}
template <typename T>
struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const CUDADeviceContext& ctx, const T* in, const int num,
T* out) {
int block = 1024;
int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid;
Tensor max;
T* max_data =
max.mutable_data<T>(framework::make_ddim({grid}), ctx.GetPlace());
FindAbsMaxKernel<T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, max_data);
FindAbsMaxKernel<T><<<1, block, block * sizeof(T), ctx.stream()>>>(
max_data, grid, out);
}
};
template <typename T>
__global__ void ApplySaturateKernel(const int n, const T* in, T* out,
int* num_saturate, const T min,
const T max) {
__global__ void ClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n, T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
extern __shared__ int shared_count[];
shared_count[tid] = 0;
T s = scale[0];
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
if (in[i] > max) {
out[i] = max;
shared_count[tid] += 1;
} else if (in[i] < min) {
out[i] = min;
shared_count[tid] += 1;
} else {
out[i] = in[i];
}
}
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i) {
shared_count[tid] += shared_count[tid + i];
}
__syncthreads();
}
if (tid == 0) {
num_saturate[blockIdx.x] = shared_count[0];
T x = in[bid];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt / s * v;
out[bid] = round(v);
}
}
template <typename T>
__global__ void ReduceKernel(const int n, const T* in, T* out) {
__global__ void FillScaleArray(T* scale_arr, T* out_scale, const int* it,
const int window_size, ) {
int tid = threadIdx.x;
extern __shared__ T shared_sum[];
if (tid < n) {
shared_sum[tid] = in[tid];
} else {
shared_sum[tid] = T(0);
}
__syncthreads();
// blockDim.x must >= n
for (int i = (n + 1) / 2; i > 0; i >>= 1) {
if (tid < i) {
shared_sum[tid] += shared_sum[tid + i];
}
__syncthreads();
}
if (tid == 0) {
out[0] = shared_sum[0];
}
int idx = it % window_size;
// scale_arr[idx] = ;
}
template <typename T>
int ApplySaturateGpu(const platform::CUDADeviceContext& ctx, const int n,
const T* in, T* out, const T min, const T max) {
int host_num_saturate;
int kNumTheads = 1024;
int gridDimx = (n + kNumTheads - 1) / kNumTheads;
gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx;
framework::Tensor t;
int* device_num_saturate = t.mutable_data<int>(
framework::make_ddim({gridDimx}), platform::CUDAPlace());
ApplySaturateKernel<
T><<<gridDimx, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>(
n, in, out, device_num_saturate, min, max);
ReduceKernel<int><<<1, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>(
gridDimx, device_num_saturate, device_num_saturate);
PADDLE_ENFORCE_EQ(cudaSuccess,
cudaMemcpy(&host_num_saturate, device_num_saturate,
sizeof(int), cudaMemcpyDeviceToHost),
"cudaMemcpy failed");
return host_num_saturate;
}
template <typename DeviceContext, typename T>
class FakeQuantizeCUDAKernel : public framework::OpKernel<T> {
public:
T FindRangeAbsMax(const platform::CUDADeviceContext& ctx,
framework::Tensor* scale_list, framework::Tensor* out_scale,
const T& cur_scale, int window_size,
int current_iter) const {
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
T remove_tmp = sl[current_iter];
sl[current_iter] = cur_scale;
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMaxGpu(ctx, scale_list->data<float>(), size));
struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& cur_scale,
const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale) {
T* scale_arr = scales_arr->mutable_data<T>(cxt.GetPlace());
auto& gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
int it;
memory::Copy(platform::CPUPlace(), &it, gpu_place, iter.data<int>(),
sizeof(int), ctx.stream());
int idx = current_iter % window_size;
T removed;
memory::Copy(platform::CPUPlace(), &removed, gpu_place, scale_arr + idx,
sizeof(T), ctx.stream());
T cur;
memory::Copy(gpu_place, &cur, gpu_place, cur_scale.data<T>(), sizeof(T),
ctx.stream());
T max;
memory::Copy(platform::CPUPlace(), &max, gpu_place, last_scale.data<T>(),
sizeof(T), ctx.stream());
T* out_scale_data = out_scale->mutable_data<T>(gpu_place);
if (max < cur) {
max = cur;
memory::Copy(gpu_place, out_scale_data, gpu_place, &max, sizeof(T),
ctx.stream());
} else if (fabs(removed - max) < 1e-6) {
int size = (it > window_size) ? window_size : it;
FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size,
out_scale_data);
}
return max_scale;
}
T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
framework::Tensor* out_scale,
const T& cur_scale) const {
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
T* outs = out_scale->mutable_data<T>(platform::CPUPlace());
outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
return T(outs[0]);
}
};
virtual void Compute(const framework::ExecutionContext& context) const {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
auto& device_ctx = context.cuda_device_context();
auto* tensor = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
const bool is_test = context.Attr<bool>("is_test");
tensor->mutable_data<T>(in->place());
context.Output<framework::Tensor>("OutMovingScale")
->mutable_data<T>(
context.Input<framework::Tensor>("InMovingScale")->place());
auto quantize_type =
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
if (quantize_type == std::string("range_abs_max")) {
context.Output<framework::Tensor>("OutScales")
->mutable_data<T>(
context.Input<framework::Tensor>("InScales")->place());
context.Output<framework::Tensor>("OutCurrentIter")
->mutable_data<T>(
context.Input<framework::Tensor>("InCurrentIter")->place());
}
T scale = T(1);
int window_size = context.Attr<int>("window_size");
T bin_cnt = (T)((1 << (context.Attr<int>("bit_length") - 1)) - 1);
if (quantize_type == std::string("abs_max")) {
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale");
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
auto& device_ctx = context.template device_context<DeviceContext>();
auto* scale_list = context.Output<framework::Tensor>("OutScales");
math::SetConstant<DeviceContext, T> scalar;
scale_list->mutable_data<T>(context.GetPlace());
scalar(device_ctx, scale_list, static_cast<T>(0));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
iter->mutable_data<T>(context.GetPlace());
scalar(device_ctx, iter, static_cast<T>(0));
} else if (quantize_type == std::string("range_abs_max")) {
auto* moving_scale = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InMovingScale"));
if (is_test) {
scale = moving_scale->mutable_data<T>(platform::CPUPlace())[0];
} else {
auto* it = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InCurrentIter"));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
int* last_iter = it->mutable_data<int>(platform::CPUPlace());
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
auto* scale_list = context.Output<framework::Tensor>("OutScales");
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
scale = FindRangeAbsMax(device_ctx, scale_list, saving_scale, scale,
window_size, current_iter[0]);
(*current_iter) = (*last_iter) + 1;
}
} else if (quantize_type == std::string("moving_average_abs_max")) {
auto* moving_scale = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InMovingScale"));
if (is_test) {
scale = moving_scale->mutable_data<T>(platform::CPUPlace())[0];
} else {
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
scale = FindMovingAverageAbsMmax(
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
}
}
ApplySaturateGpu<T>(device_ctx, in->numel(), in->data<T>(),
tensor->mutable_data<T>(in->place()), -scale, scale);
scale = bin_cnt / scale;
auto& dev =
*context.template device_context<DeviceContext>().eigen_device();
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor);
eigen_out.device(dev) = (scale * eigen_in).round();
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const CPUDeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor* scale, const int bin_cnt,
framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
T* in_data = in.data<T>();
T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data);
}
};
......
......@@ -25,254 +25,109 @@ namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
public:
T FindAbsMax(framework::Tensor* in, int n) const {
T* p = in->mutable_data<T>(platform::CPUPlace());
T abs_max = (T)0.00000001;
for (int i = 0; i < n; i++) {
T tmp = fabs(p[i]);
if (tmp > abs_max) abs_max = tmp;
}
return T(abs_max);
}
T FindRangeAbsMax(framework::Tensor* scale_list, framework::Tensor* out_scale,
const T& cur_scale, int window_size,
int current_iter) const {
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
T remove_tmp = sl[current_iter];
sl[current_iter] = cur_scale;
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMax(scale_list, size));
}
return max_scale;
}
struct FindAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num, T* out);
};
T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
framework::Tensor* out_scale,
const T& cur_scale) const {
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
T* outs = out_scale->mutable_data<T>(platform::CPUPlace());
outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
return T(outs[0]);
template <typename DeviceContext, typename T>
struct ClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor* scale, const int bin_cnt,
framework::Tensor* out);
};
template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& cur_scale,
const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale,
framework::Tensor* out);
};
void FindRangeAbsMax(const platform::CUDADeviceContext& ctx,
framework::Tensor* scale_list, const T last_max_scale,
const T& cur_scale, int window_size,
int current_iter) const {
T* sl = scale_list->mutable_data<T>(scale_list->place());
T remove_tmp;
auto& gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
int idx = current_iter % window_size;
memory::Copy(platform::CPUPlace(), &remove_tmp, gpu_place, sl + idx,
sizeof(float), ctx.stream());
memory::Copy(gpu_place, sl + idx, platform::CPUPlace(), &cur_scale, sizeof(T),
ctx.stream());
T max_scale = last_max_scale;
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMaxGpu(ctx, scale_list->data<float>(), size));
}
return max_scale;
}
template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
virtual void Compute(const framework::ExecutionContext& context) const {
auto* tensor = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
const bool is_test = context.Attr<bool>("is_test");
tensor->mutable_data<T>(in->place());
auto* oms_tensor = context.Output<framework::Tensor>("OutMovingScale");
oms_tensor->mutable_data<T>(in->place());
auto* in_scale = context.Input<framework::Tensor>("InScale");
auto quantize_type =
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
if (quantize_type == std::string("range_abs_max")) {
auto* oss_tensor = context.Output<framework::Tensor>("OutScales");
oss_tensor->mutable_data<T>(
context.Input<framework::Tensor>("InScales")->place());
auto* oci_tensor = context.Output<framework::Tensor>("OutCurrentIter");
oci_tensor->mutable_data<T>(
context.Input<framework::Tensor>("InCurrentIter")->place());
}
auto* out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_data = out->mutable_data<T>(context.GetPlace());
T* out_s = out_scale->mutable_data<T>(context.GetPlace());
T scale = static_cast<T>(1);
int window_size = context.Attr<int>("window_size");
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev =
*context.template device_context<DeviceContext>().eigen_device();
auto raw_in = framework::EigenVector<T>::Flatten(*in);
if (quantize_type == std::string("abs_max")) {
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = scale_out(0);
auto& device_ctx = context.template device_context<DeviceContext>();
auto* scale_list = context.Output<framework::Tensor>("OutScales");
math::SetConstant<DeviceContext, T> scalar;
scale_list->mutable_data<T>(context.GetPlace());
scalar(device_ctx, scale_list, static_cast<T>(0));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
iter->mutable_data<T>(context.GetPlace());
scalar(device_ctx, iter, static_cast<T>(0));
} else if (quantize_type == std::string("range_abs_max")) {
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
if (is_test) {
scale = moving_scale->data<T>()[0];
} else {
auto* it = context.Input<framework::Tensor>("InCurrentIter");
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
const int* last_iter = it->data<int>();
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
auto* scale_list = context.Output<framework::Tensor>("OutScales");
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
scale = FindRangeAbsMax(scale_list, saving_scale, scale, window_size,
current_iter[0]);
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
(*current_iter) = (*last_iter) + 1;
}
} else if (quantize_type == std::string("moving_average_abs_max")) {
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
if (is_test) {
scale = moving_scale->data<T>()[0];
} else {
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
scale = FindMovingAverageAbsMmax(
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
}
}
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), in->data<T>(),
in->data<T>() + in->numel(), tensor->mutable_data<T>(in->place()),
ClipFunctor<T>(-scale, scale));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor);
eigen_out.device(dev) = (bin_cnt / scale * eigen_in).round();
auto& dev_ctx = context.template device_context<DeviceContext>();
const T* in_data = in->data<T>();
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in.numel(), out_s);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale,
bin_cnt, out);
}
};
using platform::Transform;
template <typename DeviceContext, typename T>
class FakeQuantizeKernel : public framework::OpKernel<T> {
public:
T FindAbsMax(framework::Tensor* in, int n) const {
T* p = in->mutable_data<T>(platform::CPUPlace());
T abs_max = (T)0.00000001;
for (int i = 0; i < n; i++) {
T tmp = fabs(p[i]);
if (tmp > abs_max) abs_max = tmp;
}
return T(abs_max);
}
T FindRangeAbsMax(framework::Tensor* scale_list, framework::Tensor* out_scale,
const T& cur_scale, int window_size,
int current_iter) const {
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
T remove_tmp = sl[current_iter];
sl[current_iter] = cur_scale;
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMax(scale_list, size));
}
return max_scale;
}
T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
framework::Tensor* out_scale,
const T& cur_scale) const {
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
T* outs = out_scale->mutable_data<T>(platform::CPUPlace());
outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
return T(outs[0]);
}
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
virtual void Compute(const framework::ExecutionContext& context) const {
auto* tensor = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
const bool is_test = context.Attr<bool>("is_test");
tensor->mutable_data<T>(in->place());
auto* in_scale = context.Input<framework::Tensor>("X");
auto* oms_tensor = context.Output<framework::Tensor>("OutMovingScale");
oms_tensor->mutable_data<T>(in->place());
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto quantize_type =
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
if (quantize_type == std::string("range_abs_max")) {
auto* oss_tensor = context.Output<framework::Tensor>("OutScales");
oss_tensor->mutable_data<T>(
context.Input<framework::Tensor>("InScales")->place());
auto* oci_tensor = context.Output<framework::Tensor>("OutCurrentIter");
oci_tensor->mutable_data<T>(
context.Input<framework::Tensor>("InCurrentIter")->place());
}
T scale = static_cast<T>(1);
int window_size = context.Attr<int>("window_size");
bool is_test = context.Attr<bool>("is_test");
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto& dev =
*context.template device_context<DeviceContext>().eigen_device();
auto raw_in = framework::EigenVector<T>::Flatten(*in);
if (quantize_type == std::string("abs_max")) {
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = scale_out(0);
auto& device_ctx = context.template device_context<DeviceContext>();
auto* scale_list = context.Output<framework::Tensor>("OutScales");
math::SetConstant<DeviceContext, T> scalar;
scale_list->mutable_data<T>(context.GetPlace());
scalar(device_ctx, scale_list, static_cast<T>(0));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
iter->mutable_data<T>(context.GetPlace());
scalar(device_ctx, iter, static_cast<T>(0));
} else if (quantize_type == std::string("range_abs_max")) {
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
if (is_test) {
scale = moving_scale->data<T>()[0];
} else {
auto* it = context.Input<framework::Tensor>("InCurrentIter");
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
const int* last_iter = it->data<int>();
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
auto* scale_list = context.Output<framework::Tensor>("OutScales");
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
scale = FindRangeAbsMax(scale_list, saving_scale, scale, window_size,
current_iter[0]);
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
(*current_iter) = (*last_iter) + 1;
}
} else if (quantize_type == std::string("moving_average_abs_max")) {
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
if (is_test) {
scale = moving_scale->data<T>()[0];
} else {
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
scale = FindMovingAverageAbsMmax(
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
}
// testing
if (is_test) {
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *in_scale,
bin_cnt, out);
return;
}
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), in->data<T>(),
in->data<T>() + in->numel(), tensor->mutable_data<T>(in->place()),
ClipFunctor<T>(-scale, scale));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor);
eigen_out.device(dev) = (bin_cnt / scale * eigen_in).round();
// training
auto* out_scale = context.Output<framework::Tensor>("OutScale");
auto* in_scales = context.Input<framework::Tensor>("InScales");
auto* out_scales = context.Input<framework::Tensor>("OutScales");
auto* iter = context.Input<framework::Tensor>("Iter");
bool window_size = context.Attr<bool>("window_size");
out_scale->mutable_data<T>(context.GetPlace());
Tensor cur_scale;
T* cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(),
cur_scale_data);
FindRangeAbsMaxFunctor<DeviceContext, T>()(
dev_ctx, cur_scale, in_scale, iter, window_size, out_scale, out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale,
bin_cnt, out);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册