未验证 提交 fdc06f21 编写于 作者: Z Zhang Ting 提交者: GitHub

add Fuse bn add act pass (#28196)

* add fuse_bn_add_act pass
上级 813b2ade
...@@ -107,7 +107,7 @@ cc_test(exception_holder_test SRCS exception_holder_test.cc ) ...@@ -107,7 +107,7 @@ cc_test(exception_holder_test SRCS exception_holder_test.cc )
set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass fuse_elewise_add_act_pass fuse_bn_act_pass fuse_bn_add_act_pass
multi_batch_merge_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
lock_free_optimize_pass lock_free_optimize_pass
......
...@@ -164,6 +164,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -164,6 +164,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_, AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
"fuse_relu_depthwise_conv_pass"); "fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass"); AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
AppendPassWithCheck(strategy_.fuse_bn_add_act_ops_, "fuse_bn_add_act_pass");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__)
AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass"); AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass");
#else #else
...@@ -390,6 +391,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -390,6 +391,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped."; "GPU, skipped.";
continue; continue;
} }
} else if (pass->Type() == "fuse_bn_add_act_pass") {
if (!use_cuda) {
LOG(WARNING) << "fuse_bn_add_act_pass is only supported on "
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "mkldnn_placement_pass") { } else if (pass->Type() == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types", pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types_)); new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
...@@ -416,6 +423,7 @@ USE_PASS(sync_batch_norm_pass); ...@@ -416,6 +423,7 @@ USE_PASS(sync_batch_norm_pass);
USE_PASS(fuse_relu_depthwise_conv_pass); USE_PASS(fuse_relu_depthwise_conv_pass);
USE_PASS(fuse_elewise_add_act_pass); USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(fuse_bn_act_pass); USE_PASS(fuse_bn_act_pass);
USE_PASS(fuse_bn_add_act_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass); USE_PASS(multi_batch_merge_pass);
USE_PASS(reduce_mode_multi_devices_pass); USE_PASS(reduce_mode_multi_devices_pass);
......
...@@ -100,6 +100,7 @@ struct BuildStrategy { ...@@ -100,6 +100,7 @@ struct BuildStrategy {
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have // TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle. // cycle.
bool fuse_bn_act_ops_{false}; bool fuse_bn_act_ops_{false};
bool fuse_bn_add_act_ops_{true};
bool fuse_elewise_add_act_ops_{false}; bool fuse_elewise_add_act_ops_{false};
bool enable_auto_fusion_{false}; bool enable_auto_fusion_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients // Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
......
...@@ -114,6 +114,7 @@ if(WITH_MKLDNN) ...@@ -114,6 +114,7 @@ if(WITH_MKLDNN)
endif() endif()
cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_bn_add_act_pass SRCS fuse_bn_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_bn_add_act_pass.h"
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle {
namespace framework {
namespace ir {
void FuseBatchNormAddActPass::ApplyImpl(ir::Graph *graph) const {
#ifdef PADDLE_WITH_CUDA
#if CUDNN_VERSION_MIN(7, 4, 1)
// forward
std::unordered_set<std::string> act_types = {"relu"};
graph = FuseBatchNormAddAct(graph, act_types);
// backward
std::unordered_set<std::string> act_grad_types = {"relu_grad"};
graph = FuseBatchNormAddActGrad(graph, act_grad_types);
#endif
#endif
}
// act(bn(x) + z)
ir::Graph *FuseBatchNormAddActPass::FuseBatchNormAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE_NE(
graph, nullptr,
platform::errors::InvalidArgument(
"The input graph of FuseBatchNormAddAct should not be nullptr."));
FusePassBase::Init("bn_add_act", graph);
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("bn_add_act/x")
->AsInput()
->assert_is_op_input("batch_norm", "X")
->assert_var_dtype(proto::VarType::FP16);
patterns::BatchNormAddAct bn_add_act_pattern(gpd.mutable_pattern(),
"bn_add_act");
bn_add_act_pattern(x, act_types);
int found_bn_add_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseBatchNormAddAct fuse";
// BN inputs
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, bn_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, bn_add_act_pattern);
// BN outputs
GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, bn_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out,
bn_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance,
bn_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, bn_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_reserve_space, bn_reserve_space,
bn_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_add_act_pattern);
// Add outputs
GET_IR_NODE_FROM_SUBGRAPH(elewise_add_in, elewise_add_in,
bn_add_act_pattern);
// Add outputs
GET_IR_NODE_FROM_SUBGRAPH(elewise_add_out, elewise_add_out,
bn_add_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_add_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, bn_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elewise_add, elewise_add, bn_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_add_act_pattern);
std::string bn_x_n = subgraph.at(x)->Name();
std::string elewise_add_in_n = elewise_add_in->Name();
std::string bn_scale_n = bn_scale->Name();
std::string bn_bias_n = bn_bias->Name();
std::string bn_mean_out_n = bn_mean_out->Name();
std::string bn_variance_out_n = bn_variance_out->Name();
std::string bn_saved_variance_n = bn_saved_variance->Name();
std::string bn_saved_mean_n = bn_saved_mean->Name();
std::string bn_reserve_space_n = bn_reserve_space->Name();
std::string bn_out_n = bn_out->Name();
std::string elewise_add_out_n = elewise_add_out->Name();
std::string act_out_n = act_out->Name();
Node *fused_bn_add_act_node = CreateFusedBatchNormAddActNode(
g, act, elewise_add, batch_norm, bn_x_n, elewise_add_in_n, bn_scale_n,
bn_bias_n, bn_mean_out_n, bn_variance_out_n, bn_saved_variance_n,
bn_saved_mean_n, bn_reserve_space_n, act_out_n);
VLOG(4) << "\n\t " << bn_x_n << ", " << bn_scale_n << ", " << bn_bias_n
<< " -> " << batch_norm->Name() << " -> " << bn_mean_out_n << ", "
<< bn_variance_out_n << ", " << bn_saved_variance_n << ", "
<< bn_saved_mean_n << ", " << bn_reserve_space_n << " and "
<< bn_out_n << "\n"
<< "\t " << bn_out_n << " and " << elewise_add_in_n << " -> "
<< elewise_add->Name() << " -> " << elewise_add_out_n << "\n"
<< "\t " << elewise_add_out_n << " -> " << act->Name() << " -> "
<< act_out_n;
ReLinkNodes(g, batch_norm, elewise_add, act, fused_bn_add_act_node);
found_bn_add_act_count++;
};
gpd(graph, handler);
AddStatis(found_bn_add_act_count);
return graph;
}
Node *FuseBatchNormAddActPass::CreateFusedBatchNormAddActNode(
Graph *g, const Node *act, const Node *elewise_add, const Node *bn,
const std::string &bn_x_n, const std::string &elewise_add_in_n,
const std::string &bn_scale_n, const std::string &bn_bias_n,
const std::string &bn_mean_out_n, const std::string &bn_variance_out_n,
const std::string &bn_saved_variance_n, const std::string &bn_saved_mean_n,
const std::string &bn_reserve_space_n, const std::string &act_out_n) const {
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({bn_x_n}));
desc.SetInput("Z", std::vector<std::string>({elewise_add_in_n}));
desc.SetInput("Scale", std::vector<std::string>({bn_scale_n}));
desc.SetInput("Bias", std::vector<std::string>({bn_bias_n}));
desc.SetOutput("Y", std::vector<std::string>({act_out_n}));
desc.SetOutput("MeanOut", std::vector<std::string>({bn_mean_out_n}));
desc.SetOutput("VarianceOut", std::vector<std::string>({bn_variance_out_n}));
desc.SetOutput("SavedMean", std::vector<std::string>({bn_saved_mean_n}));
desc.SetOutput("SavedVariance",
std::vector<std::string>({bn_saved_variance_n}));
desc.SetOutput("ReserveSpace",
std::vector<std::string>({bn_reserve_space_n}));
desc.SetType("fused_bn_add_activation");
desc.SetAttr("act_type", act->Name());
// Set attrs
for (auto &n : {act->Op(), elewise_add->Op(), bn->Op()}) {
for (auto &m : n->GetAttrMap()) {
desc.SetAttr(m.first, m.second);
}
}
auto fused_bn_add_act_node = g->CreateOpNode(&desc);
return fused_bn_add_act_node;
}
// the backward of act(bn(x) + z)
ir::Graph *FuseBatchNormAddActPass::FuseBatchNormAddActGrad(
ir::Graph *graph,
const std::unordered_set<std::string> &act_grad_types) const {
PADDLE_ENFORCE_NE(
graph, nullptr,
platform::errors::InvalidArgument(
"The input graph of FuseBatchNormAddActGrad should not be nullptr."));
FusePassBase::Init("bn_add_act_grad", graph);
GraphPatternDetector gpd;
auto *d_act_out =
gpd.mutable_pattern()
->NewNode("bn_add_act_grad/x")
->AsInput()
->assert_is_ops_input(act_grad_types, GradVarName("Out"))
->assert_var_dtype(proto::VarType::FP16);
patterns::BatchNormAddActGrad bn_add_act_grad_pattern(gpd.mutable_pattern(),
"bn_add_act_grad");
bn_add_act_grad_pattern(d_act_out, act_grad_types);
int found_bn_add_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseBatchNormAddActGrad fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elewise_add_grad, elewise_add_grad,
bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(batch_norm_grad, batch_norm_grad,
bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_act_x, d_act_x, bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_out, d_bn_out, bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_x, bn_x, bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean,
bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance,
bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_reserve_space, bn_reserve_space,
bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_x, d_bn_x, bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_scale, d_bn_scale, bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_bias, d_bn_bias, bn_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_elewise_add_in, d_elewise_add_in,
bn_add_act_grad_pattern);
std::string d_act_out_n = subgraph.at(d_act_out)->Name(); // Y@GRAD
std::string act_out_n = act_out->Name(); // Y
std::string d_act_x_n = d_act_x->Name();
std::string bn_x_n = bn_x->Name();
std::string bn_scale_n = bn_scale->Name();
std::string bn_bias_n = bn_bias->Name();
std::string bn_saved_mean_n = bn_saved_mean->Name();
std::string bn_saved_variance_n = bn_saved_variance->Name();
std::string bn_reserve_space_n = bn_reserve_space->Name();
std::string d_bn_out_n = d_bn_out->Name();
std::string d_bn_x_n = d_bn_x->Name();
std::string d_bn_scale_n = d_bn_scale->Name();
std::string d_bn_bias_n = d_bn_bias->Name();
std::string d_elewise_add_in_n = d_elewise_add_in->Name();
OpDesc desc;
desc.SetType("fused_bn_add_activation_grad");
desc.SetInput("X", {bn_x_n});
desc.SetInput("Y", std::vector<std::string>({act_out_n}));
desc.SetInput(GradVarName("Y"), std::vector<std::string>({d_act_out_n}));
desc.SetInput("Scale", std::vector<std::string>({bn_scale_n}));
desc.SetInput("Bias", std::vector<std::string>({bn_bias_n}));
desc.SetInput("SavedMean", std::vector<std::string>({bn_saved_mean_n}));
desc.SetInput("SavedVariance",
std::vector<std::string>({bn_saved_variance_n}));
desc.SetInput("ReserveSpace",
std::vector<std::string>({bn_reserve_space_n}));
desc.SetOutput(GradVarName("X"), std::vector<std::string>({d_bn_x_n}));
desc.SetOutput(GradVarName("Z"),
std::vector<std::string>({d_elewise_add_in_n}));
desc.SetOutput(GradVarName("Scale"),
std::vector<std::string>({d_bn_scale_n}));
desc.SetOutput(GradVarName("Bias"),
std::vector<std::string>({d_bn_bias_n}));
std::string act = act_grad->Name();
act = act.substr(0, act.length() - 5); // remove "_grad"
desc.SetAttr("act_type", act);
for (auto &n :
{act_grad->Op(), elewise_add_grad->Op(), batch_norm_grad->Op()}) {
for (auto &m : n->GetAttrMap()) {
desc.SetAttr(m.first, m.second);
}
}
auto fused_node = g->CreateOpNode(&desc);
VLOG(4) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> "
<< act_grad->Name() << " -> " << d_act_x_n << "\n\t ";
VLOG(4) << d_act_x_n << " -> " << elewise_add_grad->Name() << " -> "
<< d_elewise_add_in_n << "," << d_bn_out_n << "\n\t ";
VLOG(4) << bn_x_n << ", " << d_bn_out_n << ", " << bn_scale_n << ", "
<< bn_bias_n << ", " << bn_saved_mean_n << ", "
<< bn_saved_variance_n << " and " << bn_reserve_space_n << " -> "
<< batch_norm_grad->Name() << " -> " << d_bn_x_n << ", "
<< d_bn_scale_n << " and " << d_bn_bias_n;
ReLinkNodes(g, act_grad, elewise_add_grad, batch_norm_grad, fused_node);
found_bn_add_act_count++;
};
gpd(graph, handler);
AddStatis(found_bn_add_act_count);
return graph;
}
void FuseBatchNormAddActPass::ReLinkNodes(Graph *graph, Node *op_1, Node *op_2,
Node *op_3,
Node *fused_op) const { // delete act
// link inputs of op_1 to fused_op
for (auto &in : op_1->inputs) {
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_1, fused_op, in->outputs);
}
std::unordered_set<const Node *> nodes2delete;
LinkOutputsToFuseOp(op_1, op_2, fused_op, &nodes2delete);
LinkOutputsToFuseOp(op_2, op_3, fused_op, &nodes2delete);
LinkInputsToFuseOp(op_2, fused_op, &nodes2delete);
LinkInputsToFuseOp(op_3, fused_op, &nodes2delete);
for (auto &out : op_3->outputs) {
IR_OP_VAR_LINK(fused_op, out);
}
nodes2delete.insert(std::move(op_1));
nodes2delete.insert(std::move(op_2));
nodes2delete.insert(std::move(op_3));
GraphSafeRemoveNodes(graph, nodes2delete);
}
void FuseBatchNormAddActPass::LinkOutputsToFuseOp(
Node *op_1, Node *op_2, Node *fused_op,
std::unordered_set<const Node *> *nodes2delete) const {
// if the outputs of op_1 are inputs of op_2, add the outputs to nodes2delete
// otherwise link the outputs to fused_op
for (auto &out : op_1->outputs) {
auto result_iter =
std::find_if(op_2->inputs.begin(), op_2->inputs.end(),
[&out](const Node *node) -> bool { return node == out; });
if (result_iter == op_2->inputs.end()) {
IR_OP_VAR_LINK(fused_op, out);
} else {
nodes2delete->emplace(out);
}
}
}
void FuseBatchNormAddActPass::LinkInputsToFuseOp(
Node *op, Node *fused_op,
std::unordered_set<const Node *> *nodes2delete) const {
// if the inputs of the op are outputs of previous op, which means
// these inputs have been added to nodes2delete before, skip the inputs,
// otherwise link the inputs of the op to fused_op
for (auto &in : op->inputs) {
if (nodes2delete->count(in)) {
continue;
}
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op, fused_op, in->outputs);
}
}
std::vector<Node *> FuseBatchNormAddActPass::ReplaceNode(
Node *cur_node, Node *new_node, const std::vector<Node *> &nodes) const {
std::vector<Node *> new_list(nodes.size());
bool has_replaced = false;
std::transform(nodes.begin(), nodes.end(), new_list.begin(),
[&](Node *node) -> Node * {
if (node == cur_node) {
has_replaced = true;
return new_node;
}
return node;
});
PADDLE_ENFORCE_EQ(has_replaced, true,
platform::errors::NotFound("Not found %s in the node list.",
cur_node->Name()));
return new_list;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_bn_add_act_pass,
paddle::framework::ir::FuseBatchNormAddActPass);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the BatchNorm, add and activation.
*/
class Graph;
class Node;
class FuseBatchNormAddActPass : public FusePassBase {
public:
virtual ~FuseBatchNormAddActPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
ir::Graph *FuseBatchNormAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseBatchNormAddActGrad(
ir::Graph *graph,
const std::unordered_set<std::string> &act_grad_types) const;
void LinkOutputsToFuseOp(
Node *op_1, Node *op_2, Node *fused_op,
std::unordered_set<const Node *> *nodes2delete) const;
void LinkInputsToFuseOp(Node *op, Node *fused_op,
std::unordered_set<const Node *> *nodes2delete) const;
std::vector<Node *> ReplaceNode(Node *cur_node, Node *new_node,
const std::vector<Node *> &nodes) const;
void ReLinkNodes(Graph *graph, Node *op_1, Node *op_2, Node *op_3,
Node *fused_op) const;
Node *CreateFusedBatchNormAddActNode(
Graph *g, const Node *act, const Node *add, const Node *bn,
const std::string &bn_x_n, const std::string &add_y_n,
const std::string &bn_scale_n, const std::string &bn_bias_n,
const std::string &bn_mean_out_n, const std::string &bn_variance_out_n,
const std::string &bn_saved_variance_n,
const std::string &bn_saved_mean_n, const std::string &bn_reserve_space_n,
const std::string &act_out_n) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -93,6 +93,7 @@ void GraphPatternDetector::operator()(Graph *graph, ...@@ -93,6 +93,7 @@ void GraphPatternDetector::operator()(Graph *graph,
auto subgraphs = DetectPatterns(); auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs); UniquePatterns(&subgraphs);
SortSubgraphs(&subgraphs);
RemoveOverlappedMatch(&subgraphs); RemoveOverlappedMatch(&subgraphs);
ValidateByNodeRole(&subgraphs); ValidateByNodeRole(&subgraphs);
...@@ -302,6 +303,46 @@ void GraphPatternDetector::UniquePatterns( ...@@ -302,6 +303,46 @@ void GraphPatternDetector::UniquePatterns(
*subgraphs = result; *subgraphs = result;
} }
void GraphPatternDetector::SortSubgraphs(
std::vector<GraphPatternDetector::subgraph_t> *subgraphs) {
if (subgraphs->empty()) return;
bool has_bn_add_act = false;
for (auto &subgraph : *subgraphs) {
for (auto &item : subgraph) {
if (item.first->name().find("bn_add_act") != std::string::npos) {
has_bn_add_act = true;
break;
}
}
}
if (!has_bn_add_act) {
return;
}
std::sort(
subgraphs->begin(), subgraphs->end(),
[](const GraphPatternDetector::subgraph_t &a,
const GraphPatternDetector::subgraph_t &b) {
for (auto &item : a) {
if (item.first->name().find("bn_add_act") != std::string::npos &&
item.first->name().find("bn_reserve_space") !=
std::string::npos) {
auto it_b = b.find(item.first);
if (it_b != b.end()) {
if (item.second->Name() != it_b->second->Name()) {
return item.second->Name() < it_b->second->Name();
} else {
return false;
}
} else {
return false;
}
}
}
return false;
});
}
void GraphPatternDetector::RemoveOverlappedMatch( void GraphPatternDetector::RemoveOverlappedMatch(
std::vector<subgraph_t> *subgraphs) { std::vector<subgraph_t> *subgraphs) {
std::vector<subgraph_t> result; std::vector<subgraph_t> result;
...@@ -1208,6 +1249,151 @@ PDNode *patterns::BatchNormActOneDNN::operator()(const std::string &act_type) { ...@@ -1208,6 +1249,151 @@ PDNode *patterns::BatchNormActOneDNN::operator()(const std::string &act_type) {
return act_out; return act_out;
} }
PDNode *patterns::BatchNormAddAct::operator()(
paddle::framework::ir::PDNode *bn_x_var,
std::unordered_set<std::string> act_types) {
bn_x_var->assert_is_op_input("batch_norm", "X")
->assert_var_dtype(proto::VarType::FP16);
auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
->assert_is_op_input("batch_norm", "Scale");
auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
->assert_is_op_input("batch_norm", "Bias");
auto *bn = pattern->NewNode(batch_norm_repr())
->assert_is_op("batch_norm")
->assert_is_not_op_input("MomentumTensor")
->assert_op_attr<bool>("is_test", false)
->assert_op_attr<bool>("use_global_stats", false)
->assert_op_attr<std::string>("data_layout", "NHWC");
auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr())
->assert_is_op_output("batch_norm", "MeanOut");
auto *bn_variance_out_var =
pattern->NewNode(bn_variance_out_repr())
->assert_is_op_output("batch_norm", "VarianceOut");
auto *bn_saved_variance_var =
pattern->NewNode(bn_saved_variance_repr())
->assert_is_op_output("batch_norm", "SavedVariance");
auto *bn_saved_mean_var =
pattern->NewNode(bn_saved_mean_repr())
->assert_is_op_output("batch_norm", "SavedMean");
auto *bn_reserve_space =
pattern->NewNode(bn_reserve_space_repr())
->assert_is_op_output("batch_norm", "ReserveSpace");
auto *bn_out_var = pattern->NewNode(bn_out_repr())
->assert_is_op_output("batch_norm", "Y")
->assert_var_dtype(proto::VarType::FP16);
bn_out_var->assert_is_op_input("elementwise_add");
auto *elewise_add =
pattern->NewNode(elewise_add_repr())->assert_is_op("elementwise_add");
auto *elewise_add_in_var = pattern->NewNode(elewise_add_in_repr())
->assert_is_not_ctrl_var()
->assert_is_op_input("elementwise_add")
->assert_var_dtype(proto::VarType::FP16);
auto *elewise_add_out_var =
pattern->NewNode(elewise_add_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_has_n_outputs(1);
elewise_add_out_var->AsIntermediate()->assert_is_ops_input(act_types);
auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types);
auto *act_out_var =
pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out");
bn->LinksFrom({bn_x_var, bn_scale_var, bn_bias_var})
.LinksTo({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var,
bn_saved_mean_var, bn_reserve_space, bn_out_var});
elewise_add->LinksFrom({elewise_add_in_var, bn_out_var})
.LinksTo({elewise_add_out_var});
act->LinksFrom({elewise_add_out_var}).LinksTo({act_out_var});
return act_out_var;
}
PDNode *patterns::BatchNormAddActGrad::operator()(
paddle::framework::ir::PDNode *d_act_out_var,
std::unordered_set<std::string> act_grad_types) {
auto *act_grad =
pattern->NewNode(act_grad_repr())->assert_is_ops(act_grad_types);
auto *elewise_add_grad = pattern->NewNode(elewise_add_grad_repr())
->assert_is_op("elementwise_add_grad");
auto *bn_grad = pattern->NewNode(batch_norm_grad_repr())
->assert_is_op("batch_norm_grad")
->assert_op_attr<bool>("use_global_stats", false)
->assert_op_attr<std::string>("data_layout", "NHWC");
auto *act_out_var = pattern->NewNode(act_out_repr())
->assert_is_ops_input(act_grad_types, "Out");
auto *d_act_x_var =
pattern->NewNode(d_act_x_repr())
->assert_is_ops_output(act_grad_types, GradVarName("X"))
->assert_has_n_outputs(1); // d_act_x
d_act_x_var->AsIntermediate()->assert_is_op_input("elementwise_add_grad");
auto *d_elewise_add_in_var =
pattern->NewNode(d_elewise_add_in_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("elementwise_add_grad")
->assert_var_dtype(proto::VarType::FP16); // d_add_in_1
auto *d_bn_out_var =
pattern->NewNode(d_bn_out_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("elementwise_add_grad")
->assert_var_dtype(proto::VarType::FP16); // d_add_in_2
d_bn_out_var->assert_is_op_input("batch_norm_grad", GradVarName("Y"));
auto *bn_x_var = pattern->NewNode(bn_x_repr())
->assert_is_op_input("batch_norm_grad", "X")
->assert_var_dtype(proto::VarType::FP16);
auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
->assert_is_op_input("batch_norm_grad", "Scale");
auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
->assert_is_op_input("batch_norm_grad", "Bias");
auto *bn_saved_mean_var =
pattern->NewNode(bn_saved_mean_repr())
->assert_is_op_input("batch_norm_grad", "SavedMean");
auto *bn_saved_variance_var =
pattern->NewNode(bn_saved_variance_repr())
->assert_is_op_input("batch_norm_grad", "SavedVariance");
auto *bn_reserve_space =
pattern->NewNode(bn_reserve_space_repr())
->assert_is_op_input("batch_norm_grad", "ReserveSpace");
auto *d_bn_x_var =
pattern->NewNode(d_bn_x_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("batch_norm_grad", GradVarName("X"))
->assert_var_dtype(proto::VarType::FP16);
auto *d_bn_scale_var =
pattern->NewNode(d_bn_scale_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("batch_norm_grad", GradVarName("Scale"));
auto *d_bn_bias_var =
pattern->NewNode(d_bn_bias_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("batch_norm_grad", GradVarName("Bias"));
act_grad->LinksFrom({d_act_out_var, act_out_var}).LinksTo({d_act_x_var});
elewise_add_grad->LinksFrom({d_act_x_var})
.LinksTo({d_elewise_add_in_var, d_bn_out_var});
bn_grad
->LinksFrom({bn_x_var, d_bn_out_var, bn_scale_var, bn_bias_var,
bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space})
.LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
return bn_grad;
}
PDNode *patterns::ElewiseAddAct::operator()( PDNode *patterns::ElewiseAddAct::operator()(
paddle::framework::ir::PDNode *ele_x_var, paddle::framework::ir::PDNode *ele_x_var,
std::unordered_set<std::string> act_types) { std::unordered_set<std::string> act_types) {
......
...@@ -294,6 +294,12 @@ class GraphPatternDetector { ...@@ -294,6 +294,12 @@ class GraphPatternDetector {
// Remove duplicate patterns. // Remove duplicate patterns.
void UniquePatterns(std::vector<subgraph_t>* subgraphs); void UniquePatterns(std::vector<subgraph_t>* subgraphs);
// Sort subgraphs, sort subgraphs by the specified node so that
// the removed forward and backward subgraphs are corresponding
// when two subgraphs are overlapped. Note: this function is
// currently only used for bn_add_act, refer to PR28196 for details.
void SortSubgraphs(std::vector<subgraph_t>* subgraphs);
// Remove overlapped match subgraphs, when overlapped, keep the previous one. // Remove overlapped match subgraphs, when overlapped, keep the previous one.
// The intermediate PDNodes will be removed, so can't shared by multiple // The intermediate PDNodes will be removed, so can't shared by multiple
// patterns. // patterns.
...@@ -685,6 +691,72 @@ struct BatchNormActOneDNN : public PatternBase { ...@@ -685,6 +691,72 @@ struct BatchNormActOneDNN : public PatternBase {
PATTERN_DECL_NODE(act_out); PATTERN_DECL_NODE(act_out);
}; };
// The following pattern is used to fuse batch_norm, elewise_add, and act
// formula: act(bn(x) + z)
// op: batch_norm + elewise_add + act
struct BatchNormAddAct : public PatternBase {
BatchNormAddAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "bn_add_act") {}
PDNode* operator()(PDNode* x, std::unordered_set<std::string> acts);
// declare operator node's name
PATTERN_DECL_NODE(batch_norm);
PATTERN_DECL_NODE(elewise_add);
PATTERN_DECL_NODE(act);
// declare variable node's name
// BN inputs
PATTERN_DECL_NODE(bn_scale);
PATTERN_DECL_NODE(bn_bias);
// BN outputs
PATTERN_DECL_NODE(bn_mean_out);
PATTERN_DECL_NODE(bn_variance_out);
PATTERN_DECL_NODE(bn_saved_variance);
PATTERN_DECL_NODE(bn_saved_mean);
PATTERN_DECL_NODE(bn_reserve_space);
PATTERN_DECL_NODE(bn_out);
// Elewise_Add input
PATTERN_DECL_NODE(elewise_add_in);
// Elewise_Add output
PATTERN_DECL_NODE(elewise_add_out);
// ACT output
PATTERN_DECL_NODE(act_out);
};
// the backward of act(bn(x) + z)
// op: batch_norm_grad + elewise_add_grad + act_grad
struct BatchNormAddActGrad : public PatternBase {
BatchNormAddActGrad(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "bn_add_act_grad") {}
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// elewise_add_grad: in["Out@GRAD"], out["X@GRAD", "Y@GRAD"]
// bn_grad: in["X", "Z", "Y@GRAD", "Scale", "Bias", "SavedMean",
// "SavedVariance",
// "ReserveSpace"],
// out["X@GRAD", "Z@GRAD", "Scale@GRAD", "Bias@GRAD"]
PDNode* operator()(PDNode* x, std::unordered_set<std::string> act_grad_types);
// declare operator node's name
PATTERN_DECL_NODE(act_grad);
PATTERN_DECL_NODE(elewise_add_grad);
PATTERN_DECL_NODE(batch_norm_grad);
// declare variable node's name
PATTERN_DECL_NODE(act_out);
PATTERN_DECL_NODE(d_act_x);
PATTERN_DECL_NODE(d_elewise_add_in);
PATTERN_DECL_NODE(d_bn_out);
PATTERN_DECL_NODE(bn_x);
PATTERN_DECL_NODE(bn_scale);
PATTERN_DECL_NODE(bn_bias);
PATTERN_DECL_NODE(bn_saved_mean);
PATTERN_DECL_NODE(bn_saved_variance);
PATTERN_DECL_NODE(bn_reserve_space);
PATTERN_DECL_NODE(d_bn_x);
PATTERN_DECL_NODE(d_bn_scale);
PATTERN_DECL_NODE(d_bn_bias);
};
// The following patterns are used to fuse elewise_add and act // The following patterns are used to fuse elewise_add and act
// formula: act(ele_add(x, y)) // formula: act(ele_add(x, y))
// op: elementwise_add + act // op: elementwise_add + act
......
...@@ -186,8 +186,6 @@ void FusedBatchNormAddActGradOp::InferShape( ...@@ -186,8 +186,6 @@ void FusedBatchNormAddActGradOp::InferShape(
// check input // check input
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"FusedBatchNormAddActGradOp"); "FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z",
"FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale",
"FusedBatchNormAddActGradOp"); "FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean", OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
......
...@@ -188,7 +188,6 @@ class FusedBatchNormAddActGradKernel<platform::CUDADeviceContext, T> ...@@ -188,7 +188,6 @@ class FusedBatchNormAddActGradKernel<platform::CUDADeviceContext, T>
std::string act_type = ctx.Attr<std::string>("act_type"); std::string act_type = ctx.Attr<std::string>("act_type");
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *z = ctx.Input<Tensor>("Z");
const auto *y = ctx.Input<Tensor>("Y"); const auto *y = ctx.Input<Tensor>("Y");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y")); const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
......
...@@ -61,7 +61,6 @@ class FusedBatchNormAddActGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -61,7 +61,6 @@ class FusedBatchNormAddActGradOpMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override { void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Z", this->Input("Z"));
op->SetInput("Y", this->Output("Y")); op->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
......
...@@ -2500,6 +2500,31 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2500,6 +2500,31 @@ All parameter, weight, gradient are variables in Paddle.
build_strategy = static.BuildStrategy() build_strategy = static.BuildStrategy()
build_strategy.fuse_bn_act_ops = True build_strategy.fuse_bn_act_ops = True
)DOC") )DOC")
.def_property(
"fuse_bn_add_act_ops",
[](const BuildStrategy &self) { return self.fuse_bn_add_act_ops_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_NE(self.IsFinalized(), true,
platform::errors::PreconditionNotMet(
"BuildStrategy has been finlaized, cannot be "
"configured again."));
self.fuse_bn_add_act_ops_ = b;
},
R"DOC((bool, optional): fuse_bn_add_act_ops indicate whether
to fuse batch_norm, elementwise_add and activation_op,
it may make the execution faster. Default is True
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fuse_bn_add_act_ops = True
)DOC")
.def_property( .def_property(
"enable_auto_fusion", "enable_auto_fusion",
[](const BuildStrategy &self) { return self.enable_auto_fusion_; }, [](const BuildStrategy &self) { return self.enable_auto_fusion_; },
......
...@@ -331,6 +331,7 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op) ...@@ -331,6 +331,7 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_api) list(REMOVE_ITEM TEST_OPS test_basic_lstm_api)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op) list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op)
list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass) list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass)
list(REMOVE_ITEM TEST_OPS test_fuse_bn_add_act_pass)
list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist) list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist)
list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_while) list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_while)
list(REMOVE_ITEM TEST_OPS test_conv3d_transpose_op) list(REMOVE_ITEM TEST_OPS test_conv3d_transpose_op)
...@@ -515,6 +516,7 @@ py_test_modules(test_parallel_executor_transformer_auto_growth MODULES test_para ...@@ -515,6 +516,7 @@ py_test_modules(test_parallel_executor_transformer_auto_growth MODULES test_para
py_test_modules(test_data_norm_op MODULES test_data_norm_op) py_test_modules(test_data_norm_op MODULES test_data_norm_op)
py_test_modules(test_fuse_bn_act_pass MODULES test_fuse_bn_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000) py_test_modules(test_fuse_bn_act_pass MODULES test_fuse_bn_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000)
py_test_modules(test_fuse_bn_add_act_pass MODULES test_fuse_bn_add_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000)
# NOTE: These unittests will appear NaN steadily in windows CI. After analysis, # NOTE: These unittests will appear NaN steadily in windows CI. After analysis,
# it is found that windows CI will run all the training unittests with the ON_INFER option turned on, # it is found that windows CI will run all the training unittests with the ON_INFER option turned on,
......
...@@ -21,6 +21,8 @@ import paddle ...@@ -21,6 +21,8 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
paddle.enable_static()
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"Paddle core is not compiled with CUDA") "Paddle core is not compiled with CUDA")
...@@ -163,12 +165,16 @@ class TestFusedBnAddActAPI(unittest.TestCase): ...@@ -163,12 +165,16 @@ class TestFusedBnAddActAPI(unittest.TestCase):
iters = 5 iters = 5
batch_size = 16 batch_size = 16
# build_fused_program # build_fused_program: turn on fuse_bn_add_act_ops
main_program = fluid.Program() main_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
x, y, loss = self.build_fused_program(main_program, startup_program, x, y, loss = self.build_origin_program(main_program, startup_program,
use_cuda) use_cuda)
feeder = fluid.DataFeeder(feed_list=[x, y], place=place) feeder = fluid.DataFeeder(feed_list=[x, y], place=place)
build_strategy_fused = fluid.BuildStrategy()
build_strategy_fused.fuse_bn_add_act_ops = True
binary_fused = fluid.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy_fused)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size) paddle.dataset.mnist.train(), batch_size=batch_size)
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -178,17 +184,16 @@ class TestFusedBnAddActAPI(unittest.TestCase): ...@@ -178,17 +184,16 @@ class TestFusedBnAddActAPI(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
for _ in range(iters): for _ in range(iters):
data = next(train_reader()) data = next(train_reader())
loss_v = exe.run(main_program, loss_v = exe.run(binary_fused,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
loss_vals_fused.append(loss_v[0][0]) loss_vals_fused.append(loss_v[0][0])
# build_origin_program # build_origin_program: turn off fused_bn_act_ops
main_program = fluid.Program() build_strategy = fluid.BuildStrategy()
startup_program = fluid.Program() build_strategy.fuse_bn_add_act_ops = False
x, y, loss = self.build_origin_program(main_program, startup_program, binary = fluid.CompiledProgram(main_program).with_data_parallel(
use_cuda) loss_name=loss.name, build_strategy=build_strategy)
feeder = fluid.DataFeeder(feed_list=[x, y], place=place)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size) paddle.dataset.mnist.train(), batch_size=batch_size)
loss_vals = [] loss_vals = []
...@@ -197,7 +202,7 @@ class TestFusedBnAddActAPI(unittest.TestCase): ...@@ -197,7 +202,7 @@ class TestFusedBnAddActAPI(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
for _ in range(iters): for _ in range(iters):
data = next(train_reader()) data = next(train_reader())
loss_v = exe.run(main_program, loss_v = exe.run(binary,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
loss_vals.append(loss_v[0][0]) loss_vals.append(loss_v[0][0])
...@@ -210,6 +215,25 @@ class TestFusedBnAddActAPI(unittest.TestCase): ...@@ -210,6 +215,25 @@ class TestFusedBnAddActAPI(unittest.TestCase):
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
self.check(place, use_cuda=True) self.check(place, use_cuda=True)
def test_fuse_bn_add_act_API(self):
# build_fused_program: use fused_bn_add_act python API
main_program = fluid.Program()
startup_program = fluid.Program()
place = fluid.CUDAPlace(0)
x, y, loss = self.build_fused_program(
main_program, startup_program, use_cuda=True)
feeder = fluid.DataFeeder(feed_list=[x, y], place=place)
train_reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=16)
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
exe.run(startup_program)
for _ in range(5):
data = next(train_reader())
loss_v = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[loss])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册