From 240d974ac5777b5e9d6c950bec91fe41e0aed093 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Fri, 7 Dec 2018 14:14:46 +0800 Subject: [PATCH] Clean Code test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 1 - .../ir/conv3d_bias_mkldnn_fuse_pass.cc | 18 ------------ .../ir/conv3d_bias_mkldnn_fuse_pass.h | 29 ------------------- .../ir/conv_bias_mkldnn_fuse_pass.cc | 2 ++ .../framework/ir/conv_bias_mkldnn_fuse_pass.h | 7 +++++ .../framework/ir/graph_pattern_detector.cc | 24 +++++++-------- .../fluid/operators/activation_mkldnn_op.cc | 4 +++ 7 files changed, 24 insertions(+), 61 deletions(-) delete mode 100644 paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.cc delete mode 100644 paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 0bbfe3c0e..883575e41 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -46,7 +46,6 @@ if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base) pass_library(depthwise_conv_mkldnn_pass base) pass_library(conv_bias_mkldnn_fuse_pass inference) - pass_library(conv3d_bias_mkldnn_fuse_pass inference) pass_library(conv_relu_mkldnn_fuse_pass inference) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference) endif() diff --git a/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.cc deleted file mode 100644 index e2968ddf6..000000000 --- a/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.cc +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h" - -REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, - paddle::framework::ir::Conv3DBiasFusePass); diff --git a/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h deleted file mode 100644 index 5afe2cc61..000000000 --- a/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h" - -namespace paddle { -namespace framework { -namespace ir { -/* -* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp. -*/ -class Conv3DBiasFusePass : public ConvBiasFusePass { - public: - bool is_conv3d() const override { return true; } -}; -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc index c3ad3a0f1..d4a701e0b 100644 --- a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc @@ -137,3 +137,5 @@ std::unique_ptr ConvBiasFusePass::ApplyImpl( } // namespace paddle REGISTER_PASS(conv_bias_mkldnn_fuse_pass, paddle::framework::ir::ConvBiasFusePass); +REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, + paddle::framework::ir::Conv3DBiasFusePass); diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h index c3b58bf58..f3ad9f1c2 100644 --- a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h @@ -32,6 +32,13 @@ class ConvBiasFusePass : public FusePassBase { std::unique_ptr ApplyImpl(std::unique_ptr graph) const; const std::string name_scope_{"conv_bias_mkldnn_fuse"}; }; +/* +* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp. +*/ +class Conv3DBiasFusePass : public ConvBiasFusePass { + public: + bool is_conv3d() const override { return true; } +}; } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index ed99d9883..0118019df 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1031,25 +1031,23 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( PDNode *patterns::ConvBias::operator()( paddle::framework::ir::PDNode *conv_input, bool is_conv3d) { + std::string type = is_conv3d ? "conv3d" : "conv2d"; // Create Operators - conv_input->assert_is_op_input(is_conv3d ? "conv3d" : "conv2d", "Input"); - auto *conv_op = pattern->NewNode(conv_repr()) - ->assert_is_op(is_conv3d ? "conv3d" : "conv2d"); + conv_input->assert_is_op_input(type, "Input"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type); auto *eltiwse_op = pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); // Create variables // Filter - auto *conv_weight_var = - pattern->NewNode(conv_weight_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input(is_conv3d ? "conv3d" : "conv2d", "Filter"); + auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input(type, "Filter"); // intermediate variable, will be removed in the IR after fuse. - auto *conv_out_var = - pattern->NewNode(conv_out_repr()) - ->AsIntermediate() - ->assert_is_only_output_of_op(is_conv3d ? "conv3d" : "conv2d") - ->assert_is_op_input("elementwise_add"); + auto *conv_out_var = pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op(type) + ->assert_is_op_input("elementwise_add"); // Bias stored in elementwise_add auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) ->AsInput() diff --git a/paddle/fluid/operators/activation_mkldnn_op.cc b/paddle/fluid/operators/activation_mkldnn_op.cc index 7fa81e185..e16b6f78d 100644 --- a/paddle/fluid/operators/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/activation_mkldnn_op.cc @@ -100,6 +100,10 @@ void eltwise_forward(const framework::ExecutionContext &ctx, const T *x_data = x->data(); T *y_data = y->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE( + x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4, + "Input dim must be with 2, 3 or 4"); + std::vector src_tz = framework::vectorize2int(x->dims()); auto src_format = -- GitLab