提交 0c39b97b 编写于 作者: M Michał Gallus 提交者: Tao Luo

[MKL-DNN] Add Fully Connected Op for inference only(#15226)

* fuse mul and elementwise add to fc

* Reimplement the FC forward operator

* Fix FC MKLDNN integration by transposing weights

* Add FC MKLDNN Pass

test=develop

* FC MKLDNN Pass: change memcpy to std::copy

* Fix MKLDNN FC handling of mismatch input and weights dims

* Lower tolerance for MKL-DNN in resnet50 test

test=develop

* Adjust FC to support MKLDNN Op placement

test=develop

* Adjust Placement Op to set use_mkldnn attribute for graph

test=develop

* MKLDNN FC: fix weights format so that gemm version is called

test=develop

* FC MKLDNN: Remove tolerance decrease from tester_helper

* FC MKL-DNN: Refactor the code, change input reorder to weight reorder

* MKL-DNN FC: Introduce operator caching

test=develop

* FC MKL-DNN: Fix the tensor type in ExpectedKernelType

test=develop

* FC MKL-DNN: fix style changes

test=develop

* FC MKL-DNN: fallback to native on non-supported dim sizes

test=develop

* FC MKLDNN: fix CMake paths

test=develop

* FC MKLDNN: Refine placement pass graph mkldnn attribute

test=develop

* Fix Transpiler error for fuse_conv_eltwise

test=develop

* Fix missing STL includes in files

test=develop

* FC MKL-DNN: Enable new output size computation

Also, refine pass to comply with newest interface.
test=develop

* FC MKL-DNN: enable only when fc_mkldnn_pass is enabled

* FC MKL-DNN: Allow Weights to use oi or io format

* FC MKL-DNN: Adjust UT to work with correct dims

test=develop

* Enable MKL DEBUG for resnet50 analyzer

test=develop

* FC MKL-DNN: Improve Hashing function

test=develop

* FC MKL-DNN: Fix shape for fc weights in transpiler

* FC MKL-DNN: Update input pointer in re-used fc primitive

* Add log for not handling fc fuse for unsupported dims

test=develop

* FC MKL-DNN: Move transpose from pass to Op Kernel

test=develop

* FC MKL-DNN: Disable transpose in unit test

test=develop

* FC MKL-DNN: Remove fc_mkldnn_pass from default list

* Correct Flag for fake data analyzer tests

test=develop

* FC MKL-DNN: Add comment about fc mkldnn pass disablement

test=develop

* FC MKL-DNN: Disable fc in int8 tests

test=develop
上级 21138eb1
...@@ -385,7 +385,7 @@ function(cc_test TARGET_NAME) ...@@ -385,7 +385,7 @@ function(cc_test TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true ${MKL_DEBUG_FLAG})
# No unit test should exceed 10 minutes. # No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif() endif()
......
...@@ -88,6 +88,7 @@ if(WITH_MKLDNN) ...@@ -88,6 +88,7 @@ if(WITH_MKLDNN)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(fc_mkldnn_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn) pass_library(cpu_quantize_pass inference mkldnn)
pass_library(cpu_quantize_squash_pass inference mkldnn) pass_library(cpu_quantize_squash_pass inference mkldnn)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -80,6 +81,7 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -80,6 +81,7 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
} }
desc.SetType("fc"); desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out}); GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
...@@ -896,6 +899,33 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, ...@@ -896,6 +899,33 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
} }
} }
PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
// Create shared nodes.
x->assert_is_op_input("fc", "Input");
auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc");
// Create variables
// Filter
auto *fc_weight_var = pattern->NewNode(weights_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "W");
// Bias
auto *fc_bias_var = pattern->NewNode(bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "Bias");
// Output
auto *fc_out_var = pattern->NewNode(output_repr())
->AsOutput()
->assert_is_op_output("fc", "Out")
->assert_is_only_output_of_op("fc");
fc_op->LinksFrom({x, fc_weight_var, fc_bias_var}).LinksTo({fc_out_var});
return fc_out_var;
}
PDNode *patterns::Embedding::operator()(PDNode *x) { PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids"); x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op = auto *lookup_table_op =
......
...@@ -517,6 +517,25 @@ struct FC : public PatternBase { ...@@ -517,6 +517,25 @@ struct FC : public PatternBase {
PATTERN_DECL_NODE(Out); PATTERN_DECL_NODE(Out);
}; };
// MKL-DNN's FC with bias
// op: fc
// named node:
// fc
// w, bias, output
struct FCMKLDNN : public PatternBase {
FCMKLDNN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_mkldnn") {}
PDNode* operator()(PDNode* x, bool with_bias);
// declare operator node's name
PATTERN_DECL_NODE(fc);
// declare variable node's name
PATTERN_DECL_NODE(weights);
PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(output);
};
// Embedding // Embedding
struct Embedding : public PatternBase { struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope) Embedding(PDPattern* pattern, const std::string& name_scope)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
Init("fc_mkldnn_pass", graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("fc_mkldnn_pass/x")
->AsInput()
->assert_is_op_input("fc", "Input");
patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass");
fc_pattern(x, true /*with bias*/);
int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Handle FC MKL-DNN pass";
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
VLOG(3) << "do not perform fc fuse";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
OpDesc* desc = fc->Op();
auto in_size = fc->inputs[0]->Var()->GetShape().size();
if (in_size != 2 && in_size != 4) {
VLOG(3) << "Do not enable FC MKL-DNN for dimensions different than 2 & 4";
return;
}
desc->SetAttr("use_mkldnn", true);
PADDLE_ENFORCE(subgraph.count(x));
found_fc_count++;
};
gpd(graph, handler);
AddStatis(found_fc_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fc_mkldnn_pass, paddle::framework::ir::FCMKLDNNPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Transpose weights of FC to comply with MKL-DNN interface
*/
class FCMKLDNNPass : public FusePassBase {
public:
virtual ~FCMKLDNNPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
...@@ -24,6 +25,9 @@ void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -24,6 +25,9 @@ void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies MKL-DNN placement strategy."; VLOG(3) << "Applies MKL-DNN placement strategy.";
const auto& op_types_list = const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types"); Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
if (!graph->Has("use_mkldnn")) {
graph->Set<bool>("use_mkldnn", new bool(true));
}
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
auto* op = n->Op(); auto* op = n->Op();
......
...@@ -146,8 +146,8 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -146,8 +146,8 @@ void CpuPassStrategy::EnableMKLDNN() {
if (!use_mkldnn_) { if (!use_mkldnn_) {
passes_.insert(passes_.begin(), "mkldnn_placement_pass"); passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : std::vector<std::string>( for (auto &pass : std::vector<std::string>({
{"depthwise_conv_mkldnn_pass", // "depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to "conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_bias_mkldnn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", //
...@@ -155,7 +155,10 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -155,7 +155,10 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_elementwise_add_mkldnn_fuse_pass", "conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass", // "conv_relu_mkldnn_fuse_pass", //
"conv_brelu_mkldnn_fuse_pass"})) { "conv_brelu_mkldnn_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass"
})) {
passes_.push_back(pass); passes_.push_back(pass);
} }
} }
......
...@@ -33,8 +33,10 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename) ...@@ -33,8 +33,10 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename)
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI} --paddle_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=2) --iterations=2)
endfunction() endfunction()
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name mkl_debug)
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name) if(mkl_debug)
set(MKL_DEBUG_FLAG MKL_DEBUG_CPU_TYPE=7)
endif()
download_model(${install_dir} ${model_name}) download_model(${install_dir} ${model_name})
inference_analysis_test(${target} SRCS ${filename} inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
...@@ -143,15 +145,15 @@ inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ...@@ -143,15 +145,15 @@ inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose
# googlenet # googlenet
inference_analysis_api_test_with_fake_data(test_analyzer_googlenet inference_analysis_api_test_with_fake_data(test_analyzer_googlenet
"${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz") "${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" false)
# resnet50 # resnet50
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz") "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" true)
# mobilenet with depthwise_conv op # mobilenet with depthwise_conv op
inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz") "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz" false)
# int8 image classification tests # int8 image classification tests
if(WITH_MKLDNN) if(WITH_MKLDNN)
......
...@@ -152,6 +152,7 @@ void profile(bool use_mkldnn = false) { ...@@ -152,6 +152,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
config.EnableMKLDNN(); config.EnableMKLDNN();
config.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
......
...@@ -200,8 +200,9 @@ void profile(bool use_mkldnn = false) { ...@@ -200,8 +200,9 @@ void profile(bool use_mkldnn = false) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
// Enable all the mkldnn supported ops except conv3d in dam // Enable all the mkldnn supported ops except conv3d in dam
std::unordered_set<std::string> op_list = {"softmax", "elementwise_add", std::unordered_set<std::string> op_list = {"softmax", "elementwise_add",
"relu"}; "relu", "fc"};
cfg.SetMKLDNNOp(op_list); cfg.SetMKLDNNOp(op_list);
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
......
...@@ -100,6 +100,7 @@ void profile(bool use_mkldnn = false) { ...@@ -100,6 +100,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -48,6 +48,7 @@ void profile(bool use_mkldnn = false) { ...@@ -48,6 +48,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
...@@ -79,6 +80,7 @@ void compare(bool use_mkldnn = false) { ...@@ -79,6 +80,7 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -149,6 +149,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { ...@@ -149,6 +149,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
} }
if (use_mkldnn) { if (use_mkldnn) {
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
cfg->pass_builder()->AppendPass("fc_mkldnn_pass");
} }
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much // Enable seqpool_concat_fuse_pass, disabled by default since it takes much
// time // time
......
...@@ -189,6 +189,7 @@ void profile(bool use_mkldnn = false) { ...@@ -189,6 +189,7 @@ void profile(bool use_mkldnn = false) {
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -85,6 +85,7 @@ void profile(bool use_mkldnn = false) { ...@@ -85,6 +85,7 @@ void profile(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
// cfg.pass_builder()->TurnOnDebug(); // cfg.pass_builder()->TurnOnDebug();
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
......
...@@ -20,34 +20,30 @@ from paddle.fluid.tests.unittests.op_test import OpTest ...@@ -20,34 +20,30 @@ from paddle.fluid.tests.unittests.op_test import OpTest
def fully_connected_naive(input, weights, bias_data=None): def fully_connected_naive(input, weights, bias_data=None):
in_n, in_c, in_h, in_w = input.shape
w_h, w_c = weights.shape
x_data = np.reshape(input, [in_n, in_c * in_h * in_w])
# this transpose should be implemented at C code
w_data = np.transpose(np.reshape(weights, (w_c, in_c * in_h * in_w)))
result = None result = None
if not bias_data: if not bias_data:
result = np.dot(x_data, w_data) result = np.dot(input, weights)
else: else:
result = np.dot(x_data, w_data) + bias_data result = np.dot(input, weights) + bias_data
return result return result
class MatrixGenerate: class MatrixGenerate:
def __init__(self, mb, ic, oc, h, w): def __init__(self, mb, ic, oc, h, w):
self.input = np.random.random((mb, ic, h, w)).astype("float32") self.input = np.random.random((mb, ic * h * w)).astype("float32")
self.weights = np.random.random((ic * h * w, oc)).astype("float32") self.weights = np.random.random((ic * h * w, oc)).astype("float32")
class TestFCMKLDNNOp(OpTest): class TestFCMKLDNNOp(OpTest):
def create_data(self):
self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
def setUp(self): def setUp(self):
self.op_type = "fc" self.op_type = "fc"
self.use_mkldnn = True self.use_mkldnn = True
self.matrix = MatrixGenerate(1, 10, 15, 3, 3) self.create_data()
self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights}
self.attrs = {'use_mkldnn': self.use_mkldnn, } self.attrs = {'use_mkldnn': self.use_mkldnn, }
...@@ -60,37 +56,16 @@ class TestFCMKLDNNOp(OpTest): ...@@ -60,37 +56,16 @@ class TestFCMKLDNNOp(OpTest):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(set(['Input', 'W']), 'Out', max_relative_error=0.9) pass
def test_check_grad_no_weight(self): def test_check_grad_no_weight(self):
self.check_grad( pass
['Input'], 'Out', max_relative_error=0.5, no_grad_set=set('W'))
class TestFCMKLDNNOp1(TestFCMKLDNNOp): class TestFCMKLDNNOp1(TestFCMKLDNNOp):
def init_op_type(self): def create_data(self):
self.matrix = MatrixGenerate(2, 15, 48, 2, 2) self.matrix = MatrixGenerate(2, 15, 48, 2, 2)
class TestFCMKLDNNOp2(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 32, 40, 1, 1)
class TestFCMKLDNNOp3(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 2, 4, 1, 1)
class TestFCMKLDNNOp4(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 32, 48, 2, 2)
class TestFCMKLDNNOp4(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 32, 1000, 6, 6)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -76,6 +76,7 @@ class InferenceTranspiler(object): ...@@ -76,6 +76,7 @@ class InferenceTranspiler(object):
self._fuse_conv_relu_mkldnn( self._fuse_conv_relu_mkldnn(
program) # ResNet residual block merging program) # ResNet residual block merging
self._fuse_bn_relu_mkldnn(program) self._fuse_bn_relu_mkldnn(program)
self._fuse_mul_add_mkldnn(program)
self._is_test_pass(program) self._is_test_pass(program)
...@@ -387,6 +388,62 @@ class InferenceTranspiler(object): ...@@ -387,6 +388,62 @@ class InferenceTranspiler(object):
# And a better solution will be considered later. # And a better solution will be considered later.
program = program.clone() program = program.clone()
def _fuse_mul_add_mkldnn(self, program):
'''
Transpile the program by fusing Mul+Add layers to FC layer with the MKL-DNN inner product.
The MUL following a Elementwise_add layer can be replaced by the MKL-DNN FC.
The Elementwise add's bias input 'Y' has to be added into the
MKL-DNN-based FC input 'Bias'.
The operator transformation is:
- before:
- MUL->elementwise_add -> any_other_op
- after:
- FC -> any_other_op
The transpile stages are:
1. insert a new MKL-DNN-based FC operator with `Bias` input
taken from the Elementwise add's input 'Y' (bias),
2. fuse the parameters of MUL and Elemenwise add,
3. remove the MUL, elementwise_add operators,
4. make the input of the deleted Elementwise add operator to be the input of the
new FC operator,
5. remove unused variables,
Args:
program (Program): program to transpile
'''
self.block = program.block(0)
self.input_map = {} # store the input names should be adjusted
i = 0
while i < len(self.block.ops):
# find a elementwise add op
if self.block.ops[i].type == 'elementwise_add':
add_op = self.block.ops[i]
add_idx = i
mul_idx = -1
# find the preceding mul op
for j in reversed(range(add_idx)):
if self.block.ops[j].type == 'mul':
mul_out_name = self.block.ops[j].output_arg_names[0]
if self.block.ops[j].output_arg_names[
0] in add_op.input_arg_names:
mul_op = self.block.ops[j]
mul_idx = j
break
if mul_idx < 0:
i += 1
continue
# create and insert a new fc op
fc_op_new = self._insert_fc_op(add_idx + 1, mul_op, add_op)
# remove the old operators
self.block._remove_op(add_idx)
self.block._remove_op(mul_idx)
# restart scanning for elementwise add from the deleted mul's index
i = mul_idx
i += 1
self._adjust_input()
self._remove_unused_var()
program = program.clone()
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
def _insert_bias_op(self, index, current_op, bn_op): def _insert_bias_op(self, index, current_op, bn_op):
''' '''
...@@ -509,6 +566,42 @@ class InferenceTranspiler(object): ...@@ -509,6 +566,42 @@ class InferenceTranspiler(object):
outputs={"Output": out_var}, outputs={"Output": out_var},
attrs=attrs) attrs=attrs)
def _insert_fc_op(self, index, mul_op, add_op):
'''
Construct a new FC operator by copying the old Mul and adding the
'Y' input taken from the Elementwise add's input 'Y'.
:param index: insert location of FC
:type index: Int
:param mul_op: MUL operator to be copied
:type mul_op: Operator
:param add_op: Elementwise add operator taken bias from
:type add_op: Operator
:return: fc_op_new
:type: Operator
'''
def get_op_outputs(op, names):
result = {}
for name in names:
result[name] = self.block.var(op.output(name)[0])
return result
fc_inputs = {}
fc_inputs['Input'] = self.block.var(mul_op.input('X')[0])
fc_inputs['W'] = self.block.var(mul_op.input('Y')[0])
fc_inputs['Bias'] = self.block.var(add_op.input('Y')[0])
fc_outputs = get_op_outputs(add_op, ['Out'])
fc_attrs = {}
fc_attrs['use_mkldnn'] = True
fc_op_new = self.block._insert_op(
index,
type='fc',
inputs=fc_inputs,
outputs=fc_outputs,
attrs=fc_attrs)
return fc_op_new
def _fuse_conv_eltwise(self, index, conv_op, eltwise_op): def _fuse_conv_eltwise(self, index, conv_op, eltwise_op):
''' '''
fuse the conv op with elementwise_add fuse the conv op with elementwise_add
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册