提交 240d974a 编写于 作者: Y Yihua Xu

Clean Code

test=develop
上级 82eefcea
...@@ -46,7 +46,6 @@ if(WITH_MKLDNN) ...@@ -46,7 +46,6 @@ if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base) pass_library(mkldnn_placement_pass base)
pass_library(depthwise_conv_mkldnn_pass base) pass_library(depthwise_conv_mkldnn_pass base)
pass_library(conv_bias_mkldnn_fuse_pass inference) pass_library(conv_bias_mkldnn_fuse_pass inference)
pass_library(conv3d_bias_mkldnn_fuse_pass inference)
pass_library(conv_relu_mkldnn_fuse_pass inference) pass_library(conv_relu_mkldnn_fuse_pass inference)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
endif() endif()
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h"
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp.
*/
class Conv3DBiasFusePass : public ConvBiasFusePass {
public:
bool is_conv3d() const override { return true; }
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -137,3 +137,5 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( ...@@ -137,3 +137,5 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
} // namespace paddle } // namespace paddle
REGISTER_PASS(conv_bias_mkldnn_fuse_pass, REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
paddle::framework::ir::ConvBiasFusePass); paddle::framework::ir::ConvBiasFusePass);
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass);
...@@ -32,6 +32,13 @@ class ConvBiasFusePass : public FusePassBase { ...@@ -32,6 +32,13 @@ class ConvBiasFusePass : public FusePassBase {
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_bias_mkldnn_fuse"}; const std::string name_scope_{"conv_bias_mkldnn_fuse"};
}; };
/*
* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp.
*/
class Conv3DBiasFusePass : public ConvBiasFusePass {
public:
bool is_conv3d() const override { return true; }
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1031,25 +1031,23 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( ...@@ -1031,25 +1031,23 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
PDNode *patterns::ConvBias::operator()( PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input, bool is_conv3d) { paddle::framework::ir::PDNode *conv_input, bool is_conv3d) {
std::string type = is_conv3d ? "conv3d" : "conv2d";
// Create Operators // Create Operators
conv_input->assert_is_op_input(is_conv3d ? "conv3d" : "conv2d", "Input"); conv_input->assert_is_op_input(type, "Input");
auto *conv_op = pattern->NewNode(conv_repr()) auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type);
->assert_is_op(is_conv3d ? "conv3d" : "conv2d");
auto *eltiwse_op = auto *eltiwse_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
// Create variables // Create variables
// Filter // Filter
auto *conv_weight_var = auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
pattern->NewNode(conv_weight_repr()) ->AsInput()
->AsInput() ->assert_is_persistable_var()
->assert_is_persistable_var() ->assert_is_op_input(type, "Filter");
->assert_is_op_input(is_conv3d ? "conv3d" : "conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse. // intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = auto *conv_out_var = pattern->NewNode(conv_out_repr())
pattern->NewNode(conv_out_repr()) ->AsIntermediate()
->AsIntermediate() ->assert_is_only_output_of_op(type)
->assert_is_only_output_of_op(is_conv3d ? "conv3d" : "conv2d") ->assert_is_op_input("elementwise_add");
->assert_is_op_input("elementwise_add");
// Bias stored in elementwise_add // Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
->AsInput() ->AsInput()
......
...@@ -100,6 +100,10 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -100,6 +100,10 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
const T *x_data = x->data<T>(); const T *x_data = x->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace()); T *y_data = y->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE(
x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
"Input dim must be with 2, 3 or 4");
std::vector<int> src_tz = framework::vectorize2int(x->dims()); std::vector<int> src_tz = framework::vectorize2int(x->dims());
auto src_format = auto src_format =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册