“9511ee3855deb2b7596f67bef57cc613ef966f47”上不存在“PaddleNLP/text_matching_on_quora/train_and_evaluate.py”
提交 c63a63d5 编写于 作者: Z Zhen Wang 提交者: Zeng Jinle

support the fusion of batch_norm and relu for AMP. test=release/1.7 (#22210)

上级 fedb609d
...@@ -118,7 +118,7 @@ function(op_library TARGET) ...@@ -118,7 +118,7 @@ function(op_library TARGET)
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"multihead_matmul_op" "fusion_group_op") "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
......
...@@ -100,7 +100,7 @@ endif() ...@@ -100,7 +100,7 @@ endif()
cc_library(build_strategy SRCS build_strategy.cc DEPS cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass fuse_elewise_add_act_pass fuse_bn_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
lock_free_optimize_pass lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
......
...@@ -167,6 +167,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -167,6 +167,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"fuse_relu_depthwise_conv_pass"); "fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_, AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_,
"fuse_elewise_add_act_pass"); "fuse_elewise_add_act_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
// for single card training, fuse_all_reduce_ops is unnecessary. // for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass. // coalesce_grad_tensor_pass should be before of MultiDevPass.
AppendPassWithCheck(strategy_.fuse_all_reduce_ops_, AppendPassWithCheck(strategy_.fuse_all_reduce_ops_,
...@@ -369,6 +370,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -369,6 +370,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped."; "GPU, skipped.";
continue; continue;
} }
} else if (pass->Type() == "fuse_bn_act_pass") {
if (!use_cuda) {
LOG(WARNING) << "fuse_bn_act_pass is only supported on "
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "mkldnn_placement_pass") { } else if (pass->Type() == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types", pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types_)); new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
...@@ -394,6 +401,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -394,6 +401,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
USE_PASS(sync_batch_norm_pass); USE_PASS(sync_batch_norm_pass);
USE_PASS(fuse_relu_depthwise_conv_pass); USE_PASS(fuse_relu_depthwise_conv_pass);
USE_PASS(fuse_elewise_add_act_pass); USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(fuse_bn_act_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass); USE_PASS(multi_batch_merge_pass);
USE_PASS(reduce_mode_multi_devices_pass); USE_PASS(reduce_mode_multi_devices_pass);
......
...@@ -87,6 +87,7 @@ struct BuildStrategy { ...@@ -87,6 +87,7 @@ struct BuildStrategy {
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have // TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle. // cycle.
bool fuse_elewise_add_act_ops_{false}; bool fuse_elewise_add_act_ops_{false};
bool fuse_bn_act_ops_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients // Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types // should not be sparse types
boost::optional<bool> fuse_all_optimizer_ops_{false}; boost::optional<bool> fuse_all_optimizer_ops_{false};
......
...@@ -106,6 +106,7 @@ if(WITH_NGRAPH) ...@@ -106,6 +106,7 @@ if(WITH_NGRAPH)
set(INFER_IR_PASSES ${INFER_IR_PASSES} ngraph_subgraph_pass CACHE INTERNAL "") set(INFER_IR_PASSES ${INFER_IR_PASSES} ngraph_subgraph_pass CACHE INTERNAL "")
endif() endif()
cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_bn_act_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle {
namespace framework {
namespace ir {
void FuseBatchNormActPass::ApplyImpl(ir::Graph *graph) const {
#ifdef PADDLE_WITH_CUDA
#if CUDNN_VERSION_MIN(7, 4, 1)
// forward
std::unordered_set<std::string> act_types = {"relu"};
graph = FuseBatchNormAct(graph, act_types);
// backward
std::unordered_set<std::string> act_grad_types = {"relu_grad"};
graph = FuseBatchNormActGrad(graph, act_grad_types);
#endif
#endif
}
// act(bn(x))
ir::Graph *FuseBatchNormActPass::FuseBatchNormAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument(
"The input graph of FuseBatchNormAct should not be nullptr."));
FusePassBase::Init("bn_act", graph);
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("bn_act/x")
->AsInput()
->assert_is_op_input("batch_norm", "X")
->assert_var_dtype(proto::VarType::FP16);
patterns::BatchNormAct bn_act_pattern(gpd.mutable_pattern(), "bn_act");
bn_act_pattern(x, act_types);
int found_bn_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseBatchNormAct fuse";
// BN inputs
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, bn_act_pattern);
// BN outputs
GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance,
bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_reserve_space, bn_reserve_space,
bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern);
std::string bn_x_n = subgraph.at(x)->Name();
std::string bn_scale_n = bn_scale->Name();
std::string bn_bias_n = bn_bias->Name();
std::string bn_variance_n = bn_variance->Name();
std::string bn_mean_n = bn_mean->Name();
std::string bn_mean_out_n = bn_mean_out->Name();
std::string bn_variance_out_n = bn_variance_out->Name();
std::string bn_saved_variance_n = bn_saved_variance->Name();
std::string bn_saved_mean_n = bn_saved_mean->Name();
std::string bn_reserve_space_n = bn_reserve_space->Name();
std::string bn_out_n = bn_out->Name();
std::string act_out_n = act_out->Name();
Node *fused_bn_act_node = CreateFusedBatchNormActNode(
g, act, batch_norm, bn_x_n, bn_scale_n, bn_bias_n, bn_variance_n,
bn_mean_n, bn_mean_out_n, bn_variance_out_n, bn_saved_variance_n,
bn_saved_mean_n, bn_reserve_space_n, act_out_n);
VLOG(4) << "\n\t " << bn_x_n << ", " << bn_scale_n << ", " << bn_bias_n
<< ", " << bn_variance_n << " and " << bn_mean_n << " -> "
<< batch_norm->Name() << " -> " << bn_mean_out_n << ", "
<< bn_variance_out_n << ", " << bn_saved_variance_n << ", "
<< bn_saved_mean_n << ", " << bn_reserve_space_n << " and "
<< bn_out_n << "\n"
<< "\t " << bn_out_n << " -> " << act->Name() << " -> "
<< act_out_n;
ReLinkNodes(g, bn_out, batch_norm, act, fused_bn_act_node);
found_bn_act_count++;
};
gpd(graph, handler);
AddStatis(found_bn_act_count);
return graph;
}
Node *FuseBatchNormActPass::CreateFusedBatchNormActNode(
Graph *g, const Node *act, const Node *bn, const std::string &bn_x_n,
const std::string &bn_scale_n, const std::string &bn_bias_n,
const std::string &bn_variance_n, const std::string &bn_mean_n,
const std::string &bn_mean_out_n, const std::string &bn_variance_out_n,
const std::string &bn_saved_variance_n, const std::string &bn_saved_mean_n,
const std::string &bn_reserve_space_n, const std::string &act_out_n) const {
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({bn_x_n}));
desc.SetInput("Scale", std::vector<std::string>({bn_scale_n}));
desc.SetInput("Bias", std::vector<std::string>({bn_bias_n}));
desc.SetInput("Mean", std::vector<std::string>({bn_mean_n}));
desc.SetInput("Variance", std::vector<std::string>({bn_variance_n}));
desc.SetOutput("Y", std::vector<std::string>({act_out_n}));
desc.SetOutput("MeanOut", std::vector<std::string>({bn_mean_out_n}));
desc.SetOutput("VarianceOut", std::vector<std::string>({bn_variance_out_n}));
desc.SetOutput("SavedMean", std::vector<std::string>({bn_saved_mean_n}));
desc.SetOutput("SavedVariance",
std::vector<std::string>({bn_saved_variance_n}));
desc.SetOutput("ReserveSpace",
std::vector<std::string>({bn_reserve_space_n}));
desc.SetType("fused_batch_norm_act");
desc.SetAttr("act_type", act->Name());
// Set attrs
for (auto &n : {act->Op(), bn->Op()}) {
for (auto &m : n->GetAttrMap()) {
desc.SetAttr(m.first, m.second);
}
}
auto fused_bn_act_node = g->CreateOpNode(&desc);
return fused_bn_act_node;
}
// the backward of act(bn(x))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// bn_grad: in["X", "Y@GRAD", "Scale", "Bias", "SavedMean", "SavedVariance",
// "ReserveSpace"],
// out["X@GRAD", "Scale@GRAD", "Bias@GRAD"]
ir::Graph *FuseBatchNormActPass::FuseBatchNormActGrad(
ir::Graph *graph,
const std::unordered_set<std::string> &act_grad_types) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::InvalidArgument(
"The input graph of FuseBatchNormActGrad should not be nullptr."));
FusePassBase::Init("bn_act_grad", graph);
GraphPatternDetector gpd;
auto *d_act_out =
gpd.mutable_pattern()
->NewNode("bn_act_grad/x")
->AsInput()
->assert_is_ops_input(act_grad_types, GradVarName("Out"));
patterns::BatchNormActGrad bn_act_grad_pattern(gpd.mutable_pattern(),
"bn_act_grad");
bn_act_grad_pattern(d_act_out, act_grad_types);
int found_bn_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseBatchNormActGrad fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(batch_norm_grad, batch_norm_grad,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_itermediate_out, d_itermediate_out,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_x, bn_x, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_reserve_space, bn_reserve_space,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_x, d_bn_x, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_scale, d_bn_scale, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_bias, d_bn_bias, bn_act_grad_pattern);
std::string d_act_out_n = subgraph.at(d_act_out)->Name(); // Y@GRAD
std::string act_out_n = act_out->Name(); // Y
std::string d_itermediate_out_n = d_itermediate_out->Name();
std::string bn_x_n = bn_x->Name();
std::string bn_scale_n = bn_scale->Name();
std::string bn_bias_n = bn_bias->Name();
std::string bn_saved_mean_n = bn_saved_mean->Name();
std::string bn_saved_variance_n = bn_saved_variance->Name();
std::string bn_reserve_space_n = bn_reserve_space->Name();
std::string d_bn_x_n = d_bn_x->Name();
std::string d_bn_scale_n = d_bn_scale->Name();
std::string d_bn_bias_n = d_bn_bias->Name();
OpDesc desc;
desc.SetType("fused_batch_norm_act_grad");
desc.SetInput("X", {bn_x_n});
desc.SetInput("Y", std::vector<std::string>({act_out_n}));
desc.SetInput(GradVarName("Y"), std::vector<std::string>({d_act_out_n}));
desc.SetInput("Scale", std::vector<std::string>({bn_scale_n}));
desc.SetInput("Bias", std::vector<std::string>({bn_bias_n}));
desc.SetInput("SavedMean", std::vector<std::string>({bn_saved_mean_n}));
desc.SetInput("SavedVariance",
std::vector<std::string>({bn_saved_variance_n}));
desc.SetInput("ReserveSpace",
std::vector<std::string>({bn_reserve_space_n}));
desc.SetOutput(GradVarName("X"), std::vector<std::string>({d_bn_x_n}));
desc.SetOutput(GradVarName("Scale"),
std::vector<std::string>({d_bn_scale_n}));
desc.SetOutput(GradVarName("Bias"),
std::vector<std::string>({d_bn_bias_n}));
std::string act = act_grad->Name();
act = act.substr(0, act.length() - 5); // remove "_grad"
desc.SetAttr("act_type", act);
for (auto &n : {act_grad->Op(), batch_norm_grad->Op()}) {
for (auto &m : n->GetAttrMap()) {
desc.SetAttr(m.first, m.second);
}
}
auto fused_node = g->CreateOpNode(&desc);
VLOG(4) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> "
<< act_grad->Name() << " -> " << d_itermediate_out_n << "\n\t "
<< bn_x_n << ", " << d_itermediate_out_n << ", " << bn_scale_n
<< ", " << bn_bias_n << ", " << bn_saved_mean_n << ", "
<< bn_saved_variance_n << " and " << bn_reserve_space_n << " -> "
<< batch_norm_grad->Name() << " -> " << d_bn_x_n << ", "
<< d_bn_scale_n << " and " << d_bn_bias_n;
ReLinkNodes(g, d_itermediate_out, act_grad, batch_norm_grad, fused_node);
found_bn_act_count++;
};
gpd(graph, handler);
AddStatis(found_bn_act_count);
return graph;
}
void FuseBatchNormActPass::ReLinkNodes(Graph *graph,
const Node *intermediate_out, Node *op_1,
Node *op_2,
Node *fused_op) const { // delete act
for (auto &in : op_1->inputs) {
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_1, fused_op, in->outputs);
}
std::unordered_set<const Node *> nodes2delete;
for (auto &out : op_1->outputs) {
// intermediate_out or ctr_var
auto result_iter =
std::find_if(op_2->inputs.begin(), op_2->inputs.end(),
[&out](const Node *node) -> bool { return node == out; });
if (result_iter == op_2->inputs.end()) {
IR_OP_VAR_LINK(fused_op, out);
} else {
nodes2delete.emplace(out);
}
}
for (auto &in : op_2->inputs) {
if (in == intermediate_out || nodes2delete.count(in)) {
continue;
}
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_2, fused_op, in->outputs);
}
for (auto &out : op_2->outputs) {
IR_OP_VAR_LINK(fused_op, out);
}
nodes2delete.insert(std::move(op_1));
nodes2delete.insert(std::move(op_2));
GraphSafeRemoveNodes(graph, nodes2delete);
}
std::vector<Node *> FuseBatchNormActPass::ReplaceNode(
Node *cur_node, Node *new_node, const std::vector<Node *> &nodes) const {
std::vector<Node *> new_list(nodes.size());
bool has_replaced = false;
std::transform(nodes.begin(), nodes.end(), new_list.begin(),
[&](Node *node) -> Node * {
if (node == cur_node) {
has_replaced = true;
return new_node;
}
return node;
});
PADDLE_ENFORCE_EQ(has_replaced, true,
platform::errors::NotFound("Not find %s in the node list.",
cur_node->Name()));
return new_list;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_bn_act_pass, paddle::framework::ir::FuseBatchNormActPass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the BatchNorm and activation.
*/
class FuseBatchNormActPass : public FusePassBase {
public:
virtual ~FuseBatchNormActPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
ir::Graph *FuseBatchNormAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseBatchNormActGrad(
ir::Graph *graph,
const std::unordered_set<std::string> &act_grad_types) const;
std::vector<Node *> ReplaceNode(Node *cur_node, Node *new_node,
const std::vector<Node *> &nodes) const;
void ReLinkNodes(Graph *graph, const Node *intermediate_out, Node *op_1,
Node *op_2, Node *fused_op) const;
Node *CreateFusedBatchNormActNode(
Graph *g, const Node *act, const Node *bn, const std::string &bn_x_n,
const std::string &bn_scale_n, const std::string &bn_bias_n,
const std::string &bn_variance_n, const std::string &bn_mean_n,
const std::string &bn_mean_out_n, const std::string &bn_variance_out_n,
const std::string &bn_saved_variance_n,
const std::string &bn_saved_mean_n, const std::string &bn_reserve_space_n,
const std::string &act_out_n) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -383,6 +383,13 @@ PDNode *PDNode::assert_is_var() { ...@@ -383,6 +383,13 @@ PDNode *PDNode::assert_is_var() {
return this; return this;
} }
PDNode *PDNode::assert_var_dtype(proto::VarType::Type dtype) {
assert_is_var();
asserts_.emplace_back(
[dtype](Node *x) { return x->Var()->GetDataType() == dtype; });
return this;
}
PDNode *PDNode::assert_is_not_ctrl_var() { PDNode *PDNode::assert_is_not_ctrl_var() {
asserts_.emplace_back([](Node *x) { return x && !x->IsCtrlVar(); }); asserts_.emplace_back([](Node *x) { return x && !x->IsCtrlVar(); });
return this; return this;
...@@ -476,6 +483,7 @@ PDNode *PDNode::assert_is_op_output(const std::string &op_type, ...@@ -476,6 +483,7 @@ PDNode *PDNode::assert_is_op_output(const std::string &op_type,
assert_is_op_nth_output(op_type, argument, 0); assert_is_op_nth_output(op_type, argument, 0);
return this; return this;
} }
PDNode *PDNode::assert_is_op_input(const std::string &op_type) { PDNode *PDNode::assert_is_op_input(const std::string &op_type) {
assert_is_var(); assert_is_var();
asserts_.emplace_back([=](Node *x) { asserts_.emplace_back([=](Node *x) {
...@@ -489,6 +497,16 @@ PDNode *PDNode::assert_is_op_input(const std::string &op_type) { ...@@ -489,6 +497,16 @@ PDNode *PDNode::assert_is_op_input(const std::string &op_type) {
return this; return this;
} }
PDNode *PDNode::assert_is_not_op_input(const std::string &argument) {
assert_is_op();
asserts_.emplace_back([=](Node *x) {
auto &ins = x->Op()->Inputs();
auto iter = ins.find(argument);
return iter == ins.end() || iter->second.empty();
});
return this;
}
PDNode *PDNode::assert_is_op_input(const std::string &op_type, PDNode *PDNode::assert_is_op_input(const std::string &op_type,
const std::string &argument) { const std::string &argument) {
assert_is_var(); assert_is_var();
...@@ -1048,6 +1066,117 @@ PDNode *patterns::ActElewiseAdd::operator()( ...@@ -1048,6 +1066,117 @@ PDNode *patterns::ActElewiseAdd::operator()(
return elewise_add_out; return elewise_add_out;
} }
PDNode *patterns::BatchNormAct::operator()(
paddle::framework::ir::PDNode *bn_x_var,
std::unordered_set<std::string> act_types) {
auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
->assert_is_op_input("batch_norm", "Scale");
auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
->assert_is_op_input("batch_norm", "Bias");
auto *bn_variance_var = pattern->NewNode(bn_variance_repr())
->assert_is_op_input("batch_norm", "Variance");
auto *bn_mean_var = pattern->NewNode(bn_mean_repr())
->assert_is_op_input("batch_norm", "Mean");
auto *bn = pattern->NewNode(batch_norm_repr())
->assert_is_op("batch_norm")
->assert_is_not_op_input("MomentumTensor")
->assert_op_attr<bool>("is_test", false)
->assert_op_attr<bool>("use_global_stats", false)
->assert_op_attr<std::string>("data_layout", "NHWC");
auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr())
->assert_is_op_output("batch_norm", "MeanOut");
auto *bn_variance_out_var =
pattern->NewNode(bn_variance_out_repr())
->assert_is_op_output("batch_norm", "VarianceOut");
auto *bn_saved_variance_var =
pattern->NewNode(bn_saved_variance_repr())
->assert_is_op_output("batch_norm", "SavedVariance");
auto *bn_saved_mean_var =
pattern->NewNode(bn_saved_mean_repr())
->assert_is_op_output("batch_norm", "SavedMean");
auto *bn_reserve_space =
pattern->NewNode(bn_reserve_space_repr())
->assert_is_op_output("batch_norm", "ReserveSpace");
auto *bn_out_var = pattern->NewNode(bn_out_repr())
->assert_is_op_output("batch_norm", "Y")
->assert_has_n_outputs(1);
bn_out_var->AsIntermediate()->assert_is_ops_input(act_types);
auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types);
auto *act_out_var =
pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out");
bn->LinksFrom(
{bn_x_var, bn_scale_var, bn_bias_var, bn_variance_var, bn_mean_var})
.LinksTo({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var,
bn_saved_mean_var, bn_reserve_space, bn_out_var});
act->LinksFrom({bn_out_var}).LinksTo({act_out_var});
return act_out_var;
}
PDNode *patterns::BatchNormActGrad::operator()(
paddle::framework::ir::PDNode *d_act_out_var,
std::unordered_set<std::string> act_grad_types) {
auto *act_grad =
pattern->NewNode(act_grad_repr())->assert_is_ops(act_grad_types);
auto *bn_grad = pattern->NewNode(batch_norm_grad_repr())
->assert_is_op("batch_norm_grad")
->assert_op_attr<bool>("use_global_stats", false)
->assert_op_attr<std::string>("data_layout", "NHWC");
auto *act_out_var = pattern->NewNode(act_out_repr())
->assert_is_ops_input(act_grad_types, "Out");
auto *d_intermediate_var =
pattern->NewNode(d_itermediate_out_repr())
->assert_is_ops_output(act_grad_types, GradVarName("X"))
->assert_has_n_outputs(1);
auto *bn_x_var = pattern->NewNode(bn_x_repr())
->assert_is_op_input("batch_norm_grad", "X")
->assert_var_dtype(proto::VarType::FP16);
auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
->assert_is_op_input("batch_norm_grad", "Scale");
auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
->assert_is_op_input("batch_norm_grad", "Bias");
auto *bn_saved_mean_var =
pattern->NewNode(bn_saved_mean_repr())
->assert_is_op_input("batch_norm_grad", "SavedMean");
auto *bn_saved_variance_var =
pattern->NewNode(bn_saved_variance_repr())
->assert_is_op_input("batch_norm_grad", "SavedVariance");
// ReserveSpace as the output is equal to:
// data_layout == 'NHWC' && FLAGS_cudnn_batchnorm_spatial_persistent == true
auto *bn_reserve_space =
pattern->NewNode(bn_reserve_space_repr())
->assert_is_op_input("batch_norm_grad", "ReserveSpace");
auto *d_bn_x_var =
pattern->NewNode(d_bn_x_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("batch_norm_grad", GradVarName("X"));
auto *d_bn_scale_var =
pattern->NewNode(d_bn_scale_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("batch_norm_grad", GradVarName("Scale"));
auto *d_bn_bias_var =
pattern->NewNode(d_bn_bias_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("batch_norm_grad", GradVarName("Bias"));
act_grad->LinksFrom({d_act_out_var, act_out_var})
.LinksTo({d_intermediate_var});
bn_grad
->LinksFrom({bn_x_var, d_intermediate_var, bn_scale_var, bn_bias_var,
bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space})
.LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
return bn_grad;
}
PDNode *patterns::ElewiseAddAct::operator()( PDNode *patterns::ElewiseAddAct::operator()(
paddle::framework::ir::PDNode *ele_x_var, paddle::framework::ir::PDNode *ele_x_var,
std::unordered_set<std::string> act_types) { std::unordered_set<std::string> act_types) {
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
...@@ -100,6 +101,7 @@ struct PDNode { ...@@ -100,6 +101,7 @@ struct PDNode {
PDNode* assert_is_op(); PDNode* assert_is_op();
PDNode* assert_is_op(const std::string& op_type); PDNode* assert_is_op(const std::string& op_type);
PDNode* assert_is_var(); PDNode* assert_is_var();
PDNode* assert_var_dtype(proto::VarType::Type dtype);
PDNode* assert_is_not_ctrl_var(); PDNode* assert_is_not_ctrl_var();
PDNode* assert_var_not_persistable(); PDNode* assert_var_not_persistable();
PDNode* assert_is_persistable_var(); PDNode* assert_is_persistable_var();
...@@ -111,6 +113,7 @@ struct PDNode { ...@@ -111,6 +113,7 @@ struct PDNode {
const std::string& argument); const std::string& argument);
PDNode* assert_is_op_nth_input(const std::string& op_type, PDNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth); const std::string& argument, int nth);
PDNode* assert_is_not_op_input(const std::string& argument);
PDNode* assert_is_op_nth_output(const std::string& op_type, PDNode* assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth); const std::string& argument, int nth);
PDNode* assert_is_only_input_of_op(const std::string& op_type); PDNode* assert_is_only_input_of_op(const std::string& op_type);
...@@ -590,6 +593,64 @@ struct GRU : public PatternBase { ...@@ -590,6 +593,64 @@ struct GRU : public PatternBase {
PATTERN_DECL_NODE(Hidden); PATTERN_DECL_NODE(Hidden);
}; };
// The following pattern is used to fuse batch_norm and act
// formula: act(bn(x))
// op: batch_norm + act
struct BatchNormAct : public PatternBase {
BatchNormAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "bn_act") {}
PDNode* operator()(PDNode* x, std::unordered_set<std::string> acts);
// declare operator node's name
PATTERN_DECL_NODE(batch_norm);
PATTERN_DECL_NODE(act);
// declare variable node's name
// BN inputs
PATTERN_DECL_NODE(bn_scale);
PATTERN_DECL_NODE(bn_bias);
PATTERN_DECL_NODE(bn_variance);
PATTERN_DECL_NODE(bn_mean);
// BN outputs
PATTERN_DECL_NODE(bn_mean_out);
PATTERN_DECL_NODE(bn_variance_out);
PATTERN_DECL_NODE(bn_saved_variance);
PATTERN_DECL_NODE(bn_saved_mean);
PATTERN_DECL_NODE(bn_reserve_space);
PATTERN_DECL_NODE(bn_out);
// ACT output
PATTERN_DECL_NODE(act_out);
};
// the backward of act(bn(x))
// op: batch_norm_grad + act_grad
struct BatchNormActGrad : public PatternBase {
BatchNormActGrad(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "bn_act_grad") {}
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// bn_grad: in["X", "Y@GRAD", "Scale", "Bias", "SavedMean", "SavedVariance",
// "ReserveSpace"],
// out["X@GRAD", "Scale@GRAD", "Bias@GRAD"]
PDNode* operator()(PDNode* x, std::unordered_set<std::string> act_grad_types);
// declare operator node's name
PATTERN_DECL_NODE(act_grad);
PATTERN_DECL_NODE(batch_norm_grad);
// declare variable node's name
PATTERN_DECL_NODE(act_out);
PATTERN_DECL_NODE(d_itermediate_out);
PATTERN_DECL_NODE(bn_x);
PATTERN_DECL_NODE(bn_scale);
PATTERN_DECL_NODE(bn_bias);
PATTERN_DECL_NODE(bn_saved_mean);
PATTERN_DECL_NODE(bn_saved_variance);
PATTERN_DECL_NODE(bn_reserve_space);
PATTERN_DECL_NODE(d_bn_x);
PATTERN_DECL_NODE(d_bn_scale);
PATTERN_DECL_NODE(d_bn_bias);
};
// The following patterns are used to fuse elewise_add and act // The following patterns are used to fuse elewise_add and act
// formula: act(ele_add(x, y)) // formula: act(ele_add(x, y))
// op: elementwise_add + act // op: elementwise_add + act
......
...@@ -30,6 +30,8 @@ const std::unordered_set<std::string> op_has_unsed_vars_white_list = { ...@@ -30,6 +30,8 @@ const std::unordered_set<std::string> op_has_unsed_vars_white_list = {
"auc", "auc",
"batch_norm", "batch_norm",
"batch_norm_grad", "batch_norm_grad",
"fused_batch_norm_act",
"fused_batch_norm_act_grad",
"sync_batch_norm_grad", "sync_batch_norm_grad",
"center_loss_grad", "center_loss_grad",
"crop", "crop",
......
include(operators) include(operators)
register_operators(EXCLUDES register_operators(EXCLUDES
fused_bn_activation_op
conv_fusion_op conv_fusion_op
fusion_transpose_flatten_concat_op fusion_transpose_flatten_concat_op
fusion_conv_inception_op fusion_conv_inception_op
...@@ -8,6 +9,11 @@ register_operators(EXCLUDES ...@@ -8,6 +9,11 @@ register_operators(EXCLUDES
fusion_group_op) fusion_group_op)
if (WITH_GPU) if (WITH_GPU)
# fused_bn_activation_op needs cudnn 7.4.1 above
if (NOT ${CUDNN_VERSION} VERSION_LESS 7401)
op_library(fused_bn_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_batch_norm_act);\n")
endif()
# conv_fusion_op needs cudnn 7 above # conv_fusion_op needs cudnn 7 above
if (NOT ${CUDNN_VERSION} VERSION_LESS 7100) if (NOT ${CUDNN_VERSION} VERSION_LESS 7100)
op_library(conv_fusion_op) op_library(conv_fusion_op)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/fused_bn_activation_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
void FusedBatchNormActOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Scale"), true,
platform::errors::InvalidArgument(
"Input(Scale) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Bias"), true,
platform::errors::InvalidArgument(
"Input(Bias) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Mean"), true,
platform::errors::InvalidArgument(
"Input(Mean) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Variance"), true,
platform::errors::InvalidArgument(
"Input(Variance) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true,
platform::errors::InvalidArgument(
"Output(Y) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("MeanOut"), true,
platform::errors::InvalidArgument(
"Output(MeanOut) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("VarianceOut"), true,
platform::errors::InvalidArgument(
"Output(VarianceOut) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("SavedMean"), true,
platform::errors::InvalidArgument(
"Output(SavedMean) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("SavedVariance"), true,
platform::errors::InvalidArgument(
"Output(SavedVariance) of BatchNormOp should not be null."));
// make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], ctx->Outputs("MeanOut")[0],
platform::errors::PreconditionNotMet(
"Mean and MeanOut should share the same memory"));
PADDLE_ENFORCE_EQ(
ctx->Inputs("Variance")[0], ctx->Outputs("VarianceOut")[0],
platform::errors::PreconditionNotMet(
"Variance and VarianceOut should share the same memory"));
const auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2, platform::errors::PreconditionNotMet(
"ShapeError: the dimension of input "
"X must greater than or equal to 2."
"But received: the shape of input X "
"= [%s], the dimension of input X ="
"[%d]",
x_dims, x_dims.size()));
PADDLE_ENFORCE_LE(x_dims.size(), 5, platform::errors::PreconditionNotMet(
"ShapeError: the dimension of input "
"X must smaller than or equal to 5."
"But received: the shape of input X "
"= [%s], the dimension of input X ="
"[%d]",
x_dims, x_dims.size()));
const int64_t C = x_dims[x_dims.size() - 1];
auto scale_dim = ctx->GetInputDim("Scale");
auto bias_dim = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(
scale_dim.size(), 1UL,
platform::errors::PreconditionNotMet(
"ShapeError: the dimension of scale must equal to 1."
"But received: the shape of scale is [%s], the dimension "
"of scale is [%d]",
scale_dim, scale_dim.size()));
PADDLE_ENFORCE_EQ(bias_dim.size(), 1UL,
platform::errors::PreconditionNotMet(
"ShapeError: the dimension of bias must equal to 1."
"But received: the shape of bias is [%s],the dimension "
"of bias is [%d]",
bias_dim, bias_dim.size()));
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
framework::product(bias_dim) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(scale_dim[0], C,
platform::errors::PreconditionNotMet(
"ShapeError: the shape of scale must equal to [%d]"
"But received: the shape of scale is [%d]",
C, scale_dim[0]));
PADDLE_ENFORCE_EQ(bias_dim[0], C,
platform::errors::PreconditionNotMet(
"ShapeError: the shape of bias must equal to [%d]"
"But received: the shape of bias is [%d]",
C, bias_dim[0]));
}
ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("MeanOut", {C});
ctx->SetOutputDim("VarianceOut", {C});
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
}
framework::OpKernelType FusedBatchNormActOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto bn_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64;
}
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::PreconditionNotMet(
"Scale input should be of float type"));
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::PreconditionNotMet(
"Bias input should be of float type"));
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
platform::errors::PreconditionNotMet(
"Mean input should be of float type"));
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
platform::errors::PreconditionNotMet(
"Variance input should be of float type"));
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}
framework::OpKernelType FusedBatchNormActOp::GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
void FusedBatchNormActOpMaker::Make() {
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
platform::errors::InvalidArgument(
"'epsilon' should be between 0.0 and 0.001."));
});
AddAttr<std::string>("act_type", "The activation type to be fused.")
.SetDefault("relu");
AddInput("X", "The input tensor");
AddInput("Scale",
"Scale is a 1-dimensional tensor of size C "
"that is applied to the output");
AddInput("Bias",
"Bias is a 1-dimensional tensor of size C "
"that is applied to the output");
AddInput("Mean",
"The global mean (for training) or "
"estimated mean (for testing)");
AddInput("Variance",
"The global variance (for training) "
"or estimated Variance (for testing)");
AddOutput("Y", "result after normalization");
AddOutput("MeanOut",
"Share memory with Mean. "
"Store the global mean when training");
AddOutput("VarianceOut",
"Share memory with Variance. "
"Store the global Variance when training");
AddOutput("SavedMean",
"Mean of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
AddOutput("SavedVariance",
"Variance of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
AddOutput("ReserveSpace",
"Reserve GPU space for triggering the new semi-persistent "
"NHWC kernel");
AddComment(R"DOC(
Fused Batch Normalization with activation.
Batch Norm has been implemented as discussed in the paper:
https://arxiv.org/pdf/1502.03167.pdf
Batch Norm can be used as a normalizer function for conv2d and fully_connected operations.
Now, the required data format for FusedBatchNormActOp is NHWC `[batch, in_height, in_width, in_channels]`.
)DOC");
}
void FusedBatchNormActGradOp::InferShape(
framework::InferShapeContext *ctx) const {
// check input
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Scale"), true,
platform::errors::InvalidArgument("Input(Scale) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Y")), true,
platform::errors::InvalidArgument("Input(Y@GRAD) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("SavedMean"), true,
platform::errors::InvalidArgument(
"Input(SavedMean) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("SavedVariance"), true,
platform::errors::InvalidArgument(
"Input(SavedVariance) should not be null"));
// check output
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument("Output(X@GRAD) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Scale")), true,
platform::errors::InvalidArgument(
"Output(Scale@GRAD) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Bias")), true,
platform::errors::InvalidArgument(
"Output(Bias@GRAD) should not be null."));
const auto x_dims = ctx->GetInputDim("X");
const int C = x_dims[x_dims.size() - 1];
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
framework::OpKernelType FusedBatchNormActGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
PADDLE_THROW(platform::errors::NotFound(
"Can not find Y@GRAD in the execution context."));
}
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
platform::errors::NotFound("Can not get the tensor value of Y@GRAD."));
}
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_batch_norm_act, ops::FusedBatchNormActOp,
ops::FusedBatchNormActOpMaker, ops::FusedBatchNormActOpInferVarType,
ops::FusedBatchNormActGradOpMaker<paddle::framework::OpDesc>,
ops::FusedBatchNormActGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_batch_norm_act_grad, ops::FusedBatchNormActGradOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/fused/fused_bn_activation_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T>
class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("It must use CUDAPlace."));
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
float momentum = ctx.Attr<float>("momentum");
std::string act_type = ctx.Attr<std::string>("act_type");
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
// Get the size for each dimension.
// NHWC [batch_size, in_height, in_width, in_channels]
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE_EQ(x_dims.size() >= 2 && x_dims.size() <= 5, true,
platform::errors::PreconditionNotMet(
"The Input dim size should be between 2 and 5"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
// Run training mode.
// obtain running mean and running inv var, and see if we need to
// initialize them.
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto *y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
int N, C, H, W, D;
const DataLayout data_layout = DataLayout::kNHWC;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if ((N * H * W * D) == 1) {
// Only 1 element in normalization dimension,
// skip the batch norm calculation, let y = act(x).
auto x_v = framework::EigenVector<T>::Flatten(*x);
auto y_v = framework::EigenVector<T>::Flatten(*y);
auto &dev = *dev_ctx.eigen_device();
if (act_type == "relu") {
ReluFunctor<T>()(dev, x_v, y_v);
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported activation type"));
}
return;
}
// ------------------- cudnn descriptors ---------------------
auto handle = dev_ctx.cudnn_handle();
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnCreateTensorDescriptor(&data_desc_)."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnCreateTensorDescriptor(&bn_param_desc_)."));
VLOG(3) << "Setting descriptors.";
std::vector<int> dims = {N, C, H, W, D};
std::vector<int> strides = {H * W * D * C, 1, W * D * C, D * C, C};
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()),
platform::errors::External(
"The error has happened when calling cudnnSetTensorNdDescriptor."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
data_desc_, mode_),
platform::errors::External("The error has happened when calling "
"cudnnDeriveBNTensorDescriptor."));
double this_factor = 1. - momentum;
cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
platform::ScopedActivationDescriptor scope_act_desc;
cudnnActivationDescriptor_t activation_desc_ =
scope_act_desc.descriptor<T>(act_type);
size_t workspace_size = 0;
size_t reserve_space_size = 0;
void *reserve_space_ptr = nullptr;
void *workspace_ptr = nullptr;
Tensor workspace_tensor;
// Create reserve space and workspace for batch norm.
// Create tensor for each batchnorm op, it will be used in the
// backward. Thus this tensor shouldn't be temp.
auto *reserve_space = ctx.Output<Tensor>("ReserveSpace");
PADDLE_ENFORCE_NOT_NULL(
reserve_space,
platform::errors::NotFound(
"The argument ReserveSpace of batch_norm op is not found."));
// --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
/*handle=*/handle,
/*mode=*/mode_,
/*bnOps=*/bnOps_,
/*xDesc=*/data_desc_,
/*zDesc=*/nullptr,
/*yDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/activation_desc_,
/*sizeInBytes=*/&workspace_size),
platform::errors::External(
"The error has happened when calling "
"cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize."));
// -------------- cudnn batchnorm reserve space --------------
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
/*handle=*/handle,
/*mode=*/mode_,
/*bnOps=*/bnOps_,
/*activationDesc=*/activation_desc_,
/*xDesc=*/data_desc_,
/*sizeInBytes=*/&reserve_space_size),
platform::errors::External(
"The error has happened when calling "
"cudnnGetBatchNormalizationTrainingExReserveSpaceSize."));
reserve_space_ptr = reserve_space->mutable_data(ctx.GetPlace(), x->type(),
reserve_space_size);
workspace_ptr = workspace_tensor.mutable_data(ctx.GetPlace(), x->type(),
workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle, mode_, bnOps_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
nullptr, nullptr, data_desc_, y->template data<T>(), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
activation_desc_, workspace_ptr, workspace_size, reserve_space_ptr,
reserve_space_size),
platform::errors::External(
"The error has happened when calling "
"cudnnBatchNormalizationForwardTrainingEx."));
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnDestroyTensorDescriptor(data_desc_)."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnDestroyTensorDescriptor(bn_param_desc_)."));
}
};
template <typename T>
class FusedBatchNormActGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("It must use CUDAPlace."));
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
std::string act_type = ctx.Attr<std::string>("act_type");
const auto *x = ctx.Input<Tensor>("X");
const auto *y = ctx.Input<Tensor>("Y");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace");
const auto &x_dims = x->dims();
PADDLE_ENFORCE_EQ(x_dims.size() >= 2 && x_dims.size() <= 5, true,
platform::errors::PreconditionNotMet(
"The Input dim size should be between 2 and 5"));
int N, C, H, W, D;
const DataLayout data_layout = DataLayout::kNHWC;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(
d_scale && d_bias, true,
platform::errors::PreconditionNotMet(
"Both the scale grad and the bias grad must not be null."));
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL,
platform::errors::PreconditionNotMet(
"The scale only has one dimension."));
PADDLE_ENFORCE_EQ(
scale->dims()[0], C,
platform::errors::PreconditionNotMet(
"The size of scale is equal to the channel of Input(X)."));
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if ((N * H * W * D) == 1) {
if (act_type == "relu") {
auto x_v = framework::EigenVector<T>::Flatten(*x);
auto y_v = framework::EigenVector<T>::Flatten(*y);
auto dx_v = framework::EigenVector<T>::Flatten(*d_x);
auto dy_v = framework::EigenVector<T>::Flatten(*d_y);
auto &dev = *dev_ctx.eigen_device();
ReluGradFunctor<T>()(dev, x_v, y_v, dy_v, dx_v);
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported activation type"));
}
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
return;
}
std::vector<int> dims = {N, C, H, W, D};
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C};
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnCreateTensorDescriptor(&data_desc_)."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnCreateTensorDescriptor(&bn_param_desc_)."));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()),
platform::errors::External(
"The error has happened when calling cudnnSetTensorNdDescriptor."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
data_desc_, mode_),
platform::errors::External("The error has happened when calling "
"cudnnDeriveBNTensorDescriptor."));
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const auto *saved_mean_data =
saved_mean->template data<BatchNormParamType<T>>();
const auto *saved_var_data =
saved_var->template data<BatchNormParamType<T>>();
size_t workspace_size = 0;
void *workspace_ptr = nullptr;
Tensor workspace_tensor;
auto reserve_space_size = reserve_space->memory_size();
cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
platform::ScopedActivationDescriptor scope_act_desc;
cudnnActivationDescriptor_t activation_desc_ =
scope_act_desc.descriptor<T>(act_type);
// --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetBatchNormalizationBackwardExWorkspaceSize(
/*handle=*/dev_ctx.cudnn_handle(),
/*mode=*/mode_,
/*bnOps=*/bnOps_,
/*xDesc=*/data_desc_,
/*yDesc=*/data_desc_,
/*dyDesc=*/data_desc_,
/*dzDesc=*/nullptr,
/*dxDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/activation_desc_,
/*sizeInBytes=*/&workspace_size),
platform::errors::External(
"The error has happened when calling "
"cudnnGetBatchNormalizationBackwardExWorkspaceSize."));
workspace_ptr = workspace_tensor.mutable_data(ctx.GetPlace(), x->type(),
workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationBackwardEx(
/*handle=*/dev_ctx.cudnn_handle(),
/*mode=*/mode_,
/*bnOps=*/bnOps_,
/*alphaDataDiff=*/CudnnDataType<T>::kOne(),
/*betaDataDiff=*/CudnnDataType<T>::kZero(),
/*alphaParamDiff=*/CudnnDataType<T>::kOne(),
/*betaParamDiff=*/CudnnDataType<T>::kZero(),
/*xDesc=*/data_desc_,
/*xData=*/x->template data<T>(),
/*yDesc=*/data_desc_,
/*yData=*/y->template data<T>(),
/*dyDesc=*/data_desc_,
/*dyData=*/d_y->template data<T>(),
/*dzDesc=*/nullptr,
/*dzData=*/nullptr,
/*dxDesc=*/data_desc_,
/*dxData=*/d_x->template mutable_data<T>(ctx.GetPlace()),
/*dBnScaleBiasDesc=*/bn_param_desc_,
/*bnScaleData=*/scale->template data<BatchNormParamType<T>>(),
/*bnBiasData=*/bias->template data<BatchNormParamType<T>>(),
/*dBnScaleData=*/d_scale
->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
/*dBnBiasData=*/d_bias
->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
/*epsilon=*/epsilon,
/*savedMean=*/saved_mean_data,
/*savedInvVariance=*/saved_var_data,
/*activationDesc=*/activation_desc_,
/*workspace=*/workspace_ptr,
/*workSpaceSizeInBytes=*/workspace_size,
/*reserveSpace=*/const_cast<T *>(reserve_space->template data<T>()),
/*reserveSpaceSizeInBytes=*/reserve_space_size),
platform::errors::External("The error has happened when calling "
"cudnnBatchNormalizationBackwardEx."));
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnDestroyTensorDescriptor(data_desc_)."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnDestroyTensorDescriptor(bn_param_desc_)."));
}
};
} // namespace operators
} // namespace paddle
#if CUDNN_VERSION >= 7401
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
fused_batch_norm_act,
ops::FusedBatchNormActKernel<plat::CUDADeviceContext, float>,
ops::FusedBatchNormActKernel<plat::CUDADeviceContext, double>,
ops::FusedBatchNormActKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_batch_norm_act_grad,
ops::FusedBatchNormActGradKernel<plat::CUDADeviceContext, float>,
ops::FusedBatchNormActGradKernel<plat::CUDADeviceContext, double>,
ops::FusedBatchNormActGradKernel<plat::CUDADeviceContext, plat::float16>);
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FusedBatchNormActOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
};
class FusedBatchNormActGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusedBatchNormActOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
template <typename T>
class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetInput("Scale", this->Input("Scale"));
op->SetInput("Bias", this->Input("Bias"));
op->SetInput("SavedMean", this->Output("SavedMean"));
op->SetInput("SavedVariance", this->Output("SavedVariance"));
op->SetInput("ReserveSpace", this->Output("ReserveSpace"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return op;
}
};
class FusedBatchNormActOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
template <typename DeviceContext, typename T>
class FusedBatchNormActKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class FusedBatchNormActGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators
} // namespace paddle
...@@ -1994,6 +1994,26 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1994,6 +1994,26 @@ All parameter, weight, gradient are variables in Paddle.
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.fuse_elewise_add_act_ops = True build_strategy.fuse_elewise_add_act_ops = True
)DOC") )DOC")
.def_property(
"fuse_bn_act_ops",
[](const BuildStrategy &self) { return self.fuse_bn_act_ops_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
platform::errors::PreconditionNotMet(
"BuildStrategy is finlaized."));
self.fuse_bn_act_ops_ = b;
},
R"DOC((bool, optional): fuse_bn_act_ops indicate whether
to fuse batch_norm and activation_op,
it may make the execution faster. Default is False.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_bn_act_ops = True
)DOC")
.def_property( .def_property(
"fuse_relu_depthwise_conv", "fuse_relu_depthwise_conv",
[](const BuildStrategy &self) { [](const BuildStrategy &self) {
......
...@@ -182,6 +182,7 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op) ...@@ -182,6 +182,7 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_api) list(REMOVE_ITEM TEST_OPS test_basic_lstm_api)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op) list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op)
list(REMOVE_ITEM TEST_OPS test_imperative_debug_string) list(REMOVE_ITEM TEST_OPS test_imperative_debug_string)
list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass)
if (APPLE OR WIN32) if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_dataset) list(REMOVE_ITEM TEST_OPS test_dataset)
...@@ -301,6 +302,7 @@ py_test_modules(test_parallel_executor_seresnext_base_cpu MODULES test_parallel_ ...@@ -301,6 +302,7 @@ py_test_modules(test_parallel_executor_seresnext_base_cpu MODULES test_parallel_
py_test_modules(test_parallel_executor_seresnext_with_reduce_cpu MODULES test_parallel_executor_seresnext_with_reduce_cpu) py_test_modules(test_parallel_executor_seresnext_with_reduce_cpu MODULES test_parallel_executor_seresnext_with_reduce_cpu)
py_test_modules(test_parallel_executor_seresnext_with_fuse_all_reduce_cpu MODULES test_parallel_executor_seresnext_with_fuse_all_reduce_cpu) py_test_modules(test_parallel_executor_seresnext_with_fuse_all_reduce_cpu MODULES test_parallel_executor_seresnext_with_fuse_all_reduce_cpu)
py_test_modules(test_data_norm_op MODULES test_data_norm_op) py_test_modules(test_data_norm_op MODULES test_data_norm_op)
py_test_modules(test_fuse_bn_act_pass MODULES test_fuse_bn_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000)
if(NOT WIN32) if(NOT WIN32)
py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer) py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer)
...@@ -330,6 +332,6 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu ...@@ -330,6 +332,6 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_parallel_executor_crf test_sync_batch_norm_op test_parallel_executor_crf test_sync_batch_norm_op
test_parallel_executor_feed_persistable_var test_parallel_executor_feed_persistable_var
test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
test_data_norm_op test_imperative_using_non_zero_gpu test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass
test_optimizer_in_control_flow test_optimizer_in_control_flow
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST") test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import unittest
class TestFuseBatchNormActPass(unittest.TestCase):
def build_program(self, main_program, startup_program, use_cuda, seed=1):
main_program.random_seed = seed
startup_program.random_seed = seed
with fluid.program_guard(main_program, startup_program):
x = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32')
y = fluid.layers.data(name="y", shape=[1], dtype='int64')
hidden1 = fluid.layers.conv2d(
input=x,
filter_size=3,
num_filters=32,
stride=1,
padding=1,
act=None,
bias_attr=False,
data_format='NHWC')
param_attr = fluid.ParamAttr(
name='batch_norm_w',
initializer=fluid.initializer.Constant(value=1.0))
bias_attr = fluid.ParamAttr(
name='batch_norm_b',
initializer=fluid.initializer.Constant(value=0.0))
hidden2 = fluid.layers.batch_norm(
input=hidden1,
param_attr=param_attr,
bias_attr=bias_attr,
act='relu',
data_layout='NHWC')
hidden3 = fluid.layers.fc(input=hidden2, size=128, act='relu')
hidden4 = fluid.layers.batch_norm(
input=hidden3, act='relu', data_layout='NHWC')
prediction = fluid.layers.fc(input=hidden4, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=y)
loss = fluid.layers.mean(loss)
sgd = fluid.optimizer.SGD(learning_rate=0.001)
if use_cuda:
sgd = fluid.contrib.mixed_precision.decorate(
sgd, use_dynamic_loss_scaling=True, init_loss_scaling=128.0)
sgd.minimize(loss)
return x, y, loss
def check(self, place, use_cuda):
main_program = fluid.Program()
startup_program = fluid.Program()
x, y, loss = self.build_program(main_program, startup_program, use_cuda)
exe = fluid.Executor(place)
iters = 10
batch_size = 16
feeder = fluid.DataFeeder(feed_list=[x, y], place=place)
# close fused_bn_act_ops
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_bn_act_ops = False
binary = fluid.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size)
loss_vals = []
scope = fluid.Scope()
with fluid.scope_guard(scope):
exe.run(startup_program)
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[loss])
loss_vals.append(loss_v[0][0])
# open fused_bn_act_ops
build_strategy_fused = fluid.BuildStrategy()
build_strategy_fused.fuse_bn_act_ops = True
binary_fused = fluid.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy_fused)
train_reader_fused = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size)
loss_vals_fused = []
scope_fused = fluid.Scope()
with fluid.scope_guard(scope_fused):
exe.run(startup_program)
for _ in range(iters):
data = next(train_reader_fused())
loss_v = exe.run(binary_fused,
feed=feeder.feed(data),
fetch_list=[loss])
loss_vals_fused.append(loss_v[0][0])
# check loss
for i in range(iters):
self.assertAlmostEqual(loss_vals[i], loss_vals_fused[i], delta=1e-5)
def test_fuse_bn_act_pass_cpu(self):
place = fluid.CPUPlace()
self.check(place, use_cuda=False)
def test_fuse_bn_act_pass_cuda(self):
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
self.check(place, use_cuda=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册