From 61403f879038c0f59a1ddecee191dc6850e02106 Mon Sep 17 00:00:00 2001 From: wozna Date: Mon, 26 Aug 2019 09:51:09 +0000 Subject: [PATCH] Add transpose2 INT8 for mkl-dnn test=develop --- cmake/operators.cmake | 5 +- .../framework/ir/graph_pattern_detector.cc | 21 +++ .../framework/ir/graph_pattern_detector.h | 15 ++ .../framework/ir/mkldnn/cpu_quantize_pass.cc | 60 +++++++ .../framework/ir/mkldnn/cpu_quantize_pass.h | 2 + .../ir/mkldnn/cpu_quantize_pass_tester.cc | 152 ++++++++++++------ .../inference/api/mkldnn_quantizer_config.cc | 3 + .../analyzer_int8_object_detection_tester.cc | 2 +- .../operators/mkldnn/transpose_mkldnn_op.cc | 45 ++++-- paddle/fluid/operators/transpose_op.cc | 31 +++- paddle/fluid/operators/transpose_op.h | 2 + paddle/fluid/platform/mkldnn_reuse.h | 30 ++-- 12 files changed, 297 insertions(+), 71 deletions(-) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 134c894392..e5b11dfed5 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -174,7 +174,10 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n") - + elseif(${MKLDNN_FILE} STREQUAL "transpose_mkldnn_op") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, FP32);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, S8);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, U8);\n") else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") endif() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index c54e805e26..66a0ce2555 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1155,6 +1155,27 @@ PDNode *patterns::Conv::operator()() { return output_var; } +PDNode *patterns::Transpose::operator()() { + auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); + + auto transpose_op = + pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); + + auto transpose_in = pattern->NewNode(transpose_in_repr()) + ->AsInput() + ->assert_is_op_input("transpose2"); + auto transpose_out = pattern->NewNode(transpose_out_repr()) + ->AsOutput() + ->assert_is_op_output("transpose2", "Out"); + + auto next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + + prev_op->LinksTo({transpose_in}); + transpose_op->LinksFrom({transpose_in}).LinksTo({transpose_out}); + next_op->LinksFrom({transpose_out}); + return transpose_out; +} + PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index c53e4e5e25..d33d0da3db 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -750,6 +750,21 @@ struct ElementwiseAdd : public PatternBase { PATTERN_DECL_NODE(elementwise_add_out); }; +// Transpose op +// Forward pass for transpose. +// transpose_out is a result of the operator. +struct Transpose : public PatternBase { + Transpose(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "transpose2") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(prev_op); + PATTERN_DECL_NODE(transpose_in); + PATTERN_DECL_NODE(transpose_op); + PATTERN_DECL_NODE(transpose_out); + PATTERN_DECL_NODE(next_op); +}; + // Concat op // Forward pass for concat. // concat_out is a result of the operator. diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 89f51bfa2a..47430379ff 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -343,6 +343,65 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { quantize_prior_box_count); } +void CPUQuantizePass::QuantizeTranspose(Graph* graph) const { + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + patterns::Transpose transpose_pattern{pattern, name_scope_}; + transpose_pattern(); + + int quantize_transpose_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Quantize transpose op"; + GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, transpose_pattern); + auto* transpose_op_desc = transpose_op->Op(); + + if (!transpose_op_desc->HasAttr("use_quantizer")) { + return; + } + // skip if should not be quantized + if (!transpose_op_desc->HasAttr("use_quantizer") || + !boost::get(transpose_op_desc->GetAttr("use_quantizer"))) { + return; + } + GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, transpose_pattern); + GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, transpose_pattern); + + // skip if prev op is not quantized + // in future we should checked if next_op is quantized + // transpose INT8 schould be used only between INT8 operators + if (!(prev_op->Op()->Type() == "dequantize" || + (prev_op->Op()->HasAttr("use_quantizer") && + boost::get(prev_op->Op()->GetAttr("use_quantizer"))))) { + return; + } + + GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern); + + // get scales calculated after warmup, they scale variables to MAX=1.0 + auto scales = Get("quant_var_scales"); + + auto input_scale = scales[transpose_in->Name()].second.data()[0]; + bool is_input_unsigned = scales[transpose_in->Name()].first; + QuantizeInput(g, transpose_op, transpose_in, "X", input_scale, + is_input_unsigned); + + auto output_scale = scales[transpose_out->Name()].second.data()[0]; + bool is_output_unsigned = scales[transpose_out->Name()].first; + DequantizeOutput(g, transpose_op, transpose_out, "Out", output_scale, + is_output_unsigned); + + ++quantize_transpose_count; + }; + + gpd(graph, handler); + AddStatis(quantize_transpose_count); + + PrettyLogDetail("--- quantized %d transpose ops", + quantize_transpose_count); +} + void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Quantizing the graph."; PADDLE_ENFORCE(graph); @@ -355,6 +414,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { QuantizePool(graph); QuantizeConcat(graph); QuantizePriorBox(graph); + QuantizeTranspose(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index ec4db66240..d1b23227b6 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -52,6 +52,8 @@ class CPUQuantizePass : public FusePassBase { void QuantizePriorBox(Graph* graph) const; + void QuantizeTranspose(Graph* graph) const; + void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, double scale_to_one, bool is_unsigned, std::string scale_attr_name = "") const; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index 0a68944186..13efaf52d4 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -48,7 +48,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetAttr("Scale_in", 1.0f); op->SetAttr("Scale_out", 1.0f); op->SetAttr("Scale_weights", std::vector{1.0f}); - } else if (type == "pool2d") { + } else if (type == "pool2d" || type == "transpose2") { op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); op->SetAttr("use_quantizer", use_quantizer); @@ -113,19 +113,14 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, tensor->mutable_data(place, proto::VarType::FP32, 1); } -void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, - int quant_count, int dequant_count, int added_nodes_count, - float scale) { - std::unique_ptr graph(new ir::Graph(prog)); - - // Init scope, as it is used in pass +void PreparePass(std::unique_ptr* graph, const ProgramDesc& prog, + const std::initializer_list variable_names, + int* original_nodes_num, int* current_nodes_num) { auto place = paddle::platform::CPUPlace(); NaiveExecutor exe{place}; Scope scope; exe.CreateVariables(prog, 0, true, &scope); - auto* scales = new VarQuantScale(); - for (auto& v : variable_names) { InitTensorHolder(&scope, place, v.c_str()); LoDTensor tensor; @@ -136,16 +131,23 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, (*scales)[v] = std::make_pair(false, std::move(tensor)); } - graph->SetNotOwned(kParamScopeAttr, &scope); - - auto pass = PassRegistry::Instance().Get("cpu_quantize_pass"); + (*graph)->SetNotOwned(kParamScopeAttr, &scope); + std::unique_ptr pass = + PassRegistry::Instance().Get("cpu_quantize_pass"); pass->Set("quant_var_scales", scales); - int original_nodes_num = graph->Nodes().size(); - - graph.reset(pass->Apply(graph.release())); + *original_nodes_num = (*graph)->Nodes().size(); + (*graph).reset(pass->Apply((*graph).release())); + *current_nodes_num = (*graph)->Nodes().size(); +} - int current_nodes_num = graph->Nodes().size(); +void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, + int quant_count, int dequant_count, int added_nodes_count, + float scale) { + std::unique_ptr graph(new ir::Graph(prog)); + int original_nodes_num, current_nodes_num; + PreparePass(&graph, prog, variable_names, &original_nodes_num, + ¤t_nodes_num); int quantize_nodes_count = 0; int dequantize_nodes_count = 0; @@ -232,35 +234,9 @@ ProgramDesc BuildProgramDescConcat() { void MainTestConcat(const ProgramDesc& prog, int pool_count, int concat_count, int quant_count, int dequant_count, int added_nodes_count) { std::unique_ptr graph(new ir::Graph(prog)); - - // Init scope, as it is used in pass - auto place = paddle::platform::CPUPlace(); - NaiveExecutor exe{place}; - Scope scope; - exe.CreateVariables(prog, 0, true, &scope); - - auto* scales = new VarQuantScale(); - - for (auto& v : variable_names_concat) { - InitTensorHolder(&scope, place, v.c_str()); - LoDTensor tensor; - tensor.Resize({1}); - auto* ptr = tensor.mutable_data(place); - ptr[0] = 2.0; - - (*scales)[v] = std::make_pair(false, std::move(tensor)); - } - - graph->SetNotOwned(kParamScopeAttr, &scope); - - auto pass = PassRegistry::Instance().Get("cpu_quantize_pass"); - pass->Set("quant_var_scales", scales); - - int original_nodes_num = graph->Nodes().size(); - - graph.reset(pass->Apply(graph.release())); - - int current_nodes_num = graph->Nodes().size(); + int original_nodes_num, current_nodes_num; + PreparePass(&graph, prog, variable_names_concat, &original_nodes_num, + ¤t_nodes_num); int quantize_nodes_count = 0; int dequantize_nodes_count = 0; @@ -300,9 +276,93 @@ TEST(CpuQuantizePass, concat) { MainTestConcat(BuildProgramDescConcat(), pool_count, concat_count, quant_count, dequant_count, added_nodes_count); } - } // namespace +namespace { +static const std::initializer_list variable_names_transpose = { + "a", "w1", "b", "c", "w2", "d", "e", "f"}; + +// a->Conv1->b +// b->Transpose1->c +// c->Conv2->d +// d->Transpose2->e +// e->Dropout->f +ProgramDesc BuildProgramDescTranspose() { + ProgramDesc prog; + for (auto& v : variable_names_transpose) { + auto* var = prog.MutableBlock(0)->Var(v); + if (v.find("w") == 0) { + var->SetPersistable(true); + } + } + + SetOp(&prog, "conv2d", "Conv1", {"a", "w1"}, {"b"}, true, true); + SetOp(&prog, "transpose2", "Transpose1", {"b"}, {"c"}, true, true); + SetOp(&prog, "conv2d", "Conv1", {"c", "w2"}, {"d"}, true, true); + SetOp(&prog, "transpose2", "Transpose2", {"d"}, {"e"}, true, true); + SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, false); + + return prog; +} + +void MainTestTranspose(const ProgramDesc& prog, int conv_count, + int transpose_count, int quant_count, int dequant_count, + int added_nodes_count, float scale) { + std::unique_ptr graph(new ir::Graph(prog)); + int original_nodes_num, current_nodes_num; + PreparePass(&graph, prog, variable_names_transpose, &original_nodes_num, + ¤t_nodes_num); + + int quantize_nodes_count = 0; + int dequantize_nodes_count = 0; + int transpose_nodes_count = 0; + int conv_nodes_count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->Type() == "transpose2") { + transpose_nodes_count++; + } else if (op->Type() == "conv2d") { + conv_nodes_count++; + auto op_name = boost::get(op->GetAttr("name")); + EXPECT_EQ(boost::get(op->GetAttr("Scale_in")), scale) + << "Scale_in for node '" + op_name + "'."; + EXPECT_EQ(boost::get(op->GetAttr("Scale_out")), scale) + << "Scale_out for node '" + op_name + "'."; + EXPECT_EQ( + boost::get>(op->GetAttr("Scale_weights"))[0], + scale) + << "Scale_weights for node '" + op_name + "'."; + } else if (op->Type() == "quantize") { + quantize_nodes_count++; + } else if (op->Type() == "dequantize") { + dequantize_nodes_count++; + } + } + } + EXPECT_EQ(transpose_nodes_count, transpose_count); + EXPECT_EQ(conv_nodes_count, conv_count); + EXPECT_EQ(quantize_nodes_count, quant_count); + EXPECT_EQ(dequantize_nodes_count, dequant_count); + EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num); +} + +TEST(CpuQuantizePass, transpose) { + // a1->Quant->a2->Conv1->b1->Dequant->b2 + // b2->Quant->b3->Transpose->c1->Dequant->c2 + // c2->Quant->c3->Conv2->d1->Dequant->d2 + // d2->Quant->d3->Transpose->e1->Dequant->e2 + // e2->Dropout->f + int conv_count = 2; + int transpose_count = 2; + int quant_count = 4; + int dequant_count = 4; + // 4 Quant + 4 IN + 4 DeQuant + 4 OUT + int added_nodes_count = 16; + MainTestTranspose(BuildProgramDescTranspose(), conv_count, transpose_count, + quant_count, dequant_count, added_nodes_count, 2.0f * 127); +} +} // namespace } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc index c2b2ba0b60..9adc1150b5 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc @@ -34,6 +34,9 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { rules_["prior_box"]["Image"] = ScaleAlgo::NONE; rules_["prior_box"]["Boxes"] = ScaleAlgo::NONE; rules_["prior_box"]["Variances"] = ScaleAlgo::NONE; + + rules_["transpose"]["X"] = ScaleAlgo::KL; + rules_["transpose"]["Out"] = ScaleAlgo::KL; } ScaleAlgo MkldnnQuantizerConfig::scale_algo( diff --git a/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc b/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc index ccb50d4043..8cc4db3443 100644 --- a/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc @@ -268,7 +268,7 @@ TEST(Analyzer_int8_mobilenet_ssd, quantization) { q_cfg.EnableMkldnnQuantizer(); q_cfg.mkldnn_quantizer_config(); std::unordered_set quantize_operators( - {"conv2d", "depthwise_conv2d", "prior_box"}); + {"conv2d", "depthwise_conv2d", "prior_box", "transpose2"}); q_cfg.mkldnn_quantizer_config()->SetEnabledOpTypes(quantize_operators); q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data); q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_warmup_batch_size); diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index 88b96dcde8..c58195930d 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { @@ -29,6 +30,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { void Compute(const paddle::framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); + mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType(); auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); @@ -49,8 +51,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { nchw_tz, axis, ctx.op().Output("Out") + std::to_string(input->format())); - platform::TransposeMKLDNNHandler handler(nchw_tz, axis, dev_ctx, - mkldnn_engine, key); + platform::TransposeMKLDNNHandler handler( + nchw_tz, axis, input->type(), in_type, dev_ctx, mkldnn_engine, key); auto transpose_src_memory_p = handler.AcquireSrcMemory( input->format(), platform::to_void_cast(input_data)); @@ -78,7 +80,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { ctx.Input(framework::GradVarName("Out")); auto* x_grad = ctx.Output(framework::GradVarName("X")); if (!x_grad) return; - + mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType(); auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); @@ -103,7 +105,8 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { const std::string key = platform::TransposeMKLDNNHandler::GetHash( nchw_tz, axis, ctx.op().Output(framework::GradVarName("X"))); - platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx, + platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, + x_grad->type(), in_type, dev_ctx, mkldnn_engine, key); auto transpose_src_memory_p = handler.AcquireSrcMemory( @@ -124,11 +127,35 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { namespace ops = paddle::operators; -REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace, - ops::TransposeMKLDNNOpKernel); - -REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace, - ops::TransposeMKLDNNOpKernel); +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, + ::paddle::platform::CPUPlace, FP32, + ops::kTransposeMKLDNNFP32, + ops::TransposeMKLDNNOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, + ::paddle::platform::CPUPlace, U8, + ops::kTransposeMKLDNNINT8, + ops::TransposeMKLDNNOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, + ::paddle::platform::CPUPlace, S8, + ops::kTransposeMKLDNNINT8, + ops::TransposeMKLDNNOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN, + ::paddle::platform::CPUPlace, FP32, + ops::kTransposeMKLDNNFP32, + ops::TransposeMKLDNNOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN, + ::paddle::platform::CPUPlace, U8, + ops::kTransposeMKLDNNINT8, + ops::TransposeMKLDNNOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN, + ::paddle::platform::CPUPlace, S8, + ops::kTransposeMKLDNNINT8, + ops::TransposeMKLDNNOpKernel); REGISTER_OP_KERNEL(transpose_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::TransposeMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 47840d71a3..39bac99765 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -65,15 +65,23 @@ class TransposeOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; + auto input_data_type = ctx.Input("X")->type(); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; + using framework::proto::VarType; + customized_type_value = (input_data_type == VarType::INT8 || + input_data_type == VarType::UINT8) + ? kTransposeMKLDNNINT8 + : kTransposeMKLDNNFP32; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, + library_, customized_type_value); } }; @@ -99,6 +107,13 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") .SetDefault("AnyLayout"); + /* int8 parameters */ + AddAttr("use_quantizer", + "(bool, default false) " + "Set to true for operators that should be quantized and use " + "int8 kernel. " + "Only used on CPU.") + .SetDefault(false); AddComment(R"DOC( Transpose Operator. @@ -196,16 +211,24 @@ class Transpose2Op : public TransposeOp { const framework::ExecutionContext &ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; + auto input_data_type = ctx.Input("X")->type(); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; + using framework::proto::VarType; + customized_type_value = (input_data_type == VarType::INT8 || + input_data_type == VarType::UINT8) + ? kTransposeMKLDNNINT8 + : kTransposeMKLDNNFP32; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, + library_, customized_type_value); } }; diff --git a/paddle/fluid/operators/transpose_op.h b/paddle/fluid/operators/transpose_op.h index 895d1ce2cc..9ed76d066f 100644 --- a/paddle/fluid/operators/transpose_op.h +++ b/paddle/fluid/operators/transpose_op.h @@ -21,6 +21,8 @@ limitations under the License. */ namespace paddle { namespace operators { +enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 }; + template inline void TransCompute(const int dim, const DeviceContext& dev_ctx, const framework::Tensor& in, framework::Tensor* out, diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 935c4f734f..23cdaecc69 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -828,12 +828,16 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { public: TransposeMKLDNNHandler(std::vector& dims, // NOLINT std::vector& axis, // NOLINT + framework::proto::VarType::Type vtype, + mkldnn::memory::data_type dtype, const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, const std::string& base_key) : platform::MKLDNNHandler(dev_ctx, engine, base_key), dims_(dims), axis_(axis), - logical_axis_(dims.size(), 0) {} + logical_axis_(dims.size(), 0), + vtype_(vtype), + dtype_(dtype) {} std::shared_ptr AcquireSrcMemory( const mkldnn::memory::format& fmt, void* ptr) { @@ -847,9 +851,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { logical_axis_[i] = i; } auto src_md = fmt != mkldnn::memory::format::nchw - ? platform::MKLDNNMemDesc( - dims_, platform::MKLDNNGetDataType(), fmt) - : Axis2MemoryDesc(dims_, logical_axis_); + ? platform::MKLDNNMemDesc(dims_, dtype_, fmt) + : Axis2MemoryDesc(dims_, logical_axis_, dtype_); mem_p = std::make_shared( mkldnn::memory::primitive_desc{src_md, engine_}, ptr); dev_ctx_.SetBlob(local_key, mem_p); @@ -866,14 +869,14 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); if (mem_p == nullptr) { auto dst_mdp = mkldnn::memory::primitive_desc{ - Axis2MemoryDesc(dims_, axis_), engine_}; + Axis2MemoryDesc(dims_, axis_, dtype_), engine_}; - auto dst_data = output->mutable_data(place, dst_mdp.get_size()); + auto dst_data = output->mutable_data(place, vtype_); mem_p = std::make_shared(dst_mdp, dst_data); dev_ctx_.SetBlob(local_key, mem_p); } else { - auto dst_data = output->mutable_data(place); + auto dst_data = output->mutable_data(place, vtype_); mem_p->set_data_handle(dst_data); } return mem_p; @@ -901,8 +904,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { protected: mkldnn_memory_desc_t Axis2MemoryDesc(std::vector& nchw_tz, // NOLINT - std::vector& axis // NOLINT - ) { + std::vector& axis, // NOLINT + mkldnn::memory::data_type dtype) { mkldnn_memory_desc_t mem_fmt; mem_fmt.primitive_kind = mkldnn_memory; @@ -911,7 +914,12 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format, // regardless physical layout) } - mem_fmt.data_type = mkldnn_f32; + if (dtype == mkldnn::memory::data_type::s8) + mem_fmt.data_type = mkldnn_s8; + else if (dtype == mkldnn::memory::data_type::u8) + mem_fmt.data_type = mkldnn_u8; + else + mem_fmt.data_type = mkldnn_f32; mem_fmt.format = mkldnn_blocked; unsigned int total_stride = 1; @@ -933,6 +941,8 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { std::vector dims_; std::vector axis_; std::vector logical_axis_; + framework::proto::VarType::Type vtype_; + mkldnn::memory::data_type dtype_; }; class ReorderMKLDNNHandler : public MKLDNNHandler { -- GitLab