提交 9252e8fa 编写于 作者: S Sylwester Fraczek 提交者: Tao Luo

add int8 mkldnn prior_box (#17242)

add prior_box quantization code

add scale algo rules for prior box

test=develop
上级 5fd68ac1
...@@ -1265,6 +1265,31 @@ PDNode *patterns::ConvConcatReLU::operator()() { ...@@ -1265,6 +1265,31 @@ PDNode *patterns::ConvConcatReLU::operator()() {
return relu_out; return relu_out;
} }
PDNode *patterns::PriorBox::operator()() {
auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
auto input_var = pattern->NewNode(prior_box_input_repr())
->AsInput()
->assert_is_op_input("prior_box", "Input");
auto image_var = pattern->NewNode(prior_box_image_repr())
->AsInput()
->assert_is_op_input("prior_box", "Image");
auto boxes_var = pattern->NewNode(prior_box_boxes_repr())
->AsOutput()
->assert_is_op_output("prior_box", "Boxes");
auto variances_var = pattern->NewNode(prior_box_variances_repr())
->AsOutput()
->assert_is_op_output("prior_box", "Variances");
prior_box_op->LinksFrom({input_var, image_var})
.LinksTo({boxes_var, variances_var});
return boxes_var;
}
std::unordered_set<std::string> conv_act_set({"identity", "relu"}); std::unordered_set<std::string> conv_act_set({"identity", "relu"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
......
...@@ -793,6 +793,23 @@ struct ConvConcatReLU : public PatternBase { ...@@ -793,6 +793,23 @@ struct ConvConcatReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out); PATTERN_DECL_NODE(relu_out);
}; };
// PriorBox operator
// operator: prior_box_op
// inputs: prior_box_input, prior_box_image
// outputs: prior_box_boxes, prior_box_variances
struct PriorBox : public PatternBase {
PriorBox(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "PriorBox") {}
PDNode* operator()();
PATTERN_DECL_NODE(prior_box_op);
PATTERN_DECL_NODE(prior_box_input);
PATTERN_DECL_NODE(prior_box_image);
PATTERN_DECL_NODE(prior_box_boxes);
PATTERN_DECL_NODE(prior_box_variances);
};
// Conv + ElementwiseAdd + an activation // Conv + ElementwiseAdd + an activation
// This pattern can futher fuse the conv related ops after the conv+bn fusion. // This pattern can futher fuse the conv related ops after the conv+bn fusion.
struct ConvElementwiseaddAct : public PatternBase { struct ConvElementwiseaddAct : public PatternBase {
......
...@@ -306,6 +306,45 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { ...@@ -306,6 +306,45 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
PrettyLogDetail("--- quantized %d concat ops", quantize_concat_count); PrettyLogDetail("--- quantized %d concat ops", quantize_concat_count);
} }
void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::PriorBox prior_box_pattern{pattern, name_scope_};
prior_box_pattern();
int quantize_prior_box_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize prior_box op";
GET_IR_NODE_FROM_SUBGRAPH(prior_box_op, prior_box_op, prior_box_pattern);
auto* prior_box_op_desc = prior_box_op->Op();
// skip if should not be quantized
if (!prior_box_op_desc->HasAttr("use_quantizer") ||
!boost::get<bool>(prior_box_op_desc->GetAttr("use_quantizer")))
return;
GET_IR_NODE_FROM_SUBGRAPH(prior_box_input, prior_box_input,
prior_box_pattern);
// get scales calculated after warmup, they scale variables to MAX=1.0
auto scales = Get<VarQuantScale>("quant_var_scales");
auto input_scale = scales[prior_box_input->Name()].second.data<double>()[0];
bool is_input_unsigned = scales[prior_box_input->Name()].first;
QuantizeInput(g, prior_box_op, prior_box_input, "Input", input_scale,
is_input_unsigned);
++quantize_prior_box_count;
};
gpd(graph, handler);
AddStatis(quantize_prior_box_count);
PrettyLogDetail("--- quantized %d prior_box ops",
quantize_prior_box_count);
}
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph."; VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph); PADDLE_ENFORCE(graph);
...@@ -317,6 +356,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -317,6 +356,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeConv(graph, true /* with_residual_data */); QuantizeConv(graph, true /* with_residual_data */);
QuantizePool(graph); QuantizePool(graph);
QuantizeConcat(graph); QuantizeConcat(graph);
QuantizePriorBox(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -50,6 +50,8 @@ class CPUQuantizePass : public FusePassBase { ...@@ -50,6 +50,8 @@ class CPUQuantizePass : public FusePassBase {
void QuantizeConcat(Graph* graph) const; void QuantizeConcat(Graph* graph) const;
void QuantizePriorBox(Graph* graph) const;
void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name,
double scale_to_one, bool is_unsigned, double scale_to_one, bool is_unsigned,
std::string scale_attr_name = "") const; std::string scale_attr_name = "") const;
......
...@@ -29,6 +29,11 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { ...@@ -29,6 +29,11 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["concat"]["X"] = ScaleAlgo::KL; rules_["concat"]["X"] = ScaleAlgo::KL;
rules_["concat"]["Out"] = ScaleAlgo::KL; rules_["concat"]["Out"] = ScaleAlgo::KL;
rules_["prior_box"]["Input"] = ScaleAlgo::KL;
rules_["prior_box"]["Image"] = ScaleAlgo::NONE;
rules_["prior_box"]["Boxes"] = ScaleAlgo::NONE;
rules_["prior_box"]["Variances"] = ScaleAlgo::NONE;
} }
ScaleAlgo MkldnnQuantizerConfig::scale_algo( ScaleAlgo MkldnnQuantizerConfig::scale_algo(
......
...@@ -14,6 +14,10 @@ limitations under the License. */ ...@@ -14,6 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/detection/prior_box_op.h" #include "paddle/fluid/operators/detection/prior_box_op.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -71,8 +75,30 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -71,8 +75,30 @@ class PriorBoxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( auto input_input_type = ctx.Input<framework::Tensor>("Input")->type();
ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
auto input_image_type = ctx.Input<framework::Tensor>("Image")->type();
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_image_type == framework::DataTypeTrait<float>::DataType) {
customized_type_value = kPriorBoxFLOAT;
} else if (input_image_type ==
framework::DataTypeTrait<double>::DataType) {
customized_type_value = kPriorBoxDOUBLE;
}
return framework::OpKernelType(input_input_type, ctx.GetPlace(), layout_,
library_, customized_type_value);
}
#endif
return framework::OpKernelType(input_input_type, ctx.GetPlace(), layout_,
library_);
} }
}; };
...@@ -155,6 +181,15 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -155,6 +181,15 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"Please note, this order affects the weights order of convolution layer" "Please note, this order affects the weights order of convolution layer"
"followed by and does not affect the final detection results.") "followed by and does not affect the final detection results.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("use_quantizer",
"(bool, default false) "
"Set to true for operators that should be quantized and use "
"int8 kernel. "
"Only used on CPU.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Prior box operator Prior box operator
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
...@@ -176,5 +211,35 @@ namespace ops = paddle::operators; ...@@ -176,5 +211,35 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker, REGISTER_OPERATOR(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(prior_box, ops::PriorBoxOpKernel<float>, REGISTER_OP_CPU_KERNEL(prior_box, ops::PriorBoxOpKernel<float, float>,
ops::PriorBoxOpKernel<double>); ops::PriorBoxOpKernel<double, double>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN,
::paddle::platform::CPUPlace, FF,
ops::kPriorBoxFLOAT,
ops::PriorBoxOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN,
::paddle::platform::CPUPlace, DD,
ops::kPriorBoxDOUBLE,
ops::PriorBoxOpKernel<double, double>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN,
::paddle::platform::CPUPlace, U8F,
ops::kPriorBoxFLOAT,
ops::PriorBoxOpKernel<uint8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN,
::paddle::platform::CPUPlace, S8F,
ops::kPriorBoxFLOAT,
ops::PriorBoxOpKernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN,
::paddle::platform::CPUPlace, U8D,
ops::kPriorBoxDOUBLE,
ops::PriorBoxOpKernel<uint8_t, double>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN,
::paddle::platform::CPUPlace, S8D,
ops::kPriorBoxDOUBLE,
ops::PriorBoxOpKernel<int8_t, double>);
...@@ -22,6 +22,9 @@ limitations under the License. */ ...@@ -22,6 +22,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr int kPriorBoxFLOAT = 1;
constexpr int kPriorBoxDOUBLE = 2;
inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior, inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
bool flip, bool flip,
std::vector<float>* output_aspect_ratior) { std::vector<float>* output_aspect_ratior) {
...@@ -46,7 +49,7 @@ inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior, ...@@ -46,7 +49,7 @@ inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
} }
} }
template <typename T> template <typename T, typename K>
class PriorBoxOpKernel : public framework::OpKernel<T> { class PriorBoxOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -67,9 +70,9 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -67,9 +70,9 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
std::vector<float> aspect_ratios; std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
T step_w = static_cast<T>(ctx.Attr<float>("step_w")); K step_w = static_cast<K>(ctx.Attr<float>("step_w"));
T step_h = static_cast<T>(ctx.Attr<float>("step_h")); K step_h = static_cast<K>(ctx.Attr<float>("step_h"));
T offset = static_cast<T>(ctx.Attr<float>("offset")); K offset = static_cast<K>(ctx.Attr<float>("offset"));
auto img_width = image->dims()[3]; auto img_width = image->dims()[3];
auto img_height = image->dims()[2]; auto img_height = image->dims()[2];
...@@ -77,10 +80,10 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -77,10 +80,10 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
auto feature_width = input->dims()[3]; auto feature_width = input->dims()[3];
auto feature_height = input->dims()[2]; auto feature_height = input->dims()[2];
T step_width, step_height; K step_width, step_height;
if (step_w == 0 || step_h == 0) { if (step_w == 0 || step_h == 0) {
step_width = static_cast<T>(img_width) / feature_width; step_width = static_cast<K>(img_width) / feature_width;
step_height = static_cast<T>(img_height) / feature_height; step_height = static_cast<K>(img_height) / feature_height;
} else { } else {
step_width = step_w; step_width = step_w;
step_height = step_h; step_height = step_h;
...@@ -91,15 +94,15 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -91,15 +94,15 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
num_priors += max_sizes.size(); num_priors += max_sizes.size();
} }
boxes->mutable_data<T>(ctx.GetPlace()); boxes->mutable_data<K>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace()); vars->mutable_data<K>(ctx.GetPlace());
T* b_t = boxes->data<T>(); K* b_t = boxes->data<K>();
for (int h = 0; h < feature_height; ++h) { for (int h = 0; h < feature_height; ++h) {
for (int w = 0; w < feature_width; ++w) { for (int w = 0; w < feature_width; ++w) {
T center_x = (w + offset) * step_width; K center_x = (w + offset) * step_width;
T center_y = (h + offset) * step_height; K center_y = (h + offset) * step_height;
T box_width, box_height; K box_width, box_height;
for (size_t s = 0; s < min_sizes.size(); ++s) { for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s]; auto min_size = min_sizes[s];
if (min_max_aspect_ratios_order) { if (min_max_aspect_ratios_order) {
...@@ -161,17 +164,17 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -161,17 +164,17 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
} }
if (clip) { if (clip) {
T* dt = boxes->data<T>(); K* dt = boxes->data<K>();
std::transform(dt, dt + boxes->numel(), dt, [](T v) -> T { std::transform(dt, dt + boxes->numel(), dt, [](K v) -> K {
return std::min<T>(std::max<T>(v, 0.), 1.); return std::min<K>(std::max<K>(v, 0.), 1.);
}); });
} }
framework::Tensor var_t; framework::Tensor var_t;
var_t.mutable_data<T>( var_t.mutable_data<K>(
framework::make_ddim({1, static_cast<int>(variances.size())}), framework::make_ddim({1, static_cast<int>(variances.size())}),
ctx.GetPlace()); ctx.GetPlace());
auto var_et = framework::EigenTensor<T, 2>::From(var_t); auto var_et = framework::EigenTensor<K, 2>::From(var_t);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
...@@ -184,7 +187,7 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -184,7 +187,7 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
auto var_dim = vars->dims(); auto var_dim = vars->dims();
vars->Resize({box_num, static_cast<int>(variances.size())}); vars->Resize({box_num, static_cast<int>(variances.size())});
auto e_vars = framework::EigenMatrix<T, Eigen::RowMajor>::From(*vars); auto e_vars = framework::EigenMatrix<K, Eigen::RowMajor>::From(*vars);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册