From 0c39b97b4ee0f732a6a9a349511b3ac3cdc1633c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Gallus?= Date: Fri, 24 May 2019 16:51:22 +0200 Subject: [PATCH] [MKL-DNN] Add Fully Connected Op for inference only(#15226) * fuse mul and elementwise add to fc * Reimplement the FC forward operator * Fix FC MKLDNN integration by transposing weights * Add FC MKLDNN Pass test=develop * FC MKLDNN Pass: change memcpy to std::copy * Fix MKLDNN FC handling of mismatch input and weights dims * Lower tolerance for MKL-DNN in resnet50 test test=develop * Adjust FC to support MKLDNN Op placement test=develop * Adjust Placement Op to set use_mkldnn attribute for graph test=develop * MKLDNN FC: fix weights format so that gemm version is called test=develop * FC MKLDNN: Remove tolerance decrease from tester_helper * FC MKL-DNN: Refactor the code, change input reorder to weight reorder * MKL-DNN FC: Introduce operator caching test=develop * FC MKL-DNN: Fix the tensor type in ExpectedKernelType test=develop * FC MKL-DNN: fix style changes test=develop * FC MKL-DNN: fallback to native on non-supported dim sizes test=develop * FC MKLDNN: fix CMake paths test=develop * FC MKLDNN: Refine placement pass graph mkldnn attribute test=develop * Fix Transpiler error for fuse_conv_eltwise test=develop * Fix missing STL includes in files test=develop * FC MKL-DNN: Enable new output size computation Also, refine pass to comply with newest interface. test=develop * FC MKL-DNN: enable only when fc_mkldnn_pass is enabled * FC MKL-DNN: Allow Weights to use oi or io format * FC MKL-DNN: Adjust UT to work with correct dims test=develop * Enable MKL DEBUG for resnet50 analyzer test=develop * FC MKL-DNN: Improve Hashing function test=develop * FC MKL-DNN: Fix shape for fc weights in transpiler * FC MKL-DNN: Update input pointer in re-used fc primitive * Add log for not handling fc fuse for unsupported dims test=develop * FC MKL-DNN: Move transpose from pass to Op Kernel test=develop * FC MKL-DNN: Disable transpose in unit test test=develop * FC MKL-DNN: Remove fc_mkldnn_pass from default list * Correct Flag for fake data analyzer tests test=develop * FC MKL-DNN: Add comment about fc mkldnn pass disablement test=develop * FC MKL-DNN: Disable fc in int8 tests test=develop --- cmake/generic.cmake | 2 +- paddle/fluid/framework/ir/CMakeLists.txt | 1 + paddle/fluid/framework/ir/fc_fuse_pass.cc | 2 + .../framework/ir/graph_pattern_detector.cc | 30 ++ .../framework/ir/graph_pattern_detector.h | 19 + .../framework/ir/mkldnn/fc_mkldnn_pass.cc | 77 ++++ .../framework/ir/mkldnn/fc_mkldnn_pass.h | 38 ++ .../ir/mkldnn/mkldnn_placement_pass.cc | 4 + .../inference/api/paddle_pass_builder.cc | 23 +- .../fluid/inference/tests/api/CMakeLists.txt | 12 +- .../tests/api/analyzer_bert_tester.cc | 1 + .../tests/api/analyzer_dam_tester.cc | 3 +- .../tests/api/analyzer_mm_dnn_tester.cc | 1 + .../tests/api/analyzer_resnet50_tester.cc | 2 + .../tests/api/analyzer_seq_pool1_tester.cc | 1 + .../tests/api/analyzer_transformer_tester.cc | 1 + .../tests/api/analyzer_vis_tester.cc | 1 + paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc | 432 ++++++++---------- .../unittests/mkldnn/test_fc_mkldnn_op.py | 45 +- .../fluid/transpiler/inference_transpiler.py | 93 ++++ 20 files changed, 502 insertions(+), 286 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h diff --git a/cmake/generic.cmake b/cmake/generic.cmake index c5bedf376ba..dfa90a3fe63 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -385,7 +385,7 @@ function(cc_test TARGET_NAME) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true ${MKL_DEBUG_FLAG}) # No unit test should exceed 10 minutes. set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) endif() diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index d205a788411..3210f3041a1 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -88,6 +88,7 @@ if(WITH_MKLDNN) pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) + pass_library(fc_mkldnn_pass inference mkldnn) pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_pass inference mkldnn) pass_library(cpu_quantize_squash_pass inference mkldnn) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index cd8030519cc..4691b9abfdf 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/fc_fuse_pass.h" +#include #include #include #include @@ -80,6 +81,7 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const { } desc.SetType("fc"); + auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out}); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index cb7ef41861e..f0d47ad57f9 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -14,7 +14,10 @@ #include #include +#include #include +#include +#include #include #include "paddle/fluid/framework/ir/graph_helper.h" @@ -896,6 +899,33 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, } } +PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x, + bool with_bias) { + // Create shared nodes. + x->assert_is_op_input("fc", "Input"); + + auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc"); + // Create variables + // Filter + auto *fc_weight_var = pattern->NewNode(weights_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("fc", "W"); + // Bias + auto *fc_bias_var = pattern->NewNode(bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("fc", "Bias"); + // Output + auto *fc_out_var = pattern->NewNode(output_repr()) + ->AsOutput() + ->assert_is_op_output("fc", "Out") + ->assert_is_only_output_of_op("fc"); + + fc_op->LinksFrom({x, fc_weight_var, fc_bias_var}).LinksTo({fc_out_var}); + return fc_out_var; +} + PDNode *patterns::Embedding::operator()(PDNode *x) { x->assert_is_op_input("lookup_table", "Ids"); auto *lookup_table_op = diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index bb62716ec7f..7df2f5efc45 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -517,6 +517,25 @@ struct FC : public PatternBase { PATTERN_DECL_NODE(Out); }; +// MKL-DNN's FC with bias +// op: fc +// named node: +// fc +// w, bias, output +struct FCMKLDNN : public PatternBase { + FCMKLDNN(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "fc_mkldnn") {} + + PDNode* operator()(PDNode* x, bool with_bias); + + // declare operator node's name + PATTERN_DECL_NODE(fc); + // declare variable node's name + PATTERN_DECL_NODE(weights); + PATTERN_DECL_NODE(bias); + PATTERN_DECL_NODE(output); +}; + // Embedding struct Embedding : public PatternBase { Embedding(PDPattern* pattern, const std::string& name_scope) diff --git a/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.cc new file mode 100644 index 00000000000..9cc2d3da3fc --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h" +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + Init("fc_mkldnn_pass", graph); + + auto* scope = param_scope(); + PADDLE_ENFORCE(scope); + + GraphPatternDetector gpd; + auto* x = gpd.mutable_pattern() + ->NewNode("fc_mkldnn_pass/x") + ->AsInput() + ->assert_is_op_input("fc", "Input"); + patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass"); + fc_pattern(x, true /*with bias*/); + + int found_fc_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Handle FC MKL-DNN pass"; + if (!(graph->Has("use_mkldnn") && graph->Get("use_mkldnn"))) { + VLOG(3) << "do not perform fc fuse"; + return; + } + GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern); + + OpDesc* desc = fc->Op(); + auto in_size = fc->inputs[0]->Var()->GetShape().size(); + if (in_size != 2 && in_size != 4) { + VLOG(3) << "Do not enable FC MKL-DNN for dimensions different than 2 & 4"; + return; + } + desc->SetAttr("use_mkldnn", true); + PADDLE_ENFORCE(subgraph.count(x)); + + found_fc_count++; + }; + + gpd(graph, handler); + + AddStatis(found_fc_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fc_mkldnn_pass, paddle::framework::ir::FCMKLDNNPass); diff --git a/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h new file mode 100644 index 00000000000..97c6b242989 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Transpose weights of FC to comply with MKL-DNN interface + */ +class FCMKLDNNPass : public FusePassBase { + public: + virtual ~FCMKLDNNPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc index 500419e4b78..a2092a5059a 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" +#include #include #include @@ -24,6 +25,9 @@ void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Applies MKL-DNN placement strategy."; const auto& op_types_list = Get>("mkldnn_enabled_op_types"); + if (!graph->Has("use_mkldnn")) { + graph->Set("use_mkldnn", new bool(true)); + } for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index b39f740ec02..2bad89cdb33 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -146,16 +146,19 @@ void CpuPassStrategy::EnableMKLDNN() { if (!use_mkldnn_) { passes_.insert(passes_.begin(), "mkldnn_placement_pass"); - for (auto &pass : std::vector( - {"depthwise_conv_mkldnn_pass", // - "conv_bn_fuse_pass", // Execute BN passes again to - "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order - "conv_bias_mkldnn_fuse_pass", // - "conv3d_bias_mkldnn_fuse_pass", // - "conv_elementwise_add_mkldnn_fuse_pass", - "conv_concat_relu_mkldnn_fuse_pass", - "conv_relu_mkldnn_fuse_pass", // - "conv_brelu_mkldnn_fuse_pass"})) { + for (auto &pass : std::vector({ + "depthwise_conv_mkldnn_pass", // + "conv_bn_fuse_pass", // Execute BN passes again to + "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order + "conv_bias_mkldnn_fuse_pass", // + "conv3d_bias_mkldnn_fuse_pass", // + "conv_elementwise_add_mkldnn_fuse_pass", + "conv_concat_relu_mkldnn_fuse_pass", + "conv_relu_mkldnn_fuse_pass", // + "conv_brelu_mkldnn_fuse_pass", // + // Disabled due to topology-dependent speed-up + // "fc_mkldnn_pass" + })) { passes_.push_back(pass); } } diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 367a37eecab..b37e3936d1b 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -33,8 +33,10 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename) --paddle_num_threads=${CPU_NUM_THREADS_ON_CI} --iterations=2) endfunction() - -function(inference_analysis_api_test_with_fake_data target install_dir filename model_name) +function(inference_analysis_api_test_with_fake_data target install_dir filename model_name mkl_debug) + if(mkl_debug) + set(MKL_DEBUG_FLAG MKL_DEBUG_CPU_TYPE=7) + endif() download_model(${install_dir} ${model_name}) inference_analysis_test(${target} SRCS ${filename} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} @@ -143,15 +145,15 @@ inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose # googlenet inference_analysis_api_test_with_fake_data(test_analyzer_googlenet - "${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz") + "${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" false) # resnet50 inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 - "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz") + "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" true) # mobilenet with depthwise_conv op inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv - "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz") + "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz" false) # int8 image classification tests if(WITH_MKLDNN) diff --git a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc index 9b2e74ec16e..fda6cf358d3 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc @@ -152,6 +152,7 @@ void profile(bool use_mkldnn = false) { if (use_mkldnn) { config.EnableMKLDNN(); + config.pass_builder()->AppendPass("fc_mkldnn_pass"); } std::vector> outputs; diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index cfbb3b15461..3efe17c8108 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -200,8 +200,9 @@ void profile(bool use_mkldnn = false) { cfg.EnableMKLDNN(); // Enable all the mkldnn supported ops except conv3d in dam std::unordered_set op_list = {"softmax", "elementwise_add", - "relu"}; + "relu", "fc"}; cfg.SetMKLDNNOp(op_list); + cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); } std::vector> outputs; diff --git a/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc b/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc index 2eb347a44b3..245357bfff8 100644 --- a/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc @@ -100,6 +100,7 @@ void profile(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); + cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index e883ad5bfcf..602d59457c0 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -48,6 +48,7 @@ void profile(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); + cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); } std::vector> outputs; @@ -79,6 +80,7 @@ void compare(bool use_mkldnn = false) { SetConfig(&cfg); if (use_mkldnn) { cfg.EnableMKLDNN(); + cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc index 3cebf8e9698..e78f04a07c5 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc @@ -149,6 +149,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { } if (use_mkldnn) { cfg->EnableMKLDNN(); + cfg->pass_builder()->AppendPass("fc_mkldnn_pass"); } // Enable seqpool_concat_fuse_pass, disabled by default since it takes much // time diff --git a/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc b/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc index a23297f29cf..147c7712fb8 100644 --- a/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc @@ -189,6 +189,7 @@ void profile(bool use_mkldnn = false) { std::vector> outputs; if (use_mkldnn) { cfg.EnableMKLDNN(); + cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index fb47048cd0c..3f020f3cbb6 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -85,6 +85,7 @@ void profile(bool use_mkldnn = false) { SetConfig(&cfg); if (use_mkldnn) { cfg.EnableMKLDNN(); + cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); } // cfg.pass_builder()->TurnOnDebug(); std::vector> outputs; diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 69c0486eb63..764183f085a 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -12,299 +12,266 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/variant.h" namespace paddle { namespace operators { -using paddle::framework::Tensor; -using paddle::platform::MKLDNNDeviceContext; +using framework::DataLayout; +using framework::Tensor; +using framework::LoDTensor; +using framework::DDim; +using framework::ExecutionContext; +using platform::MKLDNNDeviceContext; +using platform::to_void_cast; +using platform::GetMKLDNNFormat; +using mkldnn::memory; +using mkldnn::inner_product_forward; +using mkldnn::primitive; +using mkldnn::stream; +using mkldnn::prop_kind; template -class MKLDNNMD { +class FCPrimitiveFactory { public: - explicit MKLDNNMD(const T* in, const T* w, bool bias) - : in(paddle::framework::vectorize2int(in->dims())), - w(paddle::framework::vectorize2int(w->dims())) { - with_bias_ = bias; - } + explicit FCPrimitiveFactory(const mkldnn::engine& engine) : engine_(engine) {} + + inner_product_forward CreateFcPrimitive(const LoDTensor* input, + const Tensor* weights, + const Tensor* bias, LoDTensor* output, + const ExecutionContext& ctx) { + RecomputeOutputDims(ctx, input, weights, output); + if (fc_) { + UpdateDataPointers(ctx, output, input); + return *fc_; + } + auto src_desc = CreateMemDescriptor(input, input->format()); + input_ = CreateMemory(src_desc, input); - mkldnn::memory::desc dst() const { - return platform::MKLDNNMemDesc({in[0], w[1]}, - mkldnn::memory::data_type::f32, - mkldnn::memory::format::nc); - } + weights_ = TransposeWeights(weights); + if (src_desc.data.ndims == 4) { + weights_ = CreateFourDimWeightsMemory(input, weights); + } + + auto dst_desc = CreateMemDescriptor(output, memory::format::any); - mkldnn::memory::desc src() const { - return is_spatial() - ? platform::MKLDNNMemDesc({in[0], in[1], in[2], in[3]}, - mkldnn::memory::data_type::f32, - mkldnn::memory::format::nchw) - : platform::MKLDNNMemDesc({in[0], in[1]}, - mkldnn::memory::data_type::f32, - mkldnn::memory::format::nc); + fc_ = CreateFcPrimitive(*input_, *weights_, dst_desc, bias, output, ctx); + return *fc_; } - mkldnn::memory::desc weights() const { - return is_spatial() - ? platform::MKLDNNMemDesc({w[1], in[1], in[2], in[3]}, - mkldnn::memory::data_type::f32, - mkldnn::memory::format::oihw) - : platform::MKLDNNMemDesc({w[1], in[1]}, - mkldnn::memory::data_type::f32, - mkldnn::memory::format::oi); + private: + void UpdateDataPointers(const ExecutionContext& ctx, Tensor* out, + const Tensor* in) { + input_->set_data_handle(const_cast(in->data())); + output_->set_data_handle(out->mutable_data(ctx.GetPlace())); + if (out->format() == memory::format::format_undef) { + auto output_format = output_->get_primitive_desc().desc().data.format; + out->set_format((memory::format)output_format); + } } - mkldnn::memory::desc bias() const { - return with_bias_ - ? platform::MKLDNNMemDesc({w[1]}, mkldnn::memory::data_type::f32, - mkldnn::memory::format::format_undef) - : platform::MKLDNNMemDesc({}, mkldnn::memory::data_type::f32, - mkldnn::memory::format::format_undef); + memory::format MatchWeightFormat(memory::format fmt) { + using format = memory::format; + switch (fmt) { + case format::nChw16c: + return format::oIhw16i; + case format::nChw8c: + return format::oIhw8i; + case format::nchw: + return format::oihw; + default: + return format::format_undef; + } } - private: - bool is_spatial() const { return in.size() > 1 && w.size() > 1; } + mkldnn::memory Reorder(const memory::desc& src_desc, + const memory::desc& dst_desc, const void* src_data) { + auto src_mem = memory({src_desc, engine_}, const_cast(src_data)); + auto dst_mem = memory({dst_desc, engine_}); - std::vector in; - std::vector w; - bool with_bias_; - bool is_spatial_; -}; + auto reorder = mkldnn::reorder(src_mem, dst_mem); + stream(stream::kind::eager).submit({reorder}).wait(); -class MKLDNNMemory { - public: - MKLDNNMemory(MKLDNNMD* t, const mkldnn::engine& e) - : md_(t), engine_(e) {} - virtual ~MKLDNNMemory() = default; - - template - mkldnn::memory dst(const Output* out) { - return mkldnn::memory({md_->dst(), engine_}, - static_cast(const_cast(out))); + return dst_mem; } - template - mkldnn::memory dst(Output* out) { - return mkldnn::memory({md_->dst(), engine_}, out); + static mkldnn::memory::desc CreateMemDescriptor(const std::vector& dims, + memory::format format) { + return platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType(), + format); } - template - mkldnn::memory src(const Input* in) { - return mkldnn::memory({md_->src(), engine_}, - static_cast(const_cast(in))); + static mkldnn::memory::desc CreateMemDescriptor(const Tensor* tensor, + memory::format format) { + auto dims = framework::vectorize2int(tensor->dims()); + return CreateMemDescriptor(dims, format); } - template - mkldnn::memory weights(const Weight* w) { - return mkldnn::memory({md_->weights(), engine_}, - static_cast(const_cast(w))); + mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc, + const Tensor* tensor) { + return CreateMemory(desc, tensor->data()); } - mkldnn::memory bias() { - return mkldnn::memory(mkldnn::memory::primitive_desc(md_->bias(), engine_)); + mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc, + const void* data) { + return memory({desc, engine_}, const_cast(data)); } - private: - MKLDNNMD* md_; - const mkldnn::engine& engine_; -}; - -template -class FCMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); + mkldnn::memory TransposeWeights(const Tensor* weights) { + auto dims = framework::vectorize2int(weights->dims()); + std::swap(dims[0], dims[1]); // Correct output dimensions + auto src_desc = CreateMemDescriptor(dims, memory::format::io); + auto dst_desc = CreateMemDescriptor(dims, memory::format::oi); + return Reorder(src_desc, dst_desc, weights->data()); + } - auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); + inner_product_forward CreateFcPrimitive(const memory& src_memory, + const memory& weights_memory, + const memory::desc& dst_desc, + const Tensor* bias, Tensor* output, + const ExecutionContext& ctx) { + const auto weights_desc = weights_memory.get_primitive_desc().desc(); + const auto src_desc = src_memory.get_primitive_desc().desc(); + if (bias) { + auto bias_desc = CreateMemDescriptor(bias, bias->format()); + bias_ = CreateMemory(bias_desc, bias); + auto fc_prim_desc = + CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc); + + output_ = CreateDstMemory(fc_prim_desc, ctx, output); + + return inner_product_forward(fc_prim_desc, src_memory, weights_memory, + *bias_, *output_); + } else { + auto fc_prim_desc = CreateFcPrimDesc(src_desc, weights_desc, dst_desc); + + output_ = CreateDstMemory(fc_prim_desc, ctx, output); + + return inner_product_forward(fc_prim_desc, src_memory, weights_memory, + *output_); + } + } - auto input = ctx.Input("Input"); - auto w = ctx.Input("W"); - auto bias = ctx.Input("Bias"); + mkldnn::inner_product_forward::primitive_desc CreateFcPrimDesc( + const mkldnn::memory::desc& input_desc, + const mkldnn::memory::desc& weights_desc, + const mkldnn::memory::desc& bias_desc, + const mkldnn::memory::desc& dst_desc) { + auto fc_desc = + inner_product_forward::desc(prop_kind::forward_scoring, input_desc, + weights_desc, bias_desc, dst_desc); - PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, - "Input must be with 2 or 4 dimensions, i.e. NCHW"); - // TODO(intel friends): the native weight format is io, - // but the mkldnn weight format is oihw, which may need be transposed. - PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4, - "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); + return inner_product_forward::primitive_desc(fc_desc, engine_); + } - bool with_bias = bias != nullptr; - MKLDNNMD md(input, w, with_bias); + mkldnn::inner_product_forward::primitive_desc CreateFcPrimDesc( + const mkldnn::memory::desc& input_desc, + const mkldnn::memory::desc& weights_desc, + const mkldnn::memory::desc& dst_desc) { + auto fc_desc = inner_product_forward::desc(prop_kind::forward, input_desc, + weights_desc, dst_desc); - std::shared_ptr pd = - FcFwdPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(), - with_bias, mkldnn_engine); + return inner_product_forward::primitive_desc(fc_desc, engine_); + } - const std::string key = ctx.op().Output("Out"); - const std::string key_fc_pd = key + "@fc_pd"; + mkldnn::memory CreateFourDimWeightsMemory(const Tensor* input, + const Tensor* weights) { + auto input_dims = framework::vectorize2int(input->dims()); + auto weight_dims = framework::vectorize2int(weights->dims()); + auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]}; - dev_ctx.SetBlob(key_fc_pd, pd); + auto dst_format = MatchWeightFormat(input->format()); + auto src_desc = CreateMemDescriptor(dims, memory::format::oihw); + auto dst_desc = CreateMemDescriptor(dims, dst_format); - MKLDNNMemory mem(&md, mkldnn_engine); + return Reorder(src_desc, dst_desc, weights_->get_data_handle()); + } - const T* input_data = input->data(); - const T* w_data = w->data(); + mkldnn::memory CreateDstMemory( + const mkldnn::inner_product_forward::primitive_desc& fc_prim_desc, + const ExecutionContext& ctx, Tensor* output) { + auto dst_prim_desc = fc_prim_desc.dst_primitive_desc(); + auto buffer_size = dst_prim_desc.get_size(); + T* output_data = output->mutable_data( + ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, buffer_size); + output->set_format((memory::format)dst_prim_desc.desc().data.format); + return memory(dst_prim_desc, to_void_cast(output_data)); + } - auto output = ctx.Output("Out"); + void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input, + const Tensor* w, LoDTensor* output) { int in_num_col_dims = ctx.Attr("in_num_col_dims"); std::vector output_dims; FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims); output->Resize(framework::make_ddim(output_dims)); output->set_lod(input->lod()); + } - T* output_data = output->mutable_data(ctx.GetPlace()); - - auto dst_memory = mem.dst(output_data); - auto src_memory = mem.src(input_data); - auto weights_memory = mem.weights(w_data); - // TODO(intel friends): bias memory should also be obtain from bias->data() - auto bias_memory = mem.bias(); + private: + const mkldnn::engine& engine_; + boost::optional bias_; + boost::optional input_; + boost::optional output_; + boost::optional weights_; + boost::optional fc_; +}; - auto forward = with_bias ? mkldnn::inner_product_forward( - *pd, src_memory, weights_memory, bias_memory, - dst_memory) - : mkldnn::inner_product_forward( - *pd, src_memory, weights_memory, dst_memory); +static std::string GetHash(const Tensor* input, const Tensor* weights, + const std::string& suffix) { + auto dim2str = [](const DDim& operand_dims) { + std::string str = ""; + for (size_t i = 0; i < operand_dims.size(); ++i) { + str += std::to_string(operand_dims[i]) + "-"; + } + return str; + }; + return std::to_string((unsigned)input->format()) + dim2str(weights->dims()) + + suffix; +} - std::vector pipeline = {forward}; - mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); +template +std::shared_ptr> GetPrimitiveFactory( + const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx, + const Tensor* input, const Tensor* weights, + const mkldnn::engine& mkldnn_engine) { + const std::string key = GetHash(input, weights, ctx.op().Output("Out")); + + auto prim_creator = + std::static_pointer_cast>(dev_ctx.GetBlob(key)); + if (prim_creator == nullptr) { + prim_creator = std::make_shared>(mkldnn_engine); + dev_ctx.SetBlob(key, prim_creator); } - private: - std::unique_ptr - FcFwdPrimitiveDesc(const mkldnn::memory::desc& src, - const mkldnn::memory::desc& weights, - const mkldnn::memory::desc& dst, - const mkldnn::memory::desc& bias, const bool with_bias, - const mkldnn::engine& engine) const { - auto desc = with_bias - ? mkldnn::inner_product_forward::desc( - mkldnn::prop_kind::forward, src, weights, bias, dst) - : mkldnn::inner_product_forward::desc( - mkldnn::prop_kind::forward, src, weights, dst); - - auto pd = new mkldnn::inner_product_forward::primitive_desc(desc, engine); - return std::unique_ptr(pd); - } -}; + return prim_creator; +} template -class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel { +class FCMKLDNNOpKernel : public framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); - T* input_grad_data = nullptr; - T* w_grad_data = nullptr; - - Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); - Tensor* w_grad = ctx.Output(framework::GradVarName("W")); - - const Tensor* input = ctx.Input("Input"); - const T* input_data = input->data(); - - const Tensor* w = ctx.Input("W"); - const T* w_data = w->data(); - - if (input_grad) { - input_grad->Resize(input->dims()); - input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - } - if (w_grad) { - w_grad->Resize(w->dims()); - w_grad_data = w_grad->mutable_data(ctx.GetPlace()); - } - - const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); - const T* out_grad_data = out_grad->data(); - + auto input = ctx.Input("Input"); + auto w = ctx.Input("W"); auto bias = ctx.Input("Bias"); - bool with_bias = bias != nullptr; - - MKLDNNMD md(input, w, with_bias); - MKLDNNMemory mem(&md, mkldnn_engine); - - auto dst_memory = mem.dst(out_grad_data); - auto src_memory = mem.src(input_data); - auto weights_memory = mem.weights(w_data); - auto bias_memory = mem.bias(); + auto output = ctx.Output("Out"); - const std::string key = ctx.op().Input("Out"); - const std::string key_fc_pd = key + "@fc_pd"; + auto prim_creator = + GetPrimitiveFactory(dev_ctx, ctx, input, w, mkldnn_engine); + auto fc = prim_creator->CreateFcPrimitive(input, w, bias, output, ctx); + stream(stream::kind::eager).submit({fc}).wait(); - auto pd = - std::static_pointer_cast( - dev_ctx.GetBlob(key_fc_pd)); - - PADDLE_ENFORCE(pd != nullptr, "Fail to find key_fc_pd in device context"); - - if (w_grad) { - auto weights_grad_memory = mem.weights(w_grad_data); - - mkldnn::inner_product_backward_weights::primitive_desc bwd_weight_pd = - FcBwdWeightsPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(), - with_bias, *pd, mkldnn_engine); - - auto bwd_weights_prim = mkldnn::inner_product_backward_weights( - bwd_weight_pd, src_memory, dst_memory, weights_grad_memory, - bias_memory); - - std::vector pipeline{bwd_weights_prim}; - mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); - } - - if (input_grad) { - auto src_grad_memory = mem.src(input_grad_data); - - mkldnn::inner_product_backward_data::primitive_desc bwd_data_pd = - FcBwdDataPrimitiveDesc(md.src(), md.weights(), md.dst(), *pd, - mkldnn_engine); - - auto bwd_data_prim = mkldnn::inner_product_backward_data( - bwd_data_pd, dst_memory, weights_memory, src_grad_memory); - - std::vector pipeline{bwd_data_prim}; - mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); - } - } - - private: - mkldnn::inner_product_backward_weights::primitive_desc - FcBwdWeightsPrimitiveDesc( - const mkldnn::memory::desc& src, const mkldnn::memory::desc& diff_weights, - const mkldnn::memory::desc& diff_dst, const mkldnn::memory::desc& bias, - const bool with_bias, - const mkldnn::inner_product_forward::primitive_desc& pd, - const mkldnn::engine& engine) const { - auto bwd_weight_desc = with_bias - ? mkldnn::inner_product_backward_weights::desc( - src, diff_weights, bias, diff_dst) - : mkldnn::inner_product_backward_weights::desc( - src, diff_weights, diff_dst); - - return mkldnn::inner_product_backward_weights::primitive_desc( - bwd_weight_desc, engine, pd); - } - - mkldnn::inner_product_backward_data::primitive_desc FcBwdDataPrimitiveDesc( - const mkldnn::memory::desc& diff_src, const mkldnn::memory::desc& weights, - const mkldnn::memory::desc& diff_dst, - const mkldnn::inner_product_forward::primitive_desc& pd, - const mkldnn::engine& engine) const { - auto bwd_data_desc = - mkldnn::inner_product_backward_data::desc(diff_src, weights, diff_dst); - return mkldnn::inner_product_backward_data::primitive_desc(bwd_data_desc, - engine, pd); + output->set_layout(DataLayout::kMKLDNN); } }; } // namespace operators @@ -312,6 +279,3 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel { REGISTER_OP_KERNEL(fc, MKLDNN, ::paddle::platform::CPUPlace, paddle::operators::FCMKLDNNOpKernel); - -REGISTER_OP_KERNEL(fc_grad, MKLDNN, ::paddle::platform::CPUPlace, - paddle::operators::FCMKLDNNGradOpKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py index 84229a5cffb..8f0a9898dce 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py @@ -20,34 +20,30 @@ from paddle.fluid.tests.unittests.op_test import OpTest def fully_connected_naive(input, weights, bias_data=None): - in_n, in_c, in_h, in_w = input.shape - w_h, w_c = weights.shape - - x_data = np.reshape(input, [in_n, in_c * in_h * in_w]) - # this transpose should be implemented at C code - w_data = np.transpose(np.reshape(weights, (w_c, in_c * in_h * in_w))) result = None if not bias_data: - result = np.dot(x_data, w_data) + result = np.dot(input, weights) else: - result = np.dot(x_data, w_data) + bias_data + result = np.dot(input, weights) + bias_data return result class MatrixGenerate: def __init__(self, mb, ic, oc, h, w): - self.input = np.random.random((mb, ic, h, w)).astype("float32") + self.input = np.random.random((mb, ic * h * w)).astype("float32") self.weights = np.random.random((ic * h * w, oc)).astype("float32") class TestFCMKLDNNOp(OpTest): + def create_data(self): + self.matrix = MatrixGenerate(1, 10, 15, 3, 3) + def setUp(self): self.op_type = "fc" self.use_mkldnn = True - self.matrix = MatrixGenerate(1, 10, 15, 3, 3) - + self.create_data() self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} self.attrs = {'use_mkldnn': self.use_mkldnn, } @@ -60,37 +56,16 @@ class TestFCMKLDNNOp(OpTest): self.check_output() def test_check_grad_normal(self): - self.check_grad(set(['Input', 'W']), 'Out', max_relative_error=0.9) + pass def test_check_grad_no_weight(self): - self.check_grad( - ['Input'], 'Out', max_relative_error=0.5, no_grad_set=set('W')) + pass class TestFCMKLDNNOp1(TestFCMKLDNNOp): - def init_op_type(self): + def create_data(self): self.matrix = MatrixGenerate(2, 15, 48, 2, 2) -class TestFCMKLDNNOp2(TestFCMKLDNNOp): - def init_op_type(self): - self.matrix = MatrixGenerate(2, 32, 40, 1, 1) - - -class TestFCMKLDNNOp3(TestFCMKLDNNOp): - def init_op_type(self): - self.matrix = MatrixGenerate(2, 2, 4, 1, 1) - - -class TestFCMKLDNNOp4(TestFCMKLDNNOp): - def init_op_type(self): - self.matrix = MatrixGenerate(2, 32, 48, 2, 2) - - -class TestFCMKLDNNOp4(TestFCMKLDNNOp): - def init_op_type(self): - self.matrix = MatrixGenerate(2, 32, 1000, 6, 6) - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/transpiler/inference_transpiler.py b/python/paddle/fluid/transpiler/inference_transpiler.py index 8a527e72fb9..8917fb75128 100644 --- a/python/paddle/fluid/transpiler/inference_transpiler.py +++ b/python/paddle/fluid/transpiler/inference_transpiler.py @@ -76,6 +76,7 @@ class InferenceTranspiler(object): self._fuse_conv_relu_mkldnn( program) # ResNet residual block merging self._fuse_bn_relu_mkldnn(program) + self._fuse_mul_add_mkldnn(program) self._is_test_pass(program) @@ -387,6 +388,62 @@ class InferenceTranspiler(object): # And a better solution will be considered later. program = program.clone() + def _fuse_mul_add_mkldnn(self, program): + ''' + Transpile the program by fusing Mul+Add layers to FC layer with the MKL-DNN inner product. + The MUL following a Elementwise_add layer can be replaced by the MKL-DNN FC. + The Elementwise add's bias input 'Y' has to be added into the + MKL-DNN-based FC input 'Bias'. + The operator transformation is: + - before: + - MUL->elementwise_add -> any_other_op + - after: + - FC -> any_other_op + The transpile stages are: + 1. insert a new MKL-DNN-based FC operator with `Bias` input + taken from the Elementwise add's input 'Y' (bias), + 2. fuse the parameters of MUL and Elemenwise add, + 3. remove the MUL, elementwise_add operators, + 4. make the input of the deleted Elementwise add operator to be the input of the + new FC operator, + 5. remove unused variables, + Args: + program (Program): program to transpile + ''' + + self.block = program.block(0) + self.input_map = {} # store the input names should be adjusted + i = 0 + while i < len(self.block.ops): + # find a elementwise add op + if self.block.ops[i].type == 'elementwise_add': + add_op = self.block.ops[i] + add_idx = i + mul_idx = -1 + # find the preceding mul op + for j in reversed(range(add_idx)): + if self.block.ops[j].type == 'mul': + mul_out_name = self.block.ops[j].output_arg_names[0] + if self.block.ops[j].output_arg_names[ + 0] in add_op.input_arg_names: + mul_op = self.block.ops[j] + mul_idx = j + break + if mul_idx < 0: + i += 1 + continue + # create and insert a new fc op + fc_op_new = self._insert_fc_op(add_idx + 1, mul_op, add_op) + # remove the old operators + self.block._remove_op(add_idx) + self.block._remove_op(mul_idx) + # restart scanning for elementwise add from the deleted mul's index + i = mul_idx + i += 1 + self._adjust_input() + self._remove_unused_var() + program = program.clone() + # ====================== private transpiler functions ===================== def _insert_bias_op(self, index, current_op, bn_op): ''' @@ -509,6 +566,42 @@ class InferenceTranspiler(object): outputs={"Output": out_var}, attrs=attrs) + def _insert_fc_op(self, index, mul_op, add_op): + ''' + Construct a new FC operator by copying the old Mul and adding the + 'Y' input taken from the Elementwise add's input 'Y'. + :param index: insert location of FC + :type index: Int + :param mul_op: MUL operator to be copied + :type mul_op: Operator + :param add_op: Elementwise add operator taken bias from + :type add_op: Operator + :return: fc_op_new + :type: Operator + ''' + + def get_op_outputs(op, names): + result = {} + for name in names: + result[name] = self.block.var(op.output(name)[0]) + return result + + fc_inputs = {} + fc_inputs['Input'] = self.block.var(mul_op.input('X')[0]) + fc_inputs['W'] = self.block.var(mul_op.input('Y')[0]) + fc_inputs['Bias'] = self.block.var(add_op.input('Y')[0]) + fc_outputs = get_op_outputs(add_op, ['Out']) + fc_attrs = {} + fc_attrs['use_mkldnn'] = True + + fc_op_new = self.block._insert_op( + index, + type='fc', + inputs=fc_inputs, + outputs=fc_outputs, + attrs=fc_attrs) + return fc_op_new + def _fuse_conv_eltwise(self, index, conv_op, eltwise_op): ''' fuse the conv op with elementwise_add -- GitLab