提交 0c39b97b 编写于 作者: M Michał Gallus 提交者: Tao Luo

[MKL-DNN] Add Fully Connected Op for inference only(#15226)

* fuse mul and elementwise add to fc

* Reimplement the FC forward operator

* Fix FC MKLDNN integration by transposing weights

* Add FC MKLDNN Pass

test=develop

* FC MKLDNN Pass: change memcpy to std::copy

* Fix MKLDNN FC handling of mismatch input and weights dims

* Lower tolerance for MKL-DNN in resnet50 test

test=develop

* Adjust FC to support MKLDNN Op placement

test=develop

* Adjust Placement Op to set use_mkldnn attribute for graph

test=develop

* MKLDNN FC: fix weights format so that gemm version is called

test=develop

* FC MKLDNN: Remove tolerance decrease from tester_helper

* FC MKL-DNN: Refactor the code, change input reorder to weight reorder

* MKL-DNN FC: Introduce operator caching

test=develop

* FC MKL-DNN: Fix the tensor type in ExpectedKernelType

test=develop

* FC MKL-DNN: fix style changes

test=develop

* FC MKL-DNN: fallback to native on non-supported dim sizes

test=develop

* FC MKLDNN: fix CMake paths

test=develop

* FC MKLDNN: Refine placement pass graph mkldnn attribute

test=develop

* Fix Transpiler error for fuse_conv_eltwise

test=develop

* Fix missing STL includes in files

test=develop

* FC MKL-DNN: Enable new output size computation

Also, refine pass to comply with newest interface.
test=develop

* FC MKL-DNN: enable only when fc_mkldnn_pass is enabled

* FC MKL-DNN: Allow Weights to use oi or io format

* FC MKL-DNN: Adjust UT to work with correct dims

test=develop

* Enable MKL DEBUG for resnet50 analyzer

test=develop

* FC MKL-DNN: Improve Hashing function

test=develop

* FC MKL-DNN: Fix shape for fc weights in transpiler

* FC MKL-DNN: Update input pointer in re-used fc primitive

* Add log for not handling fc fuse for unsupported dims

test=develop

* FC MKL-DNN: Move transpose from pass to Op Kernel

test=develop

* FC MKL-DNN: Disable transpose in unit test

test=develop

* FC MKL-DNN: Remove fc_mkldnn_pass from default list

* Correct Flag for fake data analyzer tests

test=develop

* FC MKL-DNN: Add comment about fc mkldnn pass disablement

test=develop

* FC MKL-DNN: Disable fc in int8 tests

test=develop
上级 21138eb1
...@@ -385,7 +385,7 @@ function(cc_test TARGET_NAME) ...@@ -385,7 +385,7 @@ function(cc_test TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true ${MKL_DEBUG_FLAG})
# No unit test should exceed 10 minutes. # No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif() endif()
......
...@@ -88,6 +88,7 @@ if(WITH_MKLDNN) ...@@ -88,6 +88,7 @@ if(WITH_MKLDNN)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(fc_mkldnn_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn) pass_library(cpu_quantize_pass inference mkldnn)
pass_library(cpu_quantize_squash_pass inference mkldnn) pass_library(cpu_quantize_squash_pass inference mkldnn)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -80,6 +81,7 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -80,6 +81,7 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
} }
desc.SetType("fc"); desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out}); GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
...@@ -896,6 +899,33 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, ...@@ -896,6 +899,33 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
} }
} }
PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
// Create shared nodes.
x->assert_is_op_input("fc", "Input");
auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc");
// Create variables
// Filter
auto *fc_weight_var = pattern->NewNode(weights_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "W");
// Bias
auto *fc_bias_var = pattern->NewNode(bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "Bias");
// Output
auto *fc_out_var = pattern->NewNode(output_repr())
->AsOutput()
->assert_is_op_output("fc", "Out")
->assert_is_only_output_of_op("fc");
fc_op->LinksFrom({x, fc_weight_var, fc_bias_var}).LinksTo({fc_out_var});
return fc_out_var;
}
PDNode *patterns::Embedding::operator()(PDNode *x) { PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids"); x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op = auto *lookup_table_op =
......
...@@ -517,6 +517,25 @@ struct FC : public PatternBase { ...@@ -517,6 +517,25 @@ struct FC : public PatternBase {
PATTERN_DECL_NODE(Out); PATTERN_DECL_NODE(Out);
}; };
// MKL-DNN's FC with bias
// op: fc
// named node:
// fc
// w, bias, output
struct FCMKLDNN : public PatternBase {
FCMKLDNN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_mkldnn") {}
PDNode* operator()(PDNode* x, bool with_bias);
// declare operator node's name
PATTERN_DECL_NODE(fc);
// declare variable node's name
PATTERN_DECL_NODE(weights);
PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(output);
};
// Embedding // Embedding
struct Embedding : public PatternBase { struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope) Embedding(PDPattern* pattern, const std::string& name_scope)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
Init("fc_mkldnn_pass", graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("fc_mkldnn_pass/x")
->AsInput()
->assert_is_op_input("fc", "Input");
patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass");
fc_pattern(x, true /*with bias*/);
int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Handle FC MKL-DNN pass";
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
VLOG(3) << "do not perform fc fuse";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
OpDesc* desc = fc->Op();
auto in_size = fc->inputs[0]->Var()->GetShape().size();
if (in_size != 2 && in_size != 4) {
VLOG(3) << "Do not enable FC MKL-DNN for dimensions different than 2 & 4";
return;
}
desc->SetAttr("use_mkldnn", true);
PADDLE_ENFORCE(subgraph.count(x));
found_fc_count++;
};
gpd(graph, handler);
AddStatis(found_fc_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fc_mkldnn_pass, paddle::framework::ir::FCMKLDNNPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Transpose weights of FC to comply with MKL-DNN interface
*/
class FCMKLDNNPass : public FusePassBase {
public:
virtual ~FCMKLDNNPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
...@@ -24,6 +25,9 @@ void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -24,6 +25,9 @@ void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies MKL-DNN placement strategy."; VLOG(3) << "Applies MKL-DNN placement strategy.";
const auto& op_types_list = const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types"); Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
if (!graph->Has("use_mkldnn")) {
graph->Set<bool>("use_mkldnn", new bool(true));
}
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
auto* op = n->Op(); auto* op = n->Op();
......
...@@ -146,16 +146,19 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -146,16 +146,19 @@ void CpuPassStrategy::EnableMKLDNN() {
if (!use_mkldnn_) { if (!use_mkldnn_) {
passes_.insert(passes_.begin(), "mkldnn_placement_pass"); passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : std::vector<std::string>( for (auto &pass : std::vector<std::string>({
{"depthwise_conv_mkldnn_pass", // "depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to "conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_bias_mkldnn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", //
"conv3d_bias_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", "conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass", // "conv_relu_mkldnn_fuse_pass", //
"conv_brelu_mkldnn_fuse_pass"})) { "conv_brelu_mkldnn_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass"
})) {
passes_.push_back(pass); passes_.push_back(pass);
} }
} }
......
...@@ -33,8 +33,10 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename) ...@@ -33,8 +33,10 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename)
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI} --paddle_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=2) --iterations=2)
endfunction() endfunction()
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name mkl_debug)
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name) if(mkl_debug)
set(MKL_DEBUG_FLAG MKL_DEBUG_CPU_TYPE=7)
endif()
download_model(${install_dir} ${model_name}) download_model(${install_dir} ${model_name})
inference_analysis_test(${target} SRCS ${filename} inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
...@@ -143,15 +145,15 @@ inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ...@@ -143,15 +145,15 @@ inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose
# googlenet # googlenet
inference_analysis_api_test_with_fake_data(test_analyzer_googlenet inference_analysis_api_test_with_fake_data(test_analyzer_googlenet
"${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz") "${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" false)
# resnet50 # resnet50
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz") "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" true)
# mobilenet with depthwise_conv op # mobilenet with depthwise_conv op
inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz") "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz" false)
# int8 image classification tests # int8 image classification tests
if(WITH_MKLDNN) if(WITH_MKLDNN)
......
...@@ -152,6 +152,7 @@ void profile(bool use_mkldnn = false) { ...@@ -152,6 +152,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
config.EnableMKLDNN(); config.EnableMKLDNN();
config.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
......
...@@ -200,8 +200,9 @@ void profile(bool use_mkldnn = false) { ...@@ -200,8 +200,9 @@ void profile(bool use_mkldnn = false) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
// Enable all the mkldnn supported ops except conv3d in dam // Enable all the mkldnn supported ops except conv3d in dam
std::unordered_set<std::string> op_list = {"softmax", "elementwise_add", std::unordered_set<std::string> op_list = {"softmax", "elementwise_add",
"relu"}; "relu", "fc"};
cfg.SetMKLDNNOp(op_list); cfg.SetMKLDNNOp(op_list);
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
......
...@@ -100,6 +100,7 @@ void profile(bool use_mkldnn = false) { ...@@ -100,6 +100,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -48,6 +48,7 @@ void profile(bool use_mkldnn = false) { ...@@ -48,6 +48,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
...@@ -79,6 +80,7 @@ void compare(bool use_mkldnn = false) { ...@@ -79,6 +80,7 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -149,6 +149,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { ...@@ -149,6 +149,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
} }
if (use_mkldnn) { if (use_mkldnn) {
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
cfg->pass_builder()->AppendPass("fc_mkldnn_pass");
} }
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much // Enable seqpool_concat_fuse_pass, disabled by default since it takes much
// time // time
......
...@@ -189,6 +189,7 @@ void profile(bool use_mkldnn = false) { ...@@ -189,6 +189,7 @@ void profile(bool use_mkldnn = false) {
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -85,6 +85,7 @@ void profile(bool use_mkldnn = false) { ...@@ -85,6 +85,7 @@ void profile(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
// cfg.pass_builder()->TurnOnDebug(); // cfg.pass_builder()->TurnOnDebug();
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
......
...@@ -12,299 +12,266 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,299 +12,266 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <mkldnn/include/mkldnn_types.h>
#include <memory>
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/operators/fc_op.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using paddle::framework::Tensor; using framework::DataLayout;
using paddle::platform::MKLDNNDeviceContext; using framework::Tensor;
using framework::LoDTensor;
using framework::DDim;
using framework::ExecutionContext;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;
using mkldnn::memory;
using mkldnn::inner_product_forward;
using mkldnn::primitive;
using mkldnn::stream;
using mkldnn::prop_kind;
template <typename T> template <typename T>
class MKLDNNMD { class FCPrimitiveFactory {
public: public:
explicit MKLDNNMD(const T* in, const T* w, bool bias) explicit FCPrimitiveFactory(const mkldnn::engine& engine) : engine_(engine) {}
: in(paddle::framework::vectorize2int(in->dims())),
w(paddle::framework::vectorize2int(w->dims())) { inner_product_forward CreateFcPrimitive(const LoDTensor* input,
with_bias_ = bias; const Tensor* weights,
} const Tensor* bias, LoDTensor* output,
const ExecutionContext& ctx) {
RecomputeOutputDims(ctx, input, weights, output);
if (fc_) {
UpdateDataPointers(ctx, output, input);
return *fc_;
}
auto src_desc = CreateMemDescriptor(input, input->format());
input_ = CreateMemory(src_desc, input);
mkldnn::memory::desc dst() const { weights_ = TransposeWeights(weights);
return platform::MKLDNNMemDesc({in[0], w[1]}, if (src_desc.data.ndims == 4) {
mkldnn::memory::data_type::f32, weights_ = CreateFourDimWeightsMemory(input, weights);
mkldnn::memory::format::nc); }
}
auto dst_desc = CreateMemDescriptor(output, memory::format::any);
mkldnn::memory::desc src() const { fc_ = CreateFcPrimitive(*input_, *weights_, dst_desc, bias, output, ctx);
return is_spatial() return *fc_;
? platform::MKLDNNMemDesc({in[0], in[1], in[2], in[3]},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::nchw)
: platform::MKLDNNMemDesc({in[0], in[1]},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::nc);
} }
mkldnn::memory::desc weights() const { private:
return is_spatial() void UpdateDataPointers(const ExecutionContext& ctx, Tensor* out,
? platform::MKLDNNMemDesc({w[1], in[1], in[2], in[3]}, const Tensor* in) {
mkldnn::memory::data_type::f32, input_->set_data_handle(const_cast<T*>(in->data<T>()));
mkldnn::memory::format::oihw) output_->set_data_handle(out->mutable_data<T>(ctx.GetPlace()));
: platform::MKLDNNMemDesc({w[1], in[1]}, if (out->format() == memory::format::format_undef) {
mkldnn::memory::data_type::f32, auto output_format = output_->get_primitive_desc().desc().data.format;
mkldnn::memory::format::oi); out->set_format((memory::format)output_format);
}
} }
mkldnn::memory::desc bias() const { memory::format MatchWeightFormat(memory::format fmt) {
return with_bias_ using format = memory::format;
? platform::MKLDNNMemDesc({w[1]}, mkldnn::memory::data_type::f32, switch (fmt) {
mkldnn::memory::format::format_undef) case format::nChw16c:
: platform::MKLDNNMemDesc({}, mkldnn::memory::data_type::f32, return format::oIhw16i;
mkldnn::memory::format::format_undef); case format::nChw8c:
return format::oIhw8i;
case format::nchw:
return format::oihw;
default:
return format::format_undef;
}
} }
private: mkldnn::memory Reorder(const memory::desc& src_desc,
bool is_spatial() const { return in.size() > 1 && w.size() > 1; } const memory::desc& dst_desc, const void* src_data) {
auto src_mem = memory({src_desc, engine_}, const_cast<void*>(src_data));
auto dst_mem = memory({dst_desc, engine_});
std::vector<int> in; auto reorder = mkldnn::reorder(src_mem, dst_mem);
std::vector<int> w; stream(stream::kind::eager).submit({reorder}).wait();
bool with_bias_;
bool is_spatial_;
};
class MKLDNNMemory { return dst_mem;
public:
MKLDNNMemory(MKLDNNMD<Tensor>* t, const mkldnn::engine& e)
: md_(t), engine_(e) {}
virtual ~MKLDNNMemory() = default;
template <typename Output>
mkldnn::memory dst(const Output* out) {
return mkldnn::memory({md_->dst(), engine_},
static_cast<void*>(const_cast<float*>(out)));
} }
template <typename Output> static mkldnn::memory::desc CreateMemDescriptor(const std::vector<int>& dims,
mkldnn::memory dst(Output* out) { memory::format format) {
return mkldnn::memory({md_->dst(), engine_}, out); return platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(),
format);
} }
template <typename Input> static mkldnn::memory::desc CreateMemDescriptor(const Tensor* tensor,
mkldnn::memory src(const Input* in) { memory::format format) {
return mkldnn::memory({md_->src(), engine_}, auto dims = framework::vectorize2int(tensor->dims());
static_cast<void*>(const_cast<float*>(in))); return CreateMemDescriptor(dims, format);
} }
template <typename Weight> mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc,
mkldnn::memory weights(const Weight* w) { const Tensor* tensor) {
return mkldnn::memory({md_->weights(), engine_}, return CreateMemory(desc, tensor->data<T>());
static_cast<void*>(const_cast<float*>(w)));
} }
mkldnn::memory bias() { mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc,
return mkldnn::memory(mkldnn::memory::primitive_desc(md_->bias(), engine_)); const void* data) {
return memory({desc, engine_}, const_cast<void*>(data));
} }
private: mkldnn::memory TransposeWeights(const Tensor* weights) {
MKLDNNMD<Tensor>* md_; auto dims = framework::vectorize2int(weights->dims());
const mkldnn::engine& engine_; std::swap(dims[0], dims[1]); // Correct output dimensions
}; auto src_desc = CreateMemDescriptor(dims, memory::format::io);
auto dst_desc = CreateMemDescriptor(dims, memory::format::oi);
template <typename T> return Reorder(src_desc, dst_desc, weights->data<T>());
class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> { }
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); inner_product_forward CreateFcPrimitive(const memory& src_memory,
const auto& mkldnn_engine = dev_ctx.GetEngine(); const memory& weights_memory,
const memory::desc& dst_desc,
const Tensor* bias, Tensor* output,
const ExecutionContext& ctx) {
const auto weights_desc = weights_memory.get_primitive_desc().desc();
const auto src_desc = src_memory.get_primitive_desc().desc();
if (bias) {
auto bias_desc = CreateMemDescriptor(bias, bias->format());
bias_ = CreateMemory(bias_desc, bias);
auto fc_prim_desc =
CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc);
output_ = CreateDstMemory(fc_prim_desc, ctx, output);
return inner_product_forward(fc_prim_desc, src_memory, weights_memory,
*bias_, *output_);
} else {
auto fc_prim_desc = CreateFcPrimDesc(src_desc, weights_desc, dst_desc);
output_ = CreateDstMemory(fc_prim_desc, ctx, output);
return inner_product_forward(fc_prim_desc, src_memory, weights_memory,
*output_);
}
}
auto input = ctx.Input<framework::LoDTensor>("Input"); mkldnn::inner_product_forward::primitive_desc CreateFcPrimDesc(
auto w = ctx.Input<Tensor>("W"); const mkldnn::memory::desc& input_desc,
auto bias = ctx.Input<Tensor>("Bias"); const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_desc) {
auto fc_desc =
inner_product_forward::desc(prop_kind::forward_scoring, input_desc,
weights_desc, bias_desc, dst_desc);
PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, return inner_product_forward::primitive_desc(fc_desc, engine_);
"Input must be with 2 or 4 dimensions, i.e. NCHW"); }
// TODO(intel friends): the native weight format is io,
// but the mkldnn weight format is oihw, which may need be transposed.
PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
bool with_bias = bias != nullptr; mkldnn::inner_product_forward::primitive_desc CreateFcPrimDesc(
MKLDNNMD<Tensor> md(input, w, with_bias); const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& dst_desc) {
auto fc_desc = inner_product_forward::desc(prop_kind::forward, input_desc,
weights_desc, dst_desc);
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> pd = return inner_product_forward::primitive_desc(fc_desc, engine_);
FcFwdPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(), }
with_bias, mkldnn_engine);
const std::string key = ctx.op().Output("Out"); mkldnn::memory CreateFourDimWeightsMemory(const Tensor* input,
const std::string key_fc_pd = key + "@fc_pd"; const Tensor* weights) {
auto input_dims = framework::vectorize2int(input->dims());
auto weight_dims = framework::vectorize2int(weights->dims());
auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]};
dev_ctx.SetBlob(key_fc_pd, pd); auto dst_format = MatchWeightFormat(input->format());
auto src_desc = CreateMemDescriptor(dims, memory::format::oihw);
auto dst_desc = CreateMemDescriptor(dims, dst_format);
MKLDNNMemory mem(&md, mkldnn_engine); return Reorder(src_desc, dst_desc, weights_->get_data_handle());
}
const T* input_data = input->data<T>(); mkldnn::memory CreateDstMemory(
const T* w_data = w->data<T>(); const mkldnn::inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx, Tensor* output) {
auto dst_prim_desc = fc_prim_desc.dst_primitive_desc();
auto buffer_size = dst_prim_desc.get_size();
T* output_data = output->mutable_data<T>(
ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, buffer_size);
output->set_format((memory::format)dst_prim_desc.desc().data.format);
return memory(dst_prim_desc, to_void_cast<T>(output_data));
}
auto output = ctx.Output<framework::LoDTensor>("Out"); void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input,
const Tensor* w, LoDTensor* output) {
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims"); int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
std::vector<int64_t> output_dims; std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims); FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims);
output->Resize(framework::make_ddim(output_dims)); output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod()); output->set_lod(input->lod());
}
T* output_data = output->mutable_data<T>(ctx.GetPlace()); private:
const mkldnn::engine& engine_;
auto dst_memory = mem.dst(output_data); boost::optional<memory> bias_;
auto src_memory = mem.src(input_data); boost::optional<memory> input_;
auto weights_memory = mem.weights(w_data); boost::optional<memory> output_;
// TODO(intel friends): bias memory should also be obtain from bias->data() boost::optional<memory> weights_;
auto bias_memory = mem.bias(); boost::optional<inner_product_forward> fc_;
};
auto forward = with_bias ? mkldnn::inner_product_forward( static std::string GetHash(const Tensor* input, const Tensor* weights,
*pd, src_memory, weights_memory, bias_memory, const std::string& suffix) {
dst_memory) auto dim2str = [](const DDim& operand_dims) {
: mkldnn::inner_product_forward( std::string str = "";
*pd, src_memory, weights_memory, dst_memory); for (size_t i = 0; i < operand_dims.size(); ++i) {
str += std::to_string(operand_dims[i]) + "-";
}
return str;
};
return std::to_string((unsigned)input->format()) + dim2str(weights->dims()) +
suffix;
}
std::vector<mkldnn::primitive> pipeline = {forward}; template <typename T>
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); std::shared_ptr<FCPrimitiveFactory<T>> GetPrimitiveFactory(
const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx,
const Tensor* input, const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
const std::string key = GetHash(input, weights, ctx.op().Output("Out"));
auto prim_creator =
std::static_pointer_cast<FCPrimitiveFactory<T>>(dev_ctx.GetBlob(key));
if (prim_creator == nullptr) {
prim_creator = std::make_shared<FCPrimitiveFactory<T>>(mkldnn_engine);
dev_ctx.SetBlob(key, prim_creator);
} }
private: return prim_creator;
std::unique_ptr<mkldnn::inner_product_forward::primitive_desc> }
FcFwdPrimitiveDesc(const mkldnn::memory::desc& src,
const mkldnn::memory::desc& weights,
const mkldnn::memory::desc& dst,
const mkldnn::memory::desc& bias, const bool with_bias,
const mkldnn::engine& engine) const {
auto desc = with_bias
? mkldnn::inner_product_forward::desc(
mkldnn::prop_kind::forward, src, weights, bias, dst)
: mkldnn::inner_product_forward::desc(
mkldnn::prop_kind::forward, src, weights, dst);
auto pd = new mkldnn::inner_product_forward::primitive_desc(desc, engine);
return std::unique_ptr<mkldnn::inner_product_forward::primitive_desc>(pd);
}
};
template <typename T> template <typename T>
class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { class FCMKLDNNOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
T* input_grad_data = nullptr; auto input = ctx.Input<LoDTensor>("Input");
T* w_grad_data = nullptr; auto w = ctx.Input<Tensor>("W");
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* w_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
const Tensor* input = ctx.Input<Tensor>("Input");
const T* input_data = input->data<T>();
const Tensor* w = ctx.Input<Tensor>("W");
const T* w_data = w->data<T>();
if (input_grad) {
input_grad->Resize(input->dims());
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
}
if (w_grad) {
w_grad->Resize(w->dims());
w_grad_data = w_grad->mutable_data<T>(ctx.GetPlace());
}
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
const T* out_grad_data = out_grad->data<T>();
auto bias = ctx.Input<Tensor>("Bias"); auto bias = ctx.Input<Tensor>("Bias");
bool with_bias = bias != nullptr; auto output = ctx.Output<LoDTensor>("Out");
MKLDNNMD<Tensor> md(input, w, with_bias);
MKLDNNMemory mem(&md, mkldnn_engine);
auto dst_memory = mem.dst(out_grad_data);
auto src_memory = mem.src(input_data);
auto weights_memory = mem.weights(w_data);
auto bias_memory = mem.bias();
const std::string key = ctx.op().Input("Out"); auto prim_creator =
const std::string key_fc_pd = key + "@fc_pd"; GetPrimitiveFactory<T>(dev_ctx, ctx, input, w, mkldnn_engine);
auto fc = prim_creator->CreateFcPrimitive(input, w, bias, output, ctx);
stream(stream::kind::eager).submit({fc}).wait();
auto pd = output->set_layout(DataLayout::kMKLDNN);
std::static_pointer_cast<mkldnn::inner_product_forward::primitive_desc>(
dev_ctx.GetBlob(key_fc_pd));
PADDLE_ENFORCE(pd != nullptr, "Fail to find key_fc_pd in device context");
if (w_grad) {
auto weights_grad_memory = mem.weights(w_grad_data);
mkldnn::inner_product_backward_weights::primitive_desc bwd_weight_pd =
FcBwdWeightsPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(),
with_bias, *pd, mkldnn_engine);
auto bwd_weights_prim = mkldnn::inner_product_backward_weights(
bwd_weight_pd, src_memory, dst_memory, weights_grad_memory,
bias_memory);
std::vector<mkldnn::primitive> pipeline{bwd_weights_prim};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
if (input_grad) {
auto src_grad_memory = mem.src(input_grad_data);
mkldnn::inner_product_backward_data::primitive_desc bwd_data_pd =
FcBwdDataPrimitiveDesc(md.src(), md.weights(), md.dst(), *pd,
mkldnn_engine);
auto bwd_data_prim = mkldnn::inner_product_backward_data(
bwd_data_pd, dst_memory, weights_memory, src_grad_memory);
std::vector<mkldnn::primitive> pipeline{bwd_data_prim};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
}
private:
mkldnn::inner_product_backward_weights::primitive_desc
FcBwdWeightsPrimitiveDesc(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& diff_weights,
const mkldnn::memory::desc& diff_dst, const mkldnn::memory::desc& bias,
const bool with_bias,
const mkldnn::inner_product_forward::primitive_desc& pd,
const mkldnn::engine& engine) const {
auto bwd_weight_desc = with_bias
? mkldnn::inner_product_backward_weights::desc(
src, diff_weights, bias, diff_dst)
: mkldnn::inner_product_backward_weights::desc(
src, diff_weights, diff_dst);
return mkldnn::inner_product_backward_weights::primitive_desc(
bwd_weight_desc, engine, pd);
}
mkldnn::inner_product_backward_data::primitive_desc FcBwdDataPrimitiveDesc(
const mkldnn::memory::desc& diff_src, const mkldnn::memory::desc& weights,
const mkldnn::memory::desc& diff_dst,
const mkldnn::inner_product_forward::primitive_desc& pd,
const mkldnn::engine& engine) const {
auto bwd_data_desc =
mkldnn::inner_product_backward_data::desc(diff_src, weights, diff_dst);
return mkldnn::inner_product_backward_data::primitive_desc(bwd_data_desc,
engine, pd);
} }
}; };
} // namespace operators } // namespace operators
...@@ -312,6 +279,3 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -312,6 +279,3 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
REGISTER_OP_KERNEL(fc, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(fc, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::FCMKLDNNOpKernel<float>); paddle::operators::FCMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(fc_grad, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::FCMKLDNNGradOpKernel<float>);
...@@ -20,34 +20,30 @@ from paddle.fluid.tests.unittests.op_test import OpTest ...@@ -20,34 +20,30 @@ from paddle.fluid.tests.unittests.op_test import OpTest
def fully_connected_naive(input, weights, bias_data=None): def fully_connected_naive(input, weights, bias_data=None):
in_n, in_c, in_h, in_w = input.shape
w_h, w_c = weights.shape
x_data = np.reshape(input, [in_n, in_c * in_h * in_w])
# this transpose should be implemented at C code
w_data = np.transpose(np.reshape(weights, (w_c, in_c * in_h * in_w)))
result = None result = None
if not bias_data: if not bias_data:
result = np.dot(x_data, w_data) result = np.dot(input, weights)
else: else:
result = np.dot(x_data, w_data) + bias_data result = np.dot(input, weights) + bias_data
return result return result
class MatrixGenerate: class MatrixGenerate:
def __init__(self, mb, ic, oc, h, w): def __init__(self, mb, ic, oc, h, w):
self.input = np.random.random((mb, ic, h, w)).astype("float32") self.input = np.random.random((mb, ic * h * w)).astype("float32")
self.weights = np.random.random((ic * h * w, oc)).astype("float32") self.weights = np.random.random((ic * h * w, oc)).astype("float32")
class TestFCMKLDNNOp(OpTest): class TestFCMKLDNNOp(OpTest):
def create_data(self):
self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
def setUp(self): def setUp(self):
self.op_type = "fc" self.op_type = "fc"
self.use_mkldnn = True self.use_mkldnn = True
self.matrix = MatrixGenerate(1, 10, 15, 3, 3) self.create_data()
self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights}
self.attrs = {'use_mkldnn': self.use_mkldnn, } self.attrs = {'use_mkldnn': self.use_mkldnn, }
...@@ -60,37 +56,16 @@ class TestFCMKLDNNOp(OpTest): ...@@ -60,37 +56,16 @@ class TestFCMKLDNNOp(OpTest):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(set(['Input', 'W']), 'Out', max_relative_error=0.9) pass
def test_check_grad_no_weight(self): def test_check_grad_no_weight(self):
self.check_grad( pass
['Input'], 'Out', max_relative_error=0.5, no_grad_set=set('W'))
class TestFCMKLDNNOp1(TestFCMKLDNNOp): class TestFCMKLDNNOp1(TestFCMKLDNNOp):
def init_op_type(self): def create_data(self):
self.matrix = MatrixGenerate(2, 15, 48, 2, 2) self.matrix = MatrixGenerate(2, 15, 48, 2, 2)
class TestFCMKLDNNOp2(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 32, 40, 1, 1)
class TestFCMKLDNNOp3(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 2, 4, 1, 1)
class TestFCMKLDNNOp4(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 32, 48, 2, 2)
class TestFCMKLDNNOp4(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 32, 1000, 6, 6)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -76,6 +76,7 @@ class InferenceTranspiler(object): ...@@ -76,6 +76,7 @@ class InferenceTranspiler(object):
self._fuse_conv_relu_mkldnn( self._fuse_conv_relu_mkldnn(
program) # ResNet residual block merging program) # ResNet residual block merging
self._fuse_bn_relu_mkldnn(program) self._fuse_bn_relu_mkldnn(program)
self._fuse_mul_add_mkldnn(program)
self._is_test_pass(program) self._is_test_pass(program)
...@@ -387,6 +388,62 @@ class InferenceTranspiler(object): ...@@ -387,6 +388,62 @@ class InferenceTranspiler(object):
# And a better solution will be considered later. # And a better solution will be considered later.
program = program.clone() program = program.clone()
def _fuse_mul_add_mkldnn(self, program):
'''
Transpile the program by fusing Mul+Add layers to FC layer with the MKL-DNN inner product.
The MUL following a Elementwise_add layer can be replaced by the MKL-DNN FC.
The Elementwise add's bias input 'Y' has to be added into the
MKL-DNN-based FC input 'Bias'.
The operator transformation is:
- before:
- MUL->elementwise_add -> any_other_op
- after:
- FC -> any_other_op
The transpile stages are:
1. insert a new MKL-DNN-based FC operator with `Bias` input
taken from the Elementwise add's input 'Y' (bias),
2. fuse the parameters of MUL and Elemenwise add,
3. remove the MUL, elementwise_add operators,
4. make the input of the deleted Elementwise add operator to be the input of the
new FC operator,
5. remove unused variables,
Args:
program (Program): program to transpile
'''
self.block = program.block(0)
self.input_map = {} # store the input names should be adjusted
i = 0
while i < len(self.block.ops):
# find a elementwise add op
if self.block.ops[i].type == 'elementwise_add':
add_op = self.block.ops[i]
add_idx = i
mul_idx = -1
# find the preceding mul op
for j in reversed(range(add_idx)):
if self.block.ops[j].type == 'mul':
mul_out_name = self.block.ops[j].output_arg_names[0]
if self.block.ops[j].output_arg_names[
0] in add_op.input_arg_names:
mul_op = self.block.ops[j]
mul_idx = j
break
if mul_idx < 0:
i += 1
continue
# create and insert a new fc op
fc_op_new = self._insert_fc_op(add_idx + 1, mul_op, add_op)
# remove the old operators
self.block._remove_op(add_idx)
self.block._remove_op(mul_idx)
# restart scanning for elementwise add from the deleted mul's index
i = mul_idx
i += 1
self._adjust_input()
self._remove_unused_var()
program = program.clone()
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
def _insert_bias_op(self, index, current_op, bn_op): def _insert_bias_op(self, index, current_op, bn_op):
''' '''
...@@ -509,6 +566,42 @@ class InferenceTranspiler(object): ...@@ -509,6 +566,42 @@ class InferenceTranspiler(object):
outputs={"Output": out_var}, outputs={"Output": out_var},
attrs=attrs) attrs=attrs)
def _insert_fc_op(self, index, mul_op, add_op):
'''
Construct a new FC operator by copying the old Mul and adding the
'Y' input taken from the Elementwise add's input 'Y'.
:param index: insert location of FC
:type index: Int
:param mul_op: MUL operator to be copied
:type mul_op: Operator
:param add_op: Elementwise add operator taken bias from
:type add_op: Operator
:return: fc_op_new
:type: Operator
'''
def get_op_outputs(op, names):
result = {}
for name in names:
result[name] = self.block.var(op.output(name)[0])
return result
fc_inputs = {}
fc_inputs['Input'] = self.block.var(mul_op.input('X')[0])
fc_inputs['W'] = self.block.var(mul_op.input('Y')[0])
fc_inputs['Bias'] = self.block.var(add_op.input('Y')[0])
fc_outputs = get_op_outputs(add_op, ['Out'])
fc_attrs = {}
fc_attrs['use_mkldnn'] = True
fc_op_new = self.block._insert_op(
index,
type='fc',
inputs=fc_inputs,
outputs=fc_outputs,
attrs=fc_attrs)
return fc_op_new
def _fuse_conv_eltwise(self, index, conv_op, eltwise_op): def _fuse_conv_eltwise(self, index, conv_op, eltwise_op):
''' '''
fuse the conv op with elementwise_add fuse the conv op with elementwise_add
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册