提交 7faa3e95 编写于 作者: A Adam 提交者: Tao Luo

Add ConvTranspose + BatchNorm fuse pass (#20161)

* Add ConvTranspose + BatchNorm fuse pass
test=develop

* Add tests for conv+bn and conv_transpose+bn passes
test=develop
上级 27d1ef60
......@@ -126,6 +126,7 @@ cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.c
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass)
cc_test(test_fc_elementwise_layernorm_fuse_pass SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass)
cc_test(test_conv_bn_fuse_pass SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass)
if(WITH_GPU)
cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass)
endif()
......
......@@ -51,7 +51,7 @@ void recompute_bias_and_weights(const Scope* scope,
const ir::Node& bn_mean, //
const ir::Node& bn_variance, //
LoDTensor* eltwise_y_in_tensor, //
float epsilon) {
float epsilon, const std::string& conv_type) {
using EigenVectorArrayMap =
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
using ConstEigenVectorArrayMap =
......@@ -92,13 +92,26 @@ void recompute_bias_and_weights(const Scope* scope,
// Re-compute weight of conv2d from BN
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims();
auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
auto weights_data = weights->mutable_data<float>(platform::CPUPlace());
// ConvTranspose weights are in IOHW format
if (conv_type == "conv2d_transpose") {
int kernel_size = weights_shape[2] * weights_shape[3];
for (int i = 0; i < weights->numel();) {
for (int j = 0; j < weights_shape[1]; ++j) {
for (int k = 0; k < kernel_size; ++k, ++i) {
weights_data[i] *= variance_array[j];
}
}
}
} else {
auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
EigenMatrixArrayMap weights_array_2d(
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],
weights_shape_2d[1]);
EigenMatrixArrayMap weights_array_2d(weights_data, weights_shape_2d[0],
weights_shape_2d[1]);
weights_array_2d.colwise() *= variance_array;
weights_array_2d.colwise() *= variance_array;
}
}
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
......@@ -113,14 +126,14 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
->assert_is_op_input(conv_type(), "Input");
patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
conv_bn_pattern(conv_input, false /*with_eltwise_add*/);
conv_bn_pattern(conv_input, conv_type(), false /*with_eltwise_add*/);
int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvBN fuse";
VLOG(4) << "handle " + conv_type() + "BN fuse";
// conv, batch_norm,
// conv_weight, conv_out,
......@@ -132,7 +145,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+bn fuse";
VLOG(3) << "do not perform " + conv_type() + " bn fuse";
return;
}
......@@ -160,7 +173,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon"));
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
*bn_mean, *bn_variance, eltwise_y_in_tensor,
epsilon);
epsilon, conv_type());
// with MKL-DNN fuse conv+bn into conv with bias
// without MKL-DNN fuse conv+bn into conv+elementwise_add
......@@ -187,7 +200,6 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
}
conv->Op()->SetOutput("Output",
std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph,
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
......@@ -233,14 +245,14 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
->assert_is_op_input(conv_type(), "Input");
patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
conv_bn_pattern(conv_input, true /*with_eltwise_add*/);
conv_bn_pattern(conv_input, conv_type(), true /*with_eltwise_add*/);
int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvBN fuse";
VLOG(4) << "handle " + conv_type() + "BN fuse";
// conv, batch_norm,
// conv_weight, conv_out,
......@@ -266,7 +278,7 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon"));
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
*bn_mean, *bn_variance, eltwise_y_in_tensor,
epsilon);
epsilon, conv_type());
// Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1);
......@@ -294,3 +306,7 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass);
REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::ConvEltwiseAddBNFusePass);
REGISTER_PASS(conv_transpose_bn_fuse_pass,
paddle::framework::ir::ConvTransposeBNFusePass);
REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass);
......@@ -29,6 +29,7 @@ namespace ir {
class ConvBNFusePass : public FusePassBase {
public:
virtual ~ConvBNFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
protected:
void ApplyImpl(ir::Graph* graph) const override;
......@@ -38,12 +39,23 @@ class ConvBNFusePass : public FusePassBase {
class ConvEltwiseAddBNFusePass : public FusePassBase {
public:
virtual ~ConvEltwiseAddBNFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_eltwiseadd_bn_fuse"};
};
class ConvTransposeBNFusePass : public ConvBNFusePass {
public:
std::string conv_type() const { return "conv2d_transpose"; }
};
class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
public:
std::string conv_type() const { return "conv2d_transpose"; }
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "bias_1", {3});
AddVarToScope(param_scope, "scale", {3});
AddVarToScope(param_scope, "mean", {3});
AddVarToScope(param_scope, "variance", {3});
AddVarToScope(param_scope, "filters", {3, 3, 2, 2});
return param_scope;
}
void TestMain(const std::string& conv_type) {
// inputs operator output
// ------------------------------------------------------------------
// (in, filters, bias_0) conv -> conv_out
// (conv_out, scale,
// bias_1, mean, varaince) batch_norm -> (...)
Layers layers;
auto* in = layers.data("in", {1, 3, 20, 20});
auto* filters = layers.data("filters", {3, 3, 2, 2}, true);
auto* bias_0 = layers.data("bias_0", {3}, true);
VarDesc* conv_out;
if (conv_type == "conv_transpose") {
conv_out = layers.conv2d_transpose(in, filters, bias_0);
} else {
conv_out = layers.conv2d(in, filters, bias_0);
}
conv_out->SetShape({1, 3, 20, 20});
auto* scale = layers.data("scale", {3}, true);
auto* bias_1 = layers.data("bias_1", {3}, true);
auto* mean = layers.data("mean", {3}, true);
auto* variance = layers.data("variance", {3}, true);
layers.batch_norm(conv_out, scale, bias_1, mean, variance);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(conv_type + "_bn_fuse_pass");
int num_bn_nodes_before = GetNumOpNodes(graph, "batch_norm");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_bn_nodes_after = GetNumOpNodes(graph, "batch_norm");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_bn_nodes_before, 1);
PADDLE_ENFORCE_EQ(num_bn_nodes_after, 0);
}
TEST(ConvBNFusePass, conv2d) { TestMain("conv"); }
TEST(ConvBNFusePass, conv2d_tranpose) { TestMain("conv_transpose"); }
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_bn_fuse_pass);
......@@ -666,10 +666,11 @@ bool VarLinksFromOp(Node *node, const std::string &op_type) {
}
PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
const std::string &conv_type,
bool with_eltwise_add) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);
PDNode *eltwise_op = nullptr;
if (with_eltwise_add) {
......@@ -683,11 +684,11 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
->assert_is_op_input(conv_type, "Filter");
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d");
->assert_is_only_output_of_op(conv_type);
PDNode *eltwise_y_in_var = nullptr;
PDNode *eltwise_out_var = nullptr;
......
......@@ -404,7 +404,8 @@ struct ConvBN : public PatternBase {
ConvBN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bn") {}
PDNode* operator()(PDNode* conv_input, bool with_eltwise_add);
PDNode* operator()(PDNode* conv_input, const std::string& conv_type,
bool with_eltwise_add);
// declare operator node's name
PATTERN_DECL_NODE(conv);
......
......@@ -48,6 +48,19 @@ struct Layers {
return out;
}
VarDesc* conv2d_transpose(VarDesc* input, VarDesc* filter, VarDesc* bias) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("conv2d_transpose");
op->SetInput("Input", {input->Name()});
op->SetInput("Filter", {filter->Name()});
op->SetInput("Bias", {bias->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* depthwise_conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias,
bool use_cudnn) {
VarDesc* out = lod_tensor(unique_name());
......@@ -162,6 +175,33 @@ struct Layers {
return outs;
}
std::vector<VarDesc*> batch_norm(VarDesc* x, VarDesc* scale, VarDesc* bias,
VarDesc* mean, VarDesc* variance) {
VarDesc* y = lod_tensor(unique_name());
VarDesc* mean_out = lod_tensor(unique_name());
VarDesc* variance_out = lod_tensor(unique_name());
VarDesc* saved_mean = lod_tensor(unique_name());
VarDesc* saved_variance = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("batch_norm");
op->SetInput("X", {x->Name()});
op->SetInput("Scale", {scale->Name()});
op->SetInput("Bias", {bias->Name()});
op->SetInput("Mean", {mean->Name()});
op->SetInput("Variance", {variance->Name()});
op->SetOutput("Y", {y->Name()});
op->SetOutput("MeanOut", {mean_out->Name()});
op->SetOutput("VarianceOut", {variance_out->Name()});
op->SetOutput("SavedMean", {saved_mean->Name()});
op->SetOutput("SavedVariance", {saved_variance->Name()});
op->SetAttr("epsilon", static_cast<float>(1e-5));
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
std::vector<VarDesc*> outs = {y, mean_out, variance_out, saved_mean,
saved_variance};
return outs;
}
private:
VarDesc* lod_tensor(std::string name, std::vector<int64_t> shape = {},
bool is_persistable = false) {
......
......@@ -155,17 +155,19 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// "seqpool_concat_fuse_pass", //
"seqpool_cvm_concat_fuse_pass", //
// "embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"is_test_pass", //
"fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"is_test_pass", //
// following pass should be located in the last, since
// it will work on all fused ops.
"runtime_context_cache_pass"});
......@@ -185,7 +187,9 @@ void CpuPassStrategy::EnableMKLDNN() {
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_bias_mkldnn_fuse_pass", //
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_transpose_bias_mkldnn_fuse_pass",
"conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册