From 9252e8fa0886ee56bf5ce2ef6506aa309db58cb5 Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Thu, 27 Jun 2019 04:12:49 +0200 Subject: [PATCH] add int8 mkldnn prior_box (#17242) add prior_box quantization code add scale algo rules for prior box test=develop --- .../framework/ir/graph_pattern_detector.cc | 25 +++++++ .../framework/ir/graph_pattern_detector.h | 17 +++++ .../framework/ir/mkldnn/cpu_quantize_pass.cc | 40 ++++++++++ .../framework/ir/mkldnn/cpu_quantize_pass.h | 2 + .../inference/api/mkldnn_quantizer_config.cc | 5 ++ .../fluid/operators/detection/prior_box_op.cc | 73 ++++++++++++++++++- .../fluid/operators/detection/prior_box_op.h | 41 ++++++----- 7 files changed, 180 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 15b3429ef17..1ed07efc883 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1265,6 +1265,31 @@ PDNode *patterns::ConvConcatReLU::operator()() { return relu_out; } +PDNode *patterns::PriorBox::operator()() { + auto prior_box_op = + pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box"); + + auto input_var = pattern->NewNode(prior_box_input_repr()) + ->AsInput() + ->assert_is_op_input("prior_box", "Input"); + + auto image_var = pattern->NewNode(prior_box_image_repr()) + ->AsInput() + ->assert_is_op_input("prior_box", "Image"); + + auto boxes_var = pattern->NewNode(prior_box_boxes_repr()) + ->AsOutput() + ->assert_is_op_output("prior_box", "Boxes"); + + auto variances_var = pattern->NewNode(prior_box_variances_repr()) + ->AsOutput() + ->assert_is_op_output("prior_box", "Variances"); + + prior_box_op->LinksFrom({input_var, image_var}) + .LinksTo({boxes_var, variances_var}); + return boxes_var; +} + std::unordered_set conv_act_set({"identity", "relu"}); PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 1c53b910522..c9f12ddf4a0 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -793,6 +793,23 @@ struct ConvConcatReLU : public PatternBase { PATTERN_DECL_NODE(relu_out); }; +// PriorBox operator +// operator: prior_box_op +// inputs: prior_box_input, prior_box_image +// outputs: prior_box_boxes, prior_box_variances +struct PriorBox : public PatternBase { + PriorBox(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "PriorBox") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(prior_box_op); + PATTERN_DECL_NODE(prior_box_input); + PATTERN_DECL_NODE(prior_box_image); + PATTERN_DECL_NODE(prior_box_boxes); + PATTERN_DECL_NODE(prior_box_variances); +}; + // Conv + ElementwiseAdd + an activation // This pattern can futher fuse the conv related ops after the conv+bn fusion. struct ConvElementwiseaddAct : public PatternBase { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index dd3ee50e040..c7eb0587014 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -306,6 +306,45 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { PrettyLogDetail("--- quantized %d concat ops", quantize_concat_count); } +void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + patterns::PriorBox prior_box_pattern{pattern, name_scope_}; + prior_box_pattern(); + + int quantize_prior_box_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Quantize prior_box op"; + GET_IR_NODE_FROM_SUBGRAPH(prior_box_op, prior_box_op, prior_box_pattern); + auto* prior_box_op_desc = prior_box_op->Op(); + + // skip if should not be quantized + if (!prior_box_op_desc->HasAttr("use_quantizer") || + !boost::get(prior_box_op_desc->GetAttr("use_quantizer"))) + return; + + GET_IR_NODE_FROM_SUBGRAPH(prior_box_input, prior_box_input, + prior_box_pattern); + + // get scales calculated after warmup, they scale variables to MAX=1.0 + auto scales = Get("quant_var_scales"); + + auto input_scale = scales[prior_box_input->Name()].second.data()[0]; + bool is_input_unsigned = scales[prior_box_input->Name()].first; + QuantizeInput(g, prior_box_op, prior_box_input, "Input", input_scale, + is_input_unsigned); + + ++quantize_prior_box_count; + }; + + gpd(graph, handler); + AddStatis(quantize_prior_box_count); + + PrettyLogDetail("--- quantized %d prior_box ops", + quantize_prior_box_count); +} + void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Quantizing the graph."; PADDLE_ENFORCE(graph); @@ -317,6 +356,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { QuantizeConv(graph, true /* with_residual_data */); QuantizePool(graph); QuantizeConcat(graph); + QuantizePriorBox(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index 61a28fd3131..ec4db66240c 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -50,6 +50,8 @@ class CPUQuantizePass : public FusePassBase { void QuantizeConcat(Graph* graph) const; + void QuantizePriorBox(Graph* graph) const; + void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, double scale_to_one, bool is_unsigned, std::string scale_attr_name = "") const; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc index a7cb785fe95..c2b2ba0b60a 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc @@ -29,6 +29,11 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { rules_["concat"]["X"] = ScaleAlgo::KL; rules_["concat"]["Out"] = ScaleAlgo::KL; + + rules_["prior_box"]["Input"] = ScaleAlgo::KL; + rules_["prior_box"]["Image"] = ScaleAlgo::NONE; + rules_["prior_box"]["Boxes"] = ScaleAlgo::NONE; + rules_["prior_box"]["Variances"] = ScaleAlgo::NONE; } ScaleAlgo MkldnnQuantizerConfig::scale_algo( diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index 3e75c0394f9..34ddacb6f54 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -14,6 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/detection/prior_box_op.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + namespace paddle { namespace operators { @@ -71,8 +75,30 @@ class PriorBoxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + auto input_input_type = ctx.Input("Input")->type(); + + framework::LibraryType library_{framework::LibraryType::kPlain}; + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + auto input_image_type = ctx.Input("Image")->type(); + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; + if (input_image_type == framework::DataTypeTrait::DataType) { + customized_type_value = kPriorBoxFLOAT; + } else if (input_image_type == + framework::DataTypeTrait::DataType) { + customized_type_value = kPriorBoxDOUBLE; + } + return framework::OpKernelType(input_input_type, ctx.GetPlace(), layout_, + library_, customized_type_value); + } +#endif + return framework::OpKernelType(input_input_type, ctx.GetPlace(), layout_, + library_); } }; @@ -155,6 +181,15 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { "Please note, this order affects the weights order of convolution layer" "followed by and does not affect the final detection results.") .SetDefault(false); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("use_quantizer", + "(bool, default false) " + "Set to true for operators that should be quantized and use " + "int8 kernel. " + "Only used on CPU.") + .SetDefault(false); AddComment(R"DOC( Prior box operator Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. @@ -176,5 +211,35 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(prior_box, ops::PriorBoxOpKernel, - ops::PriorBoxOpKernel); +REGISTER_OP_CPU_KERNEL(prior_box, ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN, + ::paddle::platform::CPUPlace, FF, + ops::kPriorBoxFLOAT, + ops::PriorBoxOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN, + ::paddle::platform::CPUPlace, DD, + ops::kPriorBoxDOUBLE, + ops::PriorBoxOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN, + ::paddle::platform::CPUPlace, U8F, + ops::kPriorBoxFLOAT, + ops::PriorBoxOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN, + ::paddle::platform::CPUPlace, S8F, + ops::kPriorBoxFLOAT, + ops::PriorBoxOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN, + ::paddle::platform::CPUPlace, U8D, + ops::kPriorBoxDOUBLE, + ops::PriorBoxOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN, + ::paddle::platform::CPUPlace, S8D, + ops::kPriorBoxDOUBLE, + ops::PriorBoxOpKernel); diff --git a/paddle/fluid/operators/detection/prior_box_op.h b/paddle/fluid/operators/detection/prior_box_op.h index d3e26256b50..71c67b44eaf 100644 --- a/paddle/fluid/operators/detection/prior_box_op.h +++ b/paddle/fluid/operators/detection/prior_box_op.h @@ -22,6 +22,9 @@ limitations under the License. */ namespace paddle { namespace operators { +constexpr int kPriorBoxFLOAT = 1; +constexpr int kPriorBoxDOUBLE = 2; + inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, bool flip, std::vector* output_aspect_ratior) { @@ -46,7 +49,7 @@ inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, } } -template +template class PriorBoxOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -67,9 +70,9 @@ class PriorBoxOpKernel : public framework::OpKernel { std::vector aspect_ratios; ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); - T step_w = static_cast(ctx.Attr("step_w")); - T step_h = static_cast(ctx.Attr("step_h")); - T offset = static_cast(ctx.Attr("offset")); + K step_w = static_cast(ctx.Attr("step_w")); + K step_h = static_cast(ctx.Attr("step_h")); + K offset = static_cast(ctx.Attr("offset")); auto img_width = image->dims()[3]; auto img_height = image->dims()[2]; @@ -77,10 +80,10 @@ class PriorBoxOpKernel : public framework::OpKernel { auto feature_width = input->dims()[3]; auto feature_height = input->dims()[2]; - T step_width, step_height; + K step_width, step_height; if (step_w == 0 || step_h == 0) { - step_width = static_cast(img_width) / feature_width; - step_height = static_cast(img_height) / feature_height; + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; } else { step_width = step_w; step_height = step_h; @@ -91,15 +94,15 @@ class PriorBoxOpKernel : public framework::OpKernel { num_priors += max_sizes.size(); } - boxes->mutable_data(ctx.GetPlace()); - vars->mutable_data(ctx.GetPlace()); + boxes->mutable_data(ctx.GetPlace()); + vars->mutable_data(ctx.GetPlace()); - T* b_t = boxes->data(); + K* b_t = boxes->data(); for (int h = 0; h < feature_height; ++h) { for (int w = 0; w < feature_width; ++w) { - T center_x = (w + offset) * step_width; - T center_y = (h + offset) * step_height; - T box_width, box_height; + K center_x = (w + offset) * step_width; + K center_y = (h + offset) * step_height; + K box_width, box_height; for (size_t s = 0; s < min_sizes.size(); ++s) { auto min_size = min_sizes[s]; if (min_max_aspect_ratios_order) { @@ -161,17 +164,17 @@ class PriorBoxOpKernel : public framework::OpKernel { } if (clip) { - T* dt = boxes->data(); - std::transform(dt, dt + boxes->numel(), dt, [](T v) -> T { - return std::min(std::max(v, 0.), 1.); + K* dt = boxes->data(); + std::transform(dt, dt + boxes->numel(), dt, [](K v) -> K { + return std::min(std::max(v, 0.), 1.); }); } framework::Tensor var_t; - var_t.mutable_data( + var_t.mutable_data( framework::make_ddim({1, static_cast(variances.size())}), ctx.GetPlace()); - auto var_et = framework::EigenTensor::From(var_t); + auto var_et = framework::EigenTensor::From(var_t); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for @@ -184,7 +187,7 @@ class PriorBoxOpKernel : public framework::OpKernel { auto var_dim = vars->dims(); vars->Resize({box_num, static_cast(variances.size())}); - auto e_vars = framework::EigenMatrix::From(*vars); + auto e_vars = framework::EigenMatrix::From(*vars); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for collapse(2) -- GitLab