提交 956cf921 编写于 作者: H hjchen2

Fix conv_elementwise_add2_act pass

test=develop
上级 05d1121b
......@@ -40,18 +40,20 @@ framework::proto::OpDesc PrepareOpDesc(
const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetType("conv2d_fusion");
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {bias1});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
desc.SetAttr("use_cudnn", false);
desc.Flush();
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
const std::string pattern_name = "conv_elementwise_add2_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
......@@ -76,22 +78,23 @@ std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
graph->CreateOpNode(&new_op_desc);
auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, conv_op); // Input
IR_NODE_LINK_TO(conv_filter, conv_op); // Filter
IR_NODE_LINK_TO(conv_op, conv_out); // Output
IR_NODE_LINK_TO(elementwise_add_in_y, conv_op); // Bias
IR_NODE_LINK_TO(elementwise_add_in_y_1, conv_op); // Bias
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
IR_NODE_LINK_TO(conv_filter, new_conv_op); // Filter
IR_NODE_LINK_TO(elementwise_add_in_y, new_conv_op); // Bias
IR_NODE_LINK_TO(elementwise_add_in_y_1, new_conv_op); // Bias
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
{conv_op, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out});
GraphSafeRemoveNodes(
graph.get(),
{conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out, elementwise_add_out_1, act_op});
};
gpd(graph.get(), handler);
return graph;
......
......@@ -1101,9 +1101,7 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var;
}
std::unordered_set<std::string> conv_act_set({"identity", "sigmoid", "relu",
"relu6", "relux", "tanh",
"band_pass"});
std::unordered_set<std::string> conv_act_set({"identity", "relu"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
conv_in->AsInput();
......@@ -1169,13 +1167,13 @@ PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) {
->AsInput();
auto elementwise_add_out = pattern->NewNode(elementwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("elementwise_add", "X")
->assert_is_op_input("elementwise_add", "Y")
->AsIntermediate();
auto elementwise_add_op_1 = pattern->NewNode(elementwise_add_op_1_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_in_y_1 = pattern->NewNode(elementwise_add_in_y_1_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_is_op_input("elementwise_add", "X")
->AsInput();
auto elementwise_add_out_1 = pattern->NewNode(elementwise_add_out_1_repr())
->assert_is_op_output("elementwise_add")
......@@ -1203,8 +1201,8 @@ PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) {
conv_op->LinksFrom({conv_in, conv_filter}).LinksTo({conv_out});
elementwise_add_op->LinksFrom({conv_out, elementwise_add_in_y})
.LinksTo({elementwise_add_out});
elementwise_add_op_1->LinksFrom(
{elementwise_add_out, elementwise_add_in_y_1});
elementwise_add_op_1->LinksFrom({elementwise_add_out, elementwise_add_in_y_1})
.LinksTo({elementwise_add_out_1});
act_op->LinksFrom({elementwise_add_out_1}).LinksTo({act_out});
return act_out;
}
......
......@@ -22,7 +22,7 @@ DECLARE_bool(cudnn_exhaustive_search);
namespace paddle {
namespace operators {
#if CUDNN_VERSION >= 7001
#if CUDNN_VERSION >= 7100
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
......@@ -204,7 +204,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
#if CUDNN_VERSION >= 7001
#if CUDNN_VERSION >= 7100
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册