未验证 提交 a13490a0 编写于 作者: M Michał Gallus 提交者: GitHub

[DNNL] Fix accuracy in INT8 FC (#22404) (#22410)

test=release/1.7

* Enable quantize to reorder to nchw as well

* Correct FC MKL-DNN input dim requirements to accept 3D

* Improve DNNL FC format, error and 3D input handling

* Improve error checking in FC

* Improve PADDLE_ENFORCE messages in fc-related files

* Remove data layout attribute from obligatory pass args

* Fix message in fc_mkldnn_pass to be logically correct
上级 dcdd18ae
...@@ -66,6 +66,9 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, ...@@ -66,6 +66,9 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
std::vector<std::string>({quantize_out_node->Name()})); std::vector<std::string>({quantize_out_node->Name()}));
q_desc.SetAttr("Scale", scale); q_desc.SetAttr("Scale", scale);
q_desc.SetAttr("is_negative_input", !is_unsigned); q_desc.SetAttr("is_negative_input", !is_unsigned);
q_desc.SetAttr("output_format",
Has("data_layout") ? Get<std::string>("data_layout") : "NHWC");
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
// update op's input // update op's input
......
...@@ -56,14 +56,14 @@ void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const { ...@@ -56,14 +56,14 @@ void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
OpDesc* desc = fc->Op(); OpDesc* desc = fc->Op();
auto dims = fc->inputs[0]->Var()->GetShape(); auto dims = fc->inputs[0]->Var()->GetShape();
auto dim_num = dims.size(); auto dim_num = dims.size();
bool are_dims_supported = dim_num == 2 || dim_num == 4; bool are_dims_supported = dim_num >= 2 && dim_num <= 4;
constexpr size_t height_axis = 2; constexpr size_t height_axis = 2;
constexpr size_t width_axis = 3; constexpr size_t width_axis = 3;
bool is_size_supported = bool is_size_supported =
dim_num == 4 ? (dims[width_axis] == 1 && dims[height_axis] == 1) : true; dim_num == 4 ? (dims[width_axis] == 1 && dims[height_axis] == 1) : true;
if (!are_dims_supported || !is_size_supported) { if (!are_dims_supported || !is_size_supported) {
VLOG(3) << "Do not enable FC MKL-DNN for dimensions different than 2 & 4"; VLOG(3) << "Do not enable FC MKL-DNN for dimensions different than"
VLOG(3) << "Or when width and height are different than one"; "2, 3 & 4, or when width or height is different than one.";
return; return;
} }
desc->SetAttr("use_mkldnn", true); desc->SetAttr("use_mkldnn", true);
......
...@@ -69,11 +69,13 @@ class FCOp : public framework::OperatorWithKernel { ...@@ -69,11 +69,13 @@ class FCOp : public framework::OperatorWithKernel {
activation_type.c_str()); activation_type.c_str());
} }
if (ctx->Attrs().Get<bool>("use_mkldnn")) { if (ctx->Attrs().Get<bool>("use_mkldnn")) {
PADDLE_ENFORCE_EQ(in_dims.size() == 2 || in_dims.size() == 4, true, PADDLE_ENFORCE_EQ(
"Fully Connected input should be 2-D or 4-D tensor."); in_dims.size() >= 2 && in_dims.size() <= 4, true,
platform::errors::Unimplemented(
"Fully Connected input should be 2D, 3D or 4D tensor."));
} }
PADDLE_ENFORCE_EQ(w_dims.size(), 2, PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"Fully Connected input should be 2-D tensor."); "Fully Connected weights should be 2-D tensor.");
int in_num_col_dims = ctx->Attrs().Get<int>("in_num_col_dims"); int in_num_col_dims = ctx->Attrs().Get<int>("in_num_col_dims");
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
in_dims.size(), in_num_col_dims, in_dims.size(), in_num_col_dims,
......
...@@ -54,6 +54,25 @@ class FCPrimitiveFactory { ...@@ -54,6 +54,25 @@ class FCPrimitiveFactory {
return; return;
} // Otherwise, create a new one. } // Otherwise, create a new one.
auto in_col_dims = ctx.Attr<int>("in_num_col_dims");
PADDLE_ENFORCE_LE(in_col_dims, 2,
platform::errors::Unimplemented(
"DNNL FC doesn't support in_num_col_dims paramter to "
"be higher than "
"2."));
if (in_col_dims == 2) {
PADDLE_ENFORCE_EQ(
input->dims().size(), 3,
platform::errors::Unimplemented(
"DNNL FC only supports in_num_col_dims equal to 2 when "
"3 dim input is provided."));
PADDLE_ENFORCE_EQ(
input->format(), MKLDNNMemoryFormat::ncw,
platform::errors::Unimplemented(
"DNNL FC only supports in_num_col_dims equal to 2 when "
"input format is equal to ncw."));
}
// Transform weights to default MKL-DNN format // Transform weights to default MKL-DNN format
weights_ = TransposeWeights(weights); weights_ = TransposeWeights(weights);
// Since MKL-DNN has a lot of limitations on what the input/weights/output // Since MKL-DNN has a lot of limitations on what the input/weights/output
...@@ -121,6 +140,33 @@ class FCPrimitiveFactory { ...@@ -121,6 +140,33 @@ class FCPrimitiveFactory {
} }
private: private:
// DNNL always returns 2-dimensional data block as a result of computing
// inner product. Hence the format 'nc' is always set for its output
// primitive. Therefore, function SetOutputFormat is needed to choose
// an appropriate format based on the number of input dimensions and
// format of an input tensor.
void SetOutputFormat(MKLDNNMemoryFormat in_format, Tensor* out) {
int dim_num = out->dims().size();
// In case of 2 dims, we set the only possible format, nc
if (dim_num == 2) {
out->set_format(MKLDNNMemoryFormat::nc);
// In case of 3 dims, we generate a format that is based on number
// of output dims and the layout of input format (nchw or nhwc).
} else if (dim_num == 3) {
if (in_format == MKLDNNMemoryFormat::nwc ||
in_format == MKLDNNMemoryFormat::nhwc) {
out->set_format(
platform::MKLDNNFormatForSize(dim_num, MKLDNNMemoryFormat::nhwc));
} else {
out->set_format(
platform::MKLDNNFormatForSize(dim_num, MKLDNNMemoryFormat::nchw));
}
// In any other case we overwrite the output format with the input one.
} else {
out->set_format(in_format);
}
}
void UpdateDataPointers(const ExecutionContext& ctx, Tensor* out, void UpdateDataPointers(const ExecutionContext& ctx, Tensor* out,
const Tensor* in) { const Tensor* in) {
input_->set_data_handle(to_void_cast(in->data<T_in>())); input_->set_data_handle(to_void_cast(in->data<T_in>()));
...@@ -129,17 +175,7 @@ class FCPrimitiveFactory { ...@@ -129,17 +175,7 @@ class FCPrimitiveFactory {
// variable, update its format to what has been determined in first // variable, update its format to what has been determined in first
// call to CreateFcPrimitive method. // call to CreateFcPrimitive method.
if (out->format() == MKLDNNMemoryFormat::undef) { if (out->format() == MKLDNNMemoryFormat::undef) {
MKLDNNMemoryFormat format; SetOutputFormat(in->format(), out);
auto data_type = input_->get_desc().data.data_type;
if (data_type == mkldnn_f32)
format = MKLDNNMemoryFormat::nchw;
else
format = MKLDNNMemoryFormat::nhwc;
MKLDNNMemoryFormat selected = platform::MKLDNNFormatForSize(
framework::vectorize<int>(out->dims()).size(), format);
out->set_format(selected);
} }
} }
...@@ -168,8 +204,8 @@ class FCPrimitiveFactory { ...@@ -168,8 +204,8 @@ class FCPrimitiveFactory {
const LoDTensor* input, const Tensor* weights, const Tensor* bias, const LoDTensor* input, const Tensor* weights, const Tensor* bias,
LoDTensor* output, const ExecutionContext& ctx) { LoDTensor* output, const ExecutionContext& ctx) {
auto input_dims = framework::vectorize(input->dims()); auto input_dims = framework::vectorize(input->dims());
std::vector<int64_t> new_input_dims = {input_dims[0] * input_dims[1], 1, std::vector<int64_t> new_input_dims = {input_dims[0] * input_dims[1],
input_dims[2]}; input_dims[2], 1};
auto src_desc = CreateMemDescriptor<T_in>(new_input_dims, input->format()); auto src_desc = CreateMemDescriptor<T_in>(new_input_dims, input->format());
auto weight_dims = Get3DWeightDimsForDNNL(weights); auto weight_dims = Get3DWeightDimsForDNNL(weights);
...@@ -187,7 +223,7 @@ class FCPrimitiveFactory { ...@@ -187,7 +223,7 @@ class FCPrimitiveFactory {
std::vector<int64_t> Get3DWeightDimsForDNNL(const Tensor* weights) { std::vector<int64_t> Get3DWeightDimsForDNNL(const Tensor* weights) {
auto paddle_w_dims = framework::vectorize(weights->dims()); auto paddle_w_dims = framework::vectorize(weights->dims());
return {paddle_w_dims[1], 1, paddle_w_dims[0]}; return {paddle_w_dims[1], paddle_w_dims[0], 1};
} }
memory::desc Create3DUserWeightsDesc(const Tensor* weights) { memory::desc Create3DUserWeightsDesc(const Tensor* weights) {
...@@ -405,18 +441,8 @@ class FCPrimitiveFactory { ...@@ -405,18 +441,8 @@ class FCPrimitiveFactory {
T_out* output_data = T_out* output_data =
output->mutable_data<T_out>(ctx.GetPlace(), buffer_size); output->mutable_data<T_out>(ctx.GetPlace(), buffer_size);
memory dst_mem(dst_desc, engine_, to_void_cast<T_out>(output_data)); memory dst_mem(dst_desc, engine_, to_void_cast<T_out>(output_data));
SetOutputFormat(ctx.Input<LoDTensor>("Input")->format(), output);
MKLDNNMemoryFormat format;
auto data_type = input_->get_desc().data.data_type;
if (data_type == mkldnn_f32)
format = MKLDNNMemoryFormat::nchw;
else
format = MKLDNNMemoryFormat::nhwc;
MKLDNNMemoryFormat selected = platform::MKLDNNFormatForSize(
framework::vectorize<int>(output->dims()).size(), format);
output->set_format(selected);
return dst_mem; return dst_mem;
} }
......
...@@ -60,6 +60,9 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -60,6 +60,9 @@ class QuantOpKernel : public framework::OpKernel<T> {
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim)); reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
if (reorder_p == nullptr) { if (reorder_p == nullptr) {
std::string out_layout = ctx.Attr<std::string>("output_format");
MKLDNNMemoryFormat out_format =
platform::data_format_to_memory_format(out_layout);
mkldnn::primitive_attr attri; mkldnn::primitive_attr attri;
int mask = 0; int mask = 0;
attri.set_output_scales(mask, {scale_data}); attri.set_output_scales(mask, {scale_data});
...@@ -72,10 +75,10 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -72,10 +75,10 @@ class QuantOpKernel : public framework::OpKernel<T> {
std::shared_ptr<mkldnn::memory::desc> dst_md; std::shared_ptr<mkldnn::memory::desc> dst_md;
if (is_negative) { if (is_negative) {
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine, platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory); dst_md, dst_memory, out_format);
} else { } else {
platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine, platform::SetDstMemoryQuantized<uint8_t>(
dst_md, dst_memory); ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
} }
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri)); new reorder::primitive_desc(*src_memory, *dst_memory, attri));
......
...@@ -37,6 +37,9 @@ void QuantOpMaker::Make() { ...@@ -37,6 +37,9 @@ void QuantOpMaker::Make() {
"(bool, default false) Only used in mkldnn INT8 kernel") "(bool, default false) Only used in mkldnn INT8 kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<float>("Scale", "scale data").SetDefault({1.0f}); AddAttr<float>("Scale", "scale data").SetDefault({1.0f});
AddAttr<std::string>("output_format",
"Convert format to NHWC or NCHW during quantization.")
.SetDefault("NHWC");
AddComment(R"DOC(This op will quantize data from FP32 to INT8)DOC"); AddComment(R"DOC(This op will quantize data from FP32 to INT8)DOC");
} }
......
...@@ -1143,13 +1143,14 @@ static void SetDstMemoryQuantized( ...@@ -1143,13 +1143,14 @@ static void SetDstMemoryQuantized(
const framework::ExecutionContext& ctx, framework::Tensor* output, const framework::ExecutionContext& ctx, framework::Tensor* output,
std::vector<int64_t> dst_tz, const mkldnn::engine& engine, std::vector<int64_t> dst_tz, const mkldnn::engine& engine,
std::shared_ptr<mkldnn::memory::desc>& dst_md, // NOLINT std::shared_ptr<mkldnn::memory::desc>& dst_md, // NOLINT
std::shared_ptr<mkldnn::memory>& dst_memory) { // NOLINT std::shared_ptr<mkldnn::memory>& dst_memory, // NOLINT
MKLDNNMemoryFormat output_format) {
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
const size_t dst_dims = dst_tz.size(); const size_t dst_dims = dst_tz.size();
MKLDNNMemoryFormat dst_fmt; MKLDNNMemoryFormat dst_fmt;
PADDLE_ENFORCE_LE(dst_dims, 5, PADDLE_ENFORCE_LE(dst_dims, 5,
"Dst memory for quantization can not have dims > 5"); "Dst memory for quantization can not have dims > 5");
dst_fmt = platform::MKLDNNFormatForSize(dst_dims, MKLDNNMemoryFormat::nhwc); dst_fmt = platform::MKLDNNFormatForSize(dst_dims, output_format);
auto tmp_dst_md = platform::MKLDNNMemDesc( auto tmp_dst_md = platform::MKLDNNMemDesc(
{dst_tz}, paddle::framework::ToMKLDNNDataType( {dst_tz}, paddle::framework::ToMKLDNNDataType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册