提交 6d8075ec 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] conv_transpose mkldnn bias pass (#17644)

* - changes to graph detector

- Changes to pass

- Added ut for new pass

- use_pass

- Added pass to mkldnn passes

- fix to registration

- improved verbose messaging for conv bias passes

- Lint fixes

test=develop

* - Lint fixes

test=develop
上级 41f1186c
...@@ -1092,12 +1092,12 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( ...@@ -1092,12 +1092,12 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
return ele_add_grad; return ele_add_grad;
} }
// conv_type: conv2d, conv3d, conv2d_transpose
PDNode *patterns::ConvBias::operator()( PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input, bool is_conv3d) { paddle::framework::ir::PDNode *conv_input, std::string conv_type) {
std::string type = is_conv3d ? "conv3d" : "conv2d";
// Create Operators // Create Operators
conv_input->assert_is_op_input(type, "Input"); conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type); auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);
auto *eltiwse_op = auto *eltiwse_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
// Create variables // Create variables
...@@ -1105,11 +1105,11 @@ PDNode *patterns::ConvBias::operator()( ...@@ -1105,11 +1105,11 @@ PDNode *patterns::ConvBias::operator()(
auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input(type, "Filter"); ->assert_is_op_input(conv_type, "Filter");
// intermediate variable, will be removed in the IR after fuse. // intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr()) auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_only_output_of_op(type) ->assert_is_only_output_of_op(conv_type)
->assert_is_op_input("elementwise_add"); ->assert_is_op_input("elementwise_add");
// Bias stored in elementwise_add // Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
......
...@@ -669,7 +669,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase { ...@@ -669,7 +669,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
struct ConvBias : public PatternBase { struct ConvBias : public PatternBase {
ConvBias(PDPattern* pattern, const std::string& name_scope) ConvBias(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bias") {} : PatternBase(pattern, name_scope, "conv_bias") {}
PDNode* operator()(PDNode* conv_input, bool is_conv3d = false); PDNode* operator()(PDNode* conv_input, std::string conv_type = "conv2d");
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(conv); PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(eltwise); PATTERN_DECL_NODE(eltwise);
......
...@@ -45,16 +45,14 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -45,16 +45,14 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
std::string type = is_conv3d() ? "conv3d" : "conv2d";
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
gpd.mutable_pattern() gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput() ->AsInput()
->assert_is_op_input(type, "Input"); ->assert_is_op_input(type(), "Input");
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_); patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
conv_bias_pattern(conv_input, is_conv3d()); conv_bias_pattern(conv_input, type());
int found_conv_bias_count = 0; int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
...@@ -75,7 +73,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -75,7 +73,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
// check if fuse can be done and if MKL-DNN should be used // check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise); FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) { if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
VLOG(3) << "do not perform conv+bias fuse"; VLOG(3) << "do not perform " + type() + "+bias fuse";
return; return;
} }
...@@ -110,7 +108,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -110,7 +108,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()})); desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()})); desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()})); desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
desc.SetType(type); desc.SetType(type());
for (auto& attr : conv->Op()->GetAttrMap()) { for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second); desc.SetAttr(attr.first, attr.second);
...@@ -135,5 +133,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -135,5 +133,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
} // namespace paddle } // namespace paddle
REGISTER_PASS(conv_bias_mkldnn_fuse_pass, REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
paddle::framework::ir::ConvBiasFusePass); paddle::framework::ir::ConvBiasFusePass);
REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DTransposeBiasFusePass);
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass); paddle::framework::ir::Conv3DBiasFusePass);
...@@ -26,7 +26,7 @@ namespace ir { ...@@ -26,7 +26,7 @@ namespace ir {
class ConvBiasFusePass : public FusePassBase { class ConvBiasFusePass : public FusePassBase {
public: public:
virtual ~ConvBiasFusePass() {} virtual ~ConvBiasFusePass() {}
virtual bool is_conv3d() const { return false; } virtual std::string type() const { return "conv2d"; }
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
...@@ -35,9 +35,14 @@ class ConvBiasFusePass : public FusePassBase { ...@@ -35,9 +35,14 @@ class ConvBiasFusePass : public FusePassBase {
/* /*
* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp. * Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp.
*/ */
class Conv2DTransposeBiasFusePass : public ConvBiasFusePass {
public:
std::string type() const override { return "conv2d_transpose"; }
};
class Conv3DBiasFusePass : public ConvBiasFusePass { class Conv3DBiasFusePass : public ConvBiasFusePass {
public: public:
bool is_conv3d() const override { return true; } std::string type() const override { return "conv3d"; }
}; };
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -141,7 +141,12 @@ TEST(ConvBiasFusePass, conv_with_existing_bias) { MainTest(true); } ...@@ -141,7 +141,12 @@ TEST(ConvBiasFusePass, conv_with_existing_bias) { MainTest(true); }
TEST(ConvBiasFusePass, conv3d) { TEST(ConvBiasFusePass, conv3d) {
Conv3DBiasFusePass pass; Conv3DBiasFusePass pass;
ASSERT_TRUE(pass.is_conv3d()); ASSERT_EQ(pass.type(), std::string("conv3d"));
}
TEST(ConvBiasFusePass, conv2d_transpose) {
Conv2DTransposeBiasFusePass pass;
ASSERT_EQ(pass.type(), std::string("conv2d_transpose"));
} }
} // namespace ir } // namespace ir
......
...@@ -169,6 +169,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -169,6 +169,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_bn_fuse_pass", // Execute BN passes again to "conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_bias_mkldnn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", //
"conv_transpose_bias_mkldnn_fuse_pass",
"conv3d_bias_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", "conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册