提交 b60124e8 编写于 作者: Z Zhang, Guoming

Merge branch 'prv-calibration'

...@@ -132,6 +132,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -132,6 +132,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::memory> user_src_memory_p; std::shared_ptr<mkldnn::memory> user_src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p; std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
std::shared_ptr<platform::ConvMKLDNNHandler> handler;
auto prim_key = key + "@conv_p"; auto prim_key = key + "@conv_p";
auto dst_key = key + "@dst_mem_p"; auto dst_key = key + "@dst_mem_p";
...@@ -139,6 +141,44 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -139,6 +141,44 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_src_key = key + "@user_src_mem_p"; auto user_src_key = key + "@user_src_mem_p";
auto src_reorder_key = key + "@src_mem_p" + "reorder_p"; auto src_reorder_key = key + "@src_mem_p" + "reorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(dev_ctx.GetBlob(prim_key)); conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(dev_ctx.GetBlob(prim_key));
if(conv_p == nullptr){
if(is_INT8){
CreateINT8Primitive(ctx, is_test, dev_ctx, mkldnn_engine, input, //filter,
bias, output,
strides, paddings,
dilations, fuse_relu,
fuse_residual_conn, input_data,
filter_data, src_tz,
weights_tz, g,
dst_tz, key,
dst_memory_p,
pipeline,
key_conv_pd,
src_memory_p,
user_src_memory_p,
conv_p,
conv_pd,
handler,
force_fp32_output);
}else{
CreateFP32Primitive(ctx, is_test, dev_ctx, mkldnn_engine, input, //filter,
bias, output,
strides, paddings,
dilations, fuse_relu,
fuse_residual_conn, input_data,
filter_data, src_tz,
weights_tz, g,
dst_tz, key,
dst_memory_p,
pipeline,
key_conv_pd,
src_memory_p,
user_src_memory_p,
conv_p,
conv_pd,
handler);
}
} else {
auto src_memory_reorder_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_reorder_key)); auto src_memory_reorder_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_reorder_key));
src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key)); src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
if(src_memory_reorder_p){ if(src_memory_reorder_p){
...@@ -149,14 +189,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -149,14 +189,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key)); dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
conv_pd = std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(dev_ctx.GetBlob(key_conv_pd)); conv_pd = std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(dev_ctx.GetBlob(key_conv_pd));
std::shared_ptr<platform::ConvMKLDNNHandler> handler;
if(conv_pd){ if(conv_pd){
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key)); handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key));
} }
if (!is_INT8 && dst_memory_p){ if (!is_INT8){
if (fuse_residual_conn) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>(); auto residual_param_data = residual_param->data<T>();
...@@ -184,7 +221,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -184,7 +221,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<T>(output_data)); dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
} }
} else if(is_INT8 && dst_memory_p){ } else if(is_INT8){
if(fuse_residual_conn) { if(fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type()); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type());
...@@ -210,8 +247,48 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -210,8 +247,48 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
} }
if(!is_INT8){ if(src_memory_reorder_p){
if(conv_p == nullptr){ pipeline.push_back(*src_memory_reorder_p);
}
pipeline.push_back(*conv_p);
}
// push primitive to stream and wait until it's executed
//pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
if (need_s8_to_u8) {
output->mutable_data<uint8_t>(ctx.GetPlace());
}
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
};
private:
void CreateFP32Primitive(
paddle::framework::ExecutionContext ctx, bool is_test,
const paddle::platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& mkldnn_engine,
const paddle::framework::Tensor* input,// const paddle::framework::Tensor* filter,
const paddle::framework::Tensor* bias, paddle::framework::Tensor* output,
std::vector<int> strides, std::vector<int> paddings,
std::vector<int> dilations, bool fuse_relu,
bool fuse_residual_conn, const T* input_data,
const float* filter_data, std::vector<int> src_tz,
std::vector<int> weights_tz, int g,
std::vector<int> dst_tz, const std::string key,
std::shared_ptr<mkldnn::memory> &dst_memory_p,
std::vector<primitive>& pipeline,
const std::string &key_conv_pd,
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> user_src_memory_p,
std::shared_ptr<mkldnn::convolution_forward> conv_p,
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd,
std::shared_ptr<platform::ConvMKLDNNHandler> handler) const{
//const T* input_data = input->data<T>();
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format()); {src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
auto user_weights_md = platform::MKLDNNMemDesc( auto user_weights_md = platform::MKLDNNMemDesc(
...@@ -322,28 +399,37 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -322,28 +399,37 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); };
output->set_layout(DataLayout::kMKLDNN); void CreateINT8Primitive(
output->set_format(GetMKLDNNFormat(*dst_memory_p)); const paddle::framework::ExecutionContext& ctx, bool is_test,
} else { const paddle::platform::MKLDNNDeviceContext & dev_ctx,
if(src_memory_reorder_p){ const mkldnn::engine & mkldnn_engine,
pipeline.push_back(*src_memory_reorder_p); const paddle::framework::Tensor* input, //const paddle::framework::Tensor* filter,
} const paddle::framework::Tensor* bias, paddle::framework::Tensor* output,
pipeline.push_back(*conv_p); std::vector<int> strides, std::vector<int> paddings,
stream(stream::kind::eager).submit(pipeline).wait(); std::vector<int> dilations, bool fuse_relu,
bool fuse_residual_conn, const T* input_data,
output->set_layout(DataLayout::kMKLDNN); const float* filter_data, std::vector<int> src_tz,
output->set_format(GetMKLDNNFormat(*dst_memory_p)); std::vector<int> weights_tz, int g,
} std::vector<int> dst_tz, const std::string key,
} else{ std::shared_ptr<mkldnn::memory>& dst_memory_p,
if(conv_p == nullptr){ std::vector<primitive>& pipeline,
auto* scale_in = ctx.HasInput("Scale_in") ? ctx.Input<Tensor>("Scale_in") : nullptr; const std::string &key_conv_pd,
auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr; std::shared_ptr<mkldnn::memory> src_memory_p,
auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input<Tensor>("Scale_weights") : nullptr; std::shared_ptr<mkldnn::memory> user_src_memory_p,
auto* scale_out = ctx.HasInput("Scale_out")? ctx.Input<Tensor>("Scale_out") : nullptr; std::shared_ptr<mkldnn::convolution_forward> conv_p,
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd,
bool is_multi_channel = (scale_weights->memory_size() > 1) ? true : false; std::shared_ptr<platform::ConvMKLDNNHandler> handler,
bool force_fp32_output) const {
//const T* input_data = input->data<T>();
bool is_INT8 = true;
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
auto scale_out_data = force_fp32_output? 1.0f : ctx.Attr<float>("Scale_out");
bool is_multi_channel = scale_weights_data.size() > 1 ? true : false;
auto scale_in_key = key + "@scale_in"; auto scale_in_key = key + "@scale_in";
auto scale_weights_key = key + "@scale_weights"; auto scale_weights_key = key + "@scale_weights";
...@@ -351,38 +437,35 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -351,38 +437,35 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto output_shift_scale_key = key + "@output_shift_scale"; auto output_shift_scale_key = key + "@output_shift_scale";
auto sum_scale_key = key + "@sum_scale"; auto sum_scale_key = key + "@sum_scale";
auto scale_in_eltwise_key = key + "@scale_in_eltwise"; auto scale_in_eltwise_key = key + "@scale_in_eltwise";
std::vector<float> scale_in_data; //std::vector<float> scale_in_data;
std::vector<float> scale_out_data = {1.0f}; //std::vector<float> scale_out_data = {1.0f};
std::vector<float> scale_weights_data; //std::vector<float> scale_weights_data;
std::vector<float> scale_in_eltwise_data; //std::vector<float> scale_in_eltwise_data;
std::vector<float> output_shift_scale; std::vector<float> output_shift_scale;
std::vector<float> sum_scale = {1.0f}; float sum_scale = 1.0f;
std::vector<float> none_scale = {0};
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
scale_in_data = {*(scale_in->data<float>())}; //scale_in_data = {scale_in};
scale_weights_data.resize(count); //scale_weights_data.resize(count);
#pragma omp parallel for if (count > 1) //#pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){ //for(int i=0; i<count; i++){
scale_weights_data[i] =*(scale_weights->data<float>() + i); //scale_weights_data[i] =*(scale_weights->data<float>() + i);
} //}
if(!force_fp32_output) //if(!force_fp32_output)
scale_out_data = {*(scale_out->data<float>())}; //scale_out_data = {*(scale_out->data<float>())};
output_shift_scale.resize(count); output_shift_scale.resize(count);
#pragma omp parallel for if (count > 1) #pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
if(scale_weights_data[i] == 0.0) if(scale_weights_data[i] == 0.0)
output_shift_scale[i] = scale_out_data[0]; output_shift_scale[i] = scale_out_data;
else else
output_shift_scale[i] = scale_out_data[0] / (scale_in_data[0] * scale_weights_data[i]); output_shift_scale[i] = scale_out_data / (scale_in_data * scale_weights_data[i]);
} }
if(fuse_residual_conn){ if(fuse_residual_conn){
scale_in_eltwise_data = {*(scale_in_eltwise->data<float>())}; //scale_in_eltwise_data = {*(scale_in_eltwise->data<float>())};
sum_scale[0] = scale_out_data[0] / scale_in_eltwise_data[0]; sum_scale = scale_out_data / scale_in_eltwise_data;
} }
std::vector<primitive> pipeline;
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, paddle::framework::ToMKLDNNDataType(input->type()), input->format()); {src_tz}, paddle::framework::ToMKLDNNDataType(input->type()), input->format());
auto user_weights_md = platform::MKLDNNMemDesc( auto user_weights_md = platform::MKLDNNMemDesc(
...@@ -427,12 +510,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -427,12 +510,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0], is_test); output_shift_scale, sum_scale, is_test);
} else { } else {
conv_pd = conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, mkldnn_engine, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0], is_test); output_shift_scale, sum_scale, is_test);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -503,7 +586,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -503,7 +586,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
scale_bias_data.resize(count); scale_bias_data.resize(count);
#pragma omp parallel for if (count > 1) #pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_bias_data[i] = scale_in_data[0] * scale_weights_data[i]; scale_bias_data[i] = scale_in_data * scale_weights_data[i];
} }
bias_memory_p = bias_memory_p =
handler->AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_test, is_INT8, scale_bias_data, mask_reorder); handler->AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_test, is_INT8, scale_bias_data, mask_reorder);
...@@ -514,34 +597,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -514,34 +597,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p); dst_memory_p);
} }
// push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
if(need_s8_to_u8){ // push primitive to stream and wait until it's executed
output->mutable_data<uint8_t>(ctx.GetPlace());
}
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else {
if(src_memory_reorder_p){
pipeline.push_back(*src_memory_reorder_p);
}
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); };
if (need_s8_to_u8) {
output->mutable_data<uint8_t>(ctx.GetPlace());
}
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
}
}
}
private:
void AppendKey(std::string& key, mkldnn::memory::dims& input_dims, // NOLINT void AppendKey(std::string& key, mkldnn::memory::dims& input_dims, // NOLINT
mkldnn::memory::dims& weights_dims, // NOLINT mkldnn::memory::dims& weights_dims, // NOLINT
std::vector<int>& strides, // NOLINT std::vector<int>& strides, // NOLINT
......
...@@ -131,21 +131,14 @@ void Conv2DOpMaker::Make() { ...@@ -131,21 +131,14 @@ void Conv2DOpMaker::Make() {
"The format of output tensor is X (one-dimensional) of size equal" "The format of output tensor is X (one-dimensional) of size equal"
"to the number of output channels. Only used with MKL-DNN.") "to the number of output channels. Only used with MKL-DNN.")
.AsDispensable(); .AsDispensable();
AddInput("Scale_in", AddOutput("Output",
"(Tensor) Scale_in to be used for int8 input data." "(Tensor) The output tensor of convolution operator. "
"Only used with INT8.") "The format of output tensor is also NCHW.");
.AsDispensable();
AddInput("Scale_in_eltwise", AddInput("ResidualData",
"(Tensor) Scale_in_eltwise to be used for int8 eltwise input data." "(Tensor) Tensor with residual data "
"Only used with MKL-DNN.") "to which convolution output will be added."
.AsDispensable(); "Used with fuse_residual_connection fusion.")
AddInput("Scale_weights",
"(Tensor) Scale_weights to be used for int8 weights data."
"Only used with MKL-DNN.")
.AsDispensable();
AddInput("Scale_out",
"(Tensor) Scale_out to be used for int8 output data."
"Only used with MKL-DNN.")
.AsDispensable(); .AsDispensable();
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator. " "(Tensor) The output tensor of convolution operator. "
...@@ -193,6 +186,22 @@ void Conv2DOpMaker::Make() { ...@@ -193,6 +186,22 @@ void Conv2DOpMaker::Make() {
"whenever convolution output is as an input to residual " "whenever convolution output is as an input to residual "
"connection.") "connection.")
.SetDefault(false); .SetDefault(false);
AddAttr<float>("Scale_in",
"Scale_in to be used for int8 input data."
"Only used with INT8.")
.SetDefault(1.0f);
AddAttr<float>("Scale_out",
"Scale_out to be used for int8 output data."
"Only used with MKL-DNN.")
.SetDefault(1.0f);
AddAttr<float>("Scale_in_eltwise",
"Scale_in_eltwise to be used for int8 eltwise input data."
"Only used with MKL-DNN.")
.SetDefault(1.0f);
AddAttr<std::vector<float>>("Scale_weights",
"Scale_weights to be used for int8 weights data."
"Only used with MKL-DNN.")
.SetDefault({1.0f});
AddAttr<bool>("force_fp32_output", "(bool, default false) Force INT8 kernel output FP32, only used in mkldnn kernel") AddAttr<bool>("force_fp32_output", "(bool, default false) Force INT8 kernel output FP32, only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
......
...@@ -37,7 +37,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input"); auto* input = ctx.Input<Tensor>("Input");
auto* scale = ctx.Input<Tensor>("Scale"); auto scale_data = ctx.Attr<float>("Scale");
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
...@@ -45,8 +45,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -45,8 +45,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
float* output_data = output->mutable_data<float>(ctx.GetPlace()); float* output_data = output->mutable_data<float>(ctx.GetPlace());
std::vector<float> scale_data = {*(scale->data<float>())}; std::vector<float> reorder_scale = {1.0f / scale_data};
std::vector<float> reorder_scale = {1.0f / scale_data[0]};
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
...@@ -99,8 +98,8 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::Execut ...@@ -99,8 +98,8 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::Execut
void DeQuantOpMaker::Make() { void DeQuantOpMaker::Make() {
AddInput("Input","input data"); AddInput("Input","input data");
AddInput("Scale","scale data");
AddOutput("Output","output data"); AddOutput("Output","output data");
AddAttr<float>("Scale","scale data").SetDefault({1.0f});
AddComment(R"DOC(This op will quantize data from INT8 to FP32)DOC"); AddComment(R"DOC(This op will quantize data from INT8 to FP32)DOC");
} }
......
...@@ -35,7 +35,7 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input"); auto* input = ctx.Input<Tensor>("Input");
auto* scale = ctx.Input<Tensor>("Scale"); auto scale_data = ctx.Attr<float>("Scale");
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
...@@ -47,11 +47,9 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -47,11 +47,9 @@ class QuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
std::vector<T> scale_data = {*(scale->data<T>())};
mkldnn::primitive_attr attri; mkldnn::primitive_attr attri;
int mask = 0; int mask = 0;
attri.set_output_scales(mask, scale_data); attri.set_output_scales(mask, {scale_data});
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
{src_tz}, memory::data_type::f32, input->format()); {src_tz}, memory::data_type::f32, input->format());
...@@ -108,11 +106,12 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::Executio ...@@ -108,11 +106,12 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::Executio
void QuantOpMaker::Make() { void QuantOpMaker::Make() {
AddInput("Input","input data"); AddInput("Input","input data");
AddInput("Scale","scale data");
AddOutput("Output","output data"); AddOutput("Output","output data");
AddAttr<bool>("is_negative_input", AddAttr<bool>("is_negative_input",
"(bool, default false) Only used in mkldnn INT8 kernel") "(bool, default false) Only used in mkldnn INT8 kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<float>("Scale","scale data")
.SetDefault({1.0f});
AddComment(R"DOC(This op will quantize data from FP32 to INT8)DOC"); AddComment(R"DOC(This op will quantize data from FP32 to INT8)DOC");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册