未验证 提交 d402234b 编写于 作者: C chengduo 提交者: GitHub

Feature/op_fuse_pass (#12440)

* Add Preface

* Add demo code

* Save file

* Refine code

* seems can work

* use elementwise strategy

* Use ElementwiseComputeEx

* Add comments

* extract functions from operator

* Refine code

* Follow comment

* code refine

* add op_fuse  pass

* add backward

* code refine

* use TopologySortOperations

* follow comments

* refine IsFusible

* code enhance

* fix op_fusion_pass

* refine code

* refine fuse_elemwise_act_op

* adjust the input and output

* refine logic

* add intermediate_edge

* disable inplace

* follow comments

* refine logic

* follow comments

* Remove the removable IntermediateOut

* change strategy

* code refine

* enable fuse backward

* code refine

* code refine

* rename unit test

* follow comments
上级 5acdbbb4
......@@ -150,11 +150,11 @@ else()
endif()
if (NOT WIN32)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
graph graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fast_threaded_ssa_graph_executor)
fast_threaded_ssa_graph_executor fuse_elewise_add_act_pass)
endif() # NOT WIN32
cc_library(prune SRCS prune.cc DEPS framework_proto)
......
......@@ -54,6 +54,8 @@ struct BuildStrategy {
std::string debug_graphviz_path_{""};
bool fuse_elewise_add_act_ops_{false};
bool enable_data_balance_{false};
};
......
......@@ -37,6 +37,8 @@ pass_library(fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference)
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale"};
graph = FuseActElewiseAdd(std::move(graph), act_types);
graph = FuseElewiseAddAct(std::move(graph), act_types);
// backward
{
std::unordered_set<std::string> in_place_act_types = {"relu_grad"};
graph = FuseElewiseAddActInplaceGrad(std::move(graph), in_place_act_types);
}
// Remove the removable intermediate_out.
RemoveIntermediateOut(graph.get());
return graph;
}
// ele_add(x, act(y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("elewise_add_act", graph.get());
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("elewise_add_act/x")
->AsInput()
->assert_is_op_input("elementwise_add", "X");
patterns::ElewiseAddAct elewise_add_act_pattern(gpd.mutable_pattern(),
"elementwise_add");
elewise_add_act_pattern(x, act_types);
int found_elewise_add_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseElewiseAddAct fuse";
GET_IR_NODE_FROM_SUBGRAPH(ele_y, ele_y, elewise_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out,
elewise_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, elewise_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, elewise_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_add, ele_add, elewise_add_act_pattern);
std::string ele_x_n = subgraph.at(x)->Name();
std::string ele_y_n = ele_y->Name();
std::string ele_out_n = ele_out->Name();
std::string act_out_n = act_out->Name();
Node *elewise_add_act_node = CreateFuseElewiseAddActNode(
g, act, ele_add, ele_x_n, ele_y_n, ele_out_n, act_out_n);
VLOG(4) << "\n\t " << ele_x_n << " and " << ele_y_n << " -> "
<< ele_add->Name() << " -> " << ele_out_n << "\n"
<< "\t " << ele_out_n << " -> " << act->Name() << " -> "
<< act_out_n;
ReLinkNodes(g, ele_out, ele_add, act, elewise_add_act_node);
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
AddStatis(found_elewise_add_act_count);
return graph;
}
// act(ele_add(x,y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("act_elewise_add", graph.get());
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("act_elewise_add/x")
->AsInput()
->assert_is_ops_input(act_types, "X");
patterns::ActElewiseAdd act_elewise_add_pattern(gpd.mutable_pattern(),
"act_elewise_add");
act_elewise_add_pattern(x, act_types);
int found_elewise_add_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseElewiseAddAct fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_x, ele_x, act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out,
act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_add, ele_add, act_elewise_add_pattern);
std::string act_i_n = subgraph.at(x)->Name();
std::string act_o_n = act_out->Name();
std::string elewise_add_x_n = ele_x->Name();
std::string elewise_add_out_n = ele_out->Name();
Node *elewise_add_act_node = CreateFuseElewiseAddActNode(
g, ele_add, act, elewise_add_x_n, act_i_n, act_o_n, elewise_add_out_n);
VLOG(4) << "\n\t " << act_i_n << " -> " << act->Name() << " -> " << act_o_n
<< "\n\t " << act_o_n << " and " << elewise_add_x_n << " -> "
<< ele_add->Name() << " -> " << elewise_add_out_n;
ReLinkNodes(g, act_out, act, ele_add, elewise_add_act_node);
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
AddStatis(found_elewise_add_act_count);
return graph;
}
// the backward of act(ele_add(x,y))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("elewise_add_act_grad", graph.get());
GraphPatternDetector gpd;
auto *d_act_out = gpd.mutable_pattern()
->NewNode("elewise_add_act_grad_inplace/x")
->AsInput()
->assert_is_ops_input(act_types, GradVarName("Out"));
patterns::ElewiseAddActInplaceGrad elewise_add_act_grad_pattern(
gpd.mutable_pattern(), "elewise_add_act_grad_inplace");
elewise_add_act_grad_pattern(d_act_out, act_types);
int found_elewise_add_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseElewiseAddActGrad1 fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_itermediate_out, d_itermediate_out,
elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_y, ele_y, elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_add_grad, ele_add_grad,
elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_ele_x, d_ele_x, elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_ele_y, d_ele_y, elewise_add_act_grad_pattern);
std::string d_act_out_n = subgraph.at(d_act_out)->Name();
std::string act_out_n = act_out->Name();
std::string d_itermediate_out_n = d_itermediate_out->Name();
std::string ele_y_n = ele_y->Name();
std::string d_ele_x_n = d_ele_x->Name();
std::string d_ele_y_n = d_ele_y->Name();
OpDesc desc;
desc.SetType("fused_elemwise_activation_grad");
desc.SetInput("IntermediateOut", {});
desc.SetInput("X", {});
desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
desc.SetInput("Out", std::vector<std::string>({act_out_n}));
desc.SetInput(GradVarName("Out"), std::vector<std::string>({d_act_out_n}));
desc.SetOutput(GradVarName("X"), std::vector<std::string>({d_ele_x_n}));
desc.SetOutput(GradVarName("Y"), std::vector<std::string>({d_ele_y_n}));
desc.SetOutput(GradVarName("IntermediateOut"),
std::vector<std::string>({d_itermediate_out_n}));
desc.SetAttr("save_intermediate_out", false);
desc.SetAttr("functor_list",
std::vector<std::string>(
{act_grad->Op()->Type(), ele_add_grad->Op()->Type()}));
for (auto &n : {act_grad->Op(), ele_add_grad->Op()}) {
for (auto &m_ele : n->GetAttrMap()) {
desc.SetAttr(m_ele.first, m_ele.second);
}
}
auto fused_node = g->CreateOpNode(&desc);
VLOG(4) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> "
<< act_grad->Name() << " -> " << d_itermediate_out_n << "\n\t "
<< d_itermediate_out_n << " and " << act_out_n << " -> "
<< ele_add_grad->Name() << " -> " << d_itermediate_out_n;
ReLinkNodes(g, d_itermediate_out, act_grad, ele_add_grad, fused_node);
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
AddStatis(found_elewise_add_act_count);
return graph;
}
Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode(
Graph *g, const Node *op_1, const Node *op_2, const std::string &ele_x_n,
const std::string &ele_y_n, const std::string &ele_out_n,
const std::string &act_out_n) const {
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({ele_x_n}));
desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
desc.SetOutput("Out", std::vector<std::string>({act_out_n}));
desc.SetOutput("IntermediateOut", std::vector<std::string>({ele_out_n}));
desc.SetType("fused_elemwise_activation");
desc.SetAttr("save_intermediate_out", true);
desc.SetAttr("functor_list", std::vector<std::string>(
{op_1->Op()->Type(), op_2->Op()->Type()}));
// Set attrs
for (auto &n : {op_1->Op(), op_2->Op()}) {
for (auto &m_ele : n->GetAttrMap()) {
desc.SetAttr(m_ele.first, m_ele.second);
}
}
auto elewise_add_act_node = g->CreateOpNode(&desc);
return elewise_add_act_node;
}
void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
std::unordered_set<const Node *> need_removed_nodes;
for (auto &cur_node : graph->Nodes()) {
if (cur_node->IsVar()) continue;
if (cur_node->Name() == "fused_elemwise_activation") {
bool save_intermediate_out =
boost::get<bool>(cur_node->Op()->GetAttr("save_intermediate_out"));
auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut");
PADDLE_ENFORCE(
save_intermediate_out && !intermediate_out_args.empty(),
"The %s should save the intermediate_out in the fusing stage.",
cur_node->Name());
// If the intermediate_out's output is empty, it should be removed.
auto cur_node_outputs = cur_node->outputs;
for (auto &out : cur_node_outputs) {
if (out->Name() == intermediate_out_args[0]) {
if (out->outputs.size() == 0) {
cur_node->outputs = this->RemoveNode(out, cur_node->outputs);
need_removed_nodes.insert(std::move(out));
cur_node->Op()->SetAttr("save_intermediate_out", false);
}
}
}
} else if (cur_node->Name() == "fused_elemwise_activation_grad") {
auto intermediate_out_grad_args =
cur_node->Op()->Output(GradVarName("IntermediateOut"));
PADDLE_ENFORCE(
!intermediate_out_grad_args.empty(),
"The %s should save the intermediate_out in the fusing stage.",
cur_node->Name());
auto cur_node_outputs = cur_node->outputs;
// If the intermediate_out_g's output is empty, it should be removed.
for (auto &out : cur_node_outputs) {
if (out->Name() == intermediate_out_grad_args[0] &&
out->outputs.empty()) {
cur_node->Op()->SetOutput(GradVarName("IntermediateOut"), {});
cur_node->outputs = this->RemoveNode(out, cur_node->outputs);
need_removed_nodes.insert(std::move(out));
}
}
}
}
GraphSafeRemoveNodes(graph, need_removed_nodes);
}
void FuseElewiseAddActPass::ReLinkNodes(Graph *graph,
const Node *intermediate_out,
Node *op_1, Node *op_2,
Node *fused_op) const { // delete act
for (auto &in : op_1->inputs) {
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_1, fused_op, in->outputs);
}
std::unordered_set<const Node *> nodes2delete;
for (auto &out : op_1->outputs) {
if (out->IsCtrlVar()) {
auto result_iter = std::find_if(
op_2->inputs.begin(), op_2->inputs.end(),
[&out](const Node *node) -> bool { return node == out; });
if (result_iter == op_2->inputs.end()) {
IR_OP_VAR_LINK(fused_op, out);
} else {
nodes2delete.emplace(out);
}
} else {
PADDLE_ENFORCE(out == intermediate_out);
IR_OP_VAR_LINK(fused_op, out);
}
}
for (auto &in : op_2->inputs) {
if (in == intermediate_out || nodes2delete.count(in)) {
continue;
}
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_2, fused_op, in->outputs);
}
for (auto &out : op_2->outputs) {
IR_OP_VAR_LINK(fused_op, out);
}
nodes2delete.insert(std::move(op_1));
nodes2delete.insert(std::move(op_2));
GraphSafeRemoveNodes(graph, nodes2delete);
}
std::vector<Node *> FuseElewiseAddActPass::ReplaceNode(
Node *cur_node, Node *new_node, const std::vector<Node *> &nodes) const {
std::vector<Node *> new_list(nodes.size());
bool has_replaced = false;
std::transform(nodes.begin(), nodes.end(), new_list.begin(),
[&](Node *node) -> Node * {
if (node == cur_node) {
has_replaced = true;
return new_node;
}
return node;
});
PADDLE_ENFORCE(has_replaced, "Not find %s in the node list.",
cur_node->Name());
return new_list;
}
std::vector<Node *> FuseElewiseAddActPass::RemoveNode(
Node *trg_node, const std::vector<Node *> &nodes) const {
std::vector<Node *> new_list(nodes.size());
auto end_iter =
std::copy_if(nodes.begin(), nodes.end(), new_list.begin(),
[&](Node *node) -> bool { return node != trg_node; });
new_list.resize(
static_cast<uint64_t>(std::distance(new_list.begin(), end_iter)));
return new_list;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_elewise_add_act_pass,
paddle::framework::ir::FuseElewiseAddActPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the ElewiseAdd and activation
*/
class FuseElewiseAddActPass : public FusePassBase {
public:
virtual ~FuseElewiseAddActPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
std::unique_ptr<ir::Graph> FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
/**
* Remove the removable intermediate_out.
* - If the intermediate_out is only used by the backward op, but the
* backward op doesn't use intermediate_out.
* - If the intermediate_out_grad is not used by any op.
*/
void RemoveIntermediateOut(Graph *graph) const;
std::vector<Node *> ReplaceNode(Node *cur_node, Node *new_node,
const std::vector<Node *> &nodes) const;
std::vector<Node *> RemoveNode(Node *trg_node,
const std::vector<Node *> &nodes) const;
void ReLinkNodes(Graph *graph, const Node *intermediate_out, Node *op_1,
Node *op_2, Node *fused_op) const;
Node *CreateFuseElewiseAddActNode(Graph *g, const Node *op_1,
const Node *op_2,
const std::string &ele_x_n,
const std::string &ele_y_n,
const std::string &ele_out_n,
const std::string &act_out_n) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -20,10 +20,10 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace framework {
namespace ir {
......@@ -34,7 +34,7 @@ using string::Style;
size_t PDPattern::id_ = 0UL;
PDNode* PDPattern::NewNode(const std::string& name) {
PDNode *PDPattern::NewNode(const std::string &name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
"PDNode's name should be unique, get duplicate [%s]",
......@@ -42,12 +42,12 @@ PDNode* PDPattern::NewNode(const std::string& name) {
}
nodes_.emplace_back(new PDNode(this, name));
auto* cur = nodes_.back().get();
auto *cur = nodes_.back().get();
node_map_[name] = cur;
return cur;
}
PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
PDNode *PDPattern::NewNode(PDNode::teller_t &&teller, const std::string &name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
"PDNode's name should be unique, get duplicate [%s]",
......@@ -55,12 +55,12 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
}
nodes_.emplace_back(new PDNode(std::move(teller), this, name));
auto* cur = nodes_.back().get();
auto *cur = nodes_.back().get();
node_map_[name] = cur;
return cur;
}
PDNode* PDPattern::RetrieveNode(const std::string& id) const {
PDNode *PDPattern::RetrieveNode(const std::string &id) const {
auto it = node_map_.find(id);
if (it == node_map_.end()) {
return nullptr;
......@@ -69,14 +69,14 @@ PDNode* PDPattern::RetrieveNode(const std::string& id) const {
return it->second;
}
void PDPattern::AddEdge(PDNode* a, PDNode* b) {
void PDPattern::AddEdge(PDNode *a, PDNode *b) {
PADDLE_ENFORCE(a);
PADDLE_ENFORCE(b);
PADDLE_ENFORCE(a != b, "can't connect to the same nodes.");
edges_.emplace_back(a, b);
}
void GraphPatternDetector::operator()(Graph* graph,
void GraphPatternDetector::operator()(Graph *graph,
GraphPatternDetector::handle_t handler) {
if (!MarkPDNodesInGraph(*graph)) {
return;
......@@ -90,18 +90,18 @@ void GraphPatternDetector::operator()(Graph* graph,
if (subgraphs.empty()) return;
PrettyLogEndl(Style::detail(), "--- detect %d subgraphs", subgraphs.size());
int id = 0;
for (auto& g : subgraphs) {
for (auto &g : subgraphs) {
VLOG(3) << "optimizing #" << id++ << " subgraph";
handler(g, graph);
}
}
bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
VLOG(3) << "mark pdnodes in graph";
if (graph.Nodes().empty()) return false;
for (auto& node : GraphTraits::DFS(graph)) {
for (const auto& pdnode : pattern_.nodes()) {
for (auto &node : GraphTraits::DFS(graph)) {
for (const auto &pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) {
VLOG(4) << "pdnode " << pdnode->name() << " marked";
pdnodes2nodes_[pdnode.get()].insert(&node);
......@@ -109,15 +109,15 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
}
}
// Check to early stop if some PDNode can't find matched Node.
for (auto& pdnode : pattern_.nodes()) {
for (auto &pdnode : pattern_.nodes()) {
if (!pdnodes2nodes_.count(pdnode.get())) {
VLOG(4) << pdnode->name() << " can't find matched Node, early stop";
// return false;
}
}
for (auto& item : pdnodes2nodes_) {
for (auto& n : item.second) {
GetMarkedNodes(const_cast<Graph*>(&graph)).insert(n);
for (auto &item : pdnodes2nodes_) {
for (auto &n : item.second) {
GetMarkedNodes(const_cast<Graph *>(&graph)).insert(n);
}
}
VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
......@@ -128,28 +128,28 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
// The intermediate Nodes can only link to the nodes inside the pattern, or this
// subgraph will be droped.
void GraphPatternDetector::ValidateByNodeRole(
std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
std::vector<GraphPatternDetector::subgraph_t> *subgraphs) {
std::vector<GraphPatternDetector::subgraph_t> result;
subgraphs->erase(
std::remove_if(
subgraphs->begin(), subgraphs->end(),
[](const GraphPatternDetector::subgraph_t& subgraph) -> bool {
[](const GraphPatternDetector::subgraph_t &subgraph) -> bool {
// Collect the inputs and outputs.
std::unordered_set<Node*> ios;
for (auto& item : subgraph) {
std::unordered_set<Node *> ios;
for (auto &item : subgraph) {
if (!item.first->IsIntermediate()) {
ios.insert(item.second);
}
}
for (auto& item : subgraph) {
for (auto &item : subgraph) {
if (item.first->IsIntermediate()) {
for (auto* x : item.second->inputs) {
for (auto *x : item.second->inputs) {
if (!ios.count(x)) {
return true;
}
}
for (auto* x : item.second->outputs) {
for (auto *x : item.second->outputs) {
if (!ios.count(x)) {
return true;
}
......@@ -162,9 +162,9 @@ void GraphPatternDetector::ValidateByNodeRole(
}
struct HitGroup {
std::unordered_map<PDNode*, Node*> roles;
std::unordered_map<PDNode *, Node *> roles;
bool Match(Node* node, PDNode* pat) {
bool Match(Node *node, PDNode *pat) {
if (nodes_.count(node)) {
if (!roles.count(pat)) return false;
return roles[pat] == node;
......@@ -172,18 +172,18 @@ struct HitGroup {
return !roles.count(pat) || roles.at(pat) == node;
}
void Register(Node* node, PDNode* pat) {
void Register(Node *node, PDNode *pat) {
roles[pat] = node;
nodes_.insert(node);
}
private:
std::unordered_set<Node*> nodes_;
std::unordered_set<Node *> nodes_;
};
// Tell whether Node a links to b.
bool IsNodesLink(Node* a, Node* b) {
for (auto* node : a->outputs) {
bool IsNodesLink(Node *a, Node *b) {
for (auto *node : a->outputs) {
if (b == node) {
return true;
}
......@@ -198,10 +198,10 @@ GraphPatternDetector::DetectPatterns() {
std::vector<HitGroup> init_groups;
std::array<std::vector<HitGroup>, 2> bi_records;
// PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto* first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
: pattern_.edges().front().first;
if (!pdnodes2nodes_.count(first_pnode)) return result;
for (auto* node : pdnodes2nodes_[first_pnode]) {
for (auto *node : pdnodes2nodes_[first_pnode]) {
HitGroup group;
group.roles[first_pnode] = node;
init_groups.emplace_back(group);
......@@ -212,21 +212,21 @@ GraphPatternDetector::DetectPatterns() {
// Extend a PDNode to subgraphs by deducing the connection relations defined
// in edges of PDNodes.
for (const auto& edge : pattern_.edges()) {
for (const auto &edge : pattern_.edges()) {
VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name();
// TODO(Superjomn) Fix bug here, the groups might be duplicate here.
// Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected.
auto& pre_groups = bi_records[step % 2];
auto& cur_groups = bi_records[1 - (step++ % 2)];
auto &pre_groups = bi_records[step % 2];
auto &cur_groups = bi_records[1 - (step++ % 2)];
cur_groups.clear();
if (pre_groups.empty()) break;
// source -> target
for (Node* source : pdnodes2nodes_[edge.first]) {
for (Node* target : pdnodes2nodes_[edge.second]) {
for (Node *source : pdnodes2nodes_[edge.first]) {
for (Node *target : pdnodes2nodes_[edge.second]) {
VLOG(8) << "check " << source->id() << " -- " << target->id();
// TODO(Superjomn) add some prune strategies.
for (const auto& group : pre_groups) {
for (const auto &group : pre_groups) {
HitGroup new_group = group;
if (IsNodesLink(source, target) &&
new_group.Match(source, edge.first)) {
......@@ -241,17 +241,17 @@ GraphPatternDetector::DetectPatterns() {
}
}
VLOG(3) << "step " << step << " get records: " << cur_groups.size();
for (auto& group : cur_groups) {
for (auto& item : group.roles) {
for (auto &group : cur_groups) {
for (auto &item : group.roles) {
VLOG(4) << "node " << item.second->id() << " as " << item.first->name();
}
VLOG(4) << "=========================================================";
}
}
for (auto& group : bi_records[step % 2]) {
for (auto &group : bi_records[step % 2]) {
GraphPatternDetector::subgraph_t subgraph;
for (auto& role : group.roles) {
for (auto &role : group.roles) {
subgraph.emplace(role.first, role.second);
}
result.emplace_back(subgraph);
......@@ -260,16 +260,16 @@ GraphPatternDetector::DetectPatterns() {
}
void GraphPatternDetector::UniquePatterns(
std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
std::vector<GraphPatternDetector::subgraph_t> *subgraphs) {
if (subgraphs->empty()) return;
std::vector<GraphPatternDetector::subgraph_t> result;
std::unordered_set<size_t> set;
for (auto& g : *subgraphs) {
for (auto &g : *subgraphs) {
size_t key = 0;
for (auto& item : g) {
key ^= std::hash<void*>{}(item.first);
key ^= std::hash<void*>{}(item.second);
for (auto &item : g) {
key ^= std::hash<void *>{}(item.first);
key ^= std::hash<void *>{}(item.second);
}
if (!set.count(key)) {
result.emplace_back(g);
......@@ -280,20 +280,20 @@ void GraphPatternDetector::UniquePatterns(
}
void GraphPatternDetector::RemoveOverlappedMatch(
std::vector<subgraph_t>* subgraphs) {
std::vector<subgraph_t> *subgraphs) {
std::vector<subgraph_t> result;
std::unordered_set<Node*> node_set;
std::unordered_set<Node *> node_set;
for (const auto& subgraph : *subgraphs) {
for (const auto &subgraph : *subgraphs) {
bool valid = true;
for (auto& item : subgraph) {
for (auto &item : subgraph) {
if (item.first->IsIntermediate() && node_set.count(item.second)) {
valid = false;
break;
}
}
if (valid) {
for (auto& item : subgraph) {
for (auto &item : subgraph) {
node_set.insert(item.second);
}
result.push_back(subgraph);
......@@ -307,71 +307,81 @@ std::string PDPattern::DotString() const {
Dot dot;
int id = 0;
// Create Nodes
std::unordered_map<PDNode*, std::string> node2dot;
for (const auto& node : nodes()) {
std::unordered_map<PDNode *, std::string> node2dot;
for (const auto &node : nodes()) {
std::string node_id = "Node" + std::to_string(id++);
dot.AddNode(node_id, {}, node->name());
node2dot[node.get()] = node_id;
}
// Create Edges
for (const auto& edge : edges()) {
for (const auto &edge : edges()) {
if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) {
LOG(ERROR) << "no node " << edge.first << " " << edge.second;
continue;
}
auto& src = node2dot.at(edge.first);
auto& trg = node2dot.at(edge.second);
auto &src = node2dot.at(edge.first);
auto &trg = node2dot.at(edge.second);
dot.AddEdge(src, trg, {});
}
return dot.Build();
}
PDNode& PDNode::LinksTo(const std::vector<PDNode*>& others) {
PDNode &PDNode::LinksTo(const std::vector<PDNode *> &others) {
// extend outlinks.
for (PDNode* x : others) {
for (PDNode *x : others) {
pattern_->AddEdge(this, x);
}
return *this;
}
PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
PDNode &PDNode::LinksFrom(const std::vector<PDNode *> &others) {
// extend outlinks.
for (PDNode* x : others) {
for (PDNode *x : others) {
pattern_->AddEdge(x, this);
}
return *this;
}
PDNode* PDNode::assert_is_op() {
asserts_.emplace_back([](Node* x) { return x && x->IsOp(); });
PDNode *PDNode::assert_is_op() {
asserts_.emplace_back([](Node *x) { return x && x->IsOp(); });
return this;
}
PDNode* PDNode::assert_is_op(const std::string& op_type) {
asserts_.emplace_back([op_type](Node* x) {
PDNode *PDNode::assert_is_op(const std::string &op_type) {
asserts_.emplace_back([op_type](Node *x) {
return x && x->IsOp() && x->Op()->Type() == op_type;
});
return this;
}
PDNode* PDNode::assert_is_var() {
asserts_.emplace_back([](Node* x) { return x && x->IsVar(); });
PDNode *PDNode::assert_is_var() {
asserts_.emplace_back([](Node *x) { return x && x->IsVar(); });
return this;
}
PDNode* PDNode::assert_var_not_persistable() {
PDNode *PDNode::assert_is_not_ctrl_var() {
asserts_.emplace_back([](Node *x) { return x && !x->IsCtrlVar(); });
return this;
}
PDNode *PDNode::assert_var_not_persistable() {
assert_is_var();
asserts_.emplace_back([](Node* x) { return !x->Var()->Persistable(); });
asserts_.emplace_back([](Node *x) { return !x->Var()->Persistable(); });
return this;
}
PDNode* PDNode::assert_is_persistable_var() {
PDNode *PDNode::assert_is_persistable_var() {
assert_is_var();
asserts_.emplace_back([=](Node* x) { return x->Var()->Persistable(); });
asserts_.emplace_back([=](Node *x) { return x->Var()->Persistable(); });
return this;
}
PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth) {
PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
const std::string &argument, int nth) {
assert_is_var();
assert_is_op_input(op_type);
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->outputs) {
if (op->IsOp() && op->Op()->Type() == op_type &&
IsNthInput(x, op, argument, nth))
return true;
......@@ -380,11 +390,12 @@ PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type,
});
return this;
}
PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth) {
PDNode *PDNode::assert_is_op_nth_output(const std::string &op_type,
const std::string &argument, int nth) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
if (op->IsOp() && op->Op()->Type() == op_type &&
IsNthOutput(x, op, argument, nth))
return true;
......@@ -393,10 +404,11 @@ PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type,
});
return this;
}
PDNode* PDNode::assert_is_only_input_of_op(const std::string& op_type) {
PDNode *PDNode::assert_is_only_input_of_op(const std::string &op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type &&
op->inputs.size() == 1) {
return true;
......@@ -406,10 +418,11 @@ PDNode* PDNode::assert_is_only_input_of_op(const std::string& op_type) {
});
return this;
}
PDNode* PDNode::assert_is_only_output_of_op(const std::string& op_type) {
PDNode *PDNode::assert_is_only_output_of_op(const std::string &op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type &&
op->outputs.size() == 1) {
return true;
......@@ -419,10 +432,11 @@ PDNode* PDNode::assert_is_only_output_of_op(const std::string& op_type) {
});
return this;
}
PDNode* PDNode::assert_is_op_output(const std::string& op_type) {
PDNode *PDNode::assert_is_op_output(const std::string &op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
return true;
}
......@@ -431,16 +445,17 @@ PDNode* PDNode::assert_is_op_output(const std::string& op_type) {
});
return this;
}
PDNode* PDNode::assert_is_op_output(const std::string& op_type,
const std::string& argument) {
PDNode *PDNode::assert_is_op_output(const std::string &op_type,
const std::string &argument) {
assert_is_var();
assert_is_op_nth_output(op_type, argument, 0);
return this;
}
PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
PDNode *PDNode::assert_is_op_input(const std::string &op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
return true;
}
......@@ -449,72 +464,161 @@ PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
});
return this;
}
PDNode* PDNode::assert_is_op_input(const std::string& op_type,
const std::string& argument) {
PDNode *PDNode::assert_is_op_input(const std::string &op_type,
const std::string &argument) {
assert_is_var();
assert_is_op_nth_input(op_type, argument, 0);
return this;
}
PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) {
PDNode *PDNode::assert_op_has_n_inputs(const std::string &op_type, size_t n) {
assert_is_op(op_type);
asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; });
asserts_.emplace_back([=](Node *x) { return x->inputs.size() == n; });
return this;
}
PDNode* PDNode::assert_op_has_n_outputs(const std::string& op_type, size_t n) {
PDNode *PDNode::assert_op_has_n_outputs(const std::string &op_type, size_t n) {
assert_is_op(op_type);
asserts_.emplace_back([=](Node* x) { return x->outputs.size() == n; });
asserts_.emplace_back([=](Node *x) { return x->outputs.size() == n; });
return this;
}
PDNode* PDNode::assert_more(PDNode::teller_t&& teller) {
PDNode *PDNode::assert_more(PDNode::teller_t &&teller) {
asserts_.emplace_back(std::move(teller));
return this;
}
bool VarLinksToOp(Node* node, const std::string& op_type) {
for (auto* out : node->outputs) {
PDNode *PDNode::assert_is_ops(const std::unordered_set<std::string> &op_types) {
asserts_.emplace_back([op_types](Node *x) {
return x && x->IsOp() && op_types.count(x->Op()->Type());
});
return this;
}
PDNode *PDNode::assert_is_ops_nth_input(
const std::unordered_set<std::string> &op_types,
const std::string &argument, int nth) {
assert_is_var();
assert_is_ops_input(op_types);
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->outputs) {
if (op->IsOp() && op_types.count(op->Op()->Type()) &&
IsNthInput(x, op, argument, nth))
return true;
}
return false;
});
return this;
}
PDNode *PDNode::assert_is_ops_nth_output(
const std::unordered_set<std::string> &op_types,
const std::string &argument, int nth) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
if (op->IsOp() && op_types.count(op->Op()->Type()) &&
IsNthOutput(x, op, argument, nth))
return true;
}
return false;
});
return this;
}
PDNode *PDNode::assert_is_ops_output(
const std::unordered_set<std::string> &op_types) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type())) {
return true;
}
}
return false;
});
return this;
}
PDNode *PDNode::assert_is_ops_output(
const std::unordered_set<std::string> &op_types,
const std::string &argument) {
assert_is_var();
assert_is_ops_nth_output(op_types, argument, 0);
return this;
}
PDNode *PDNode::assert_is_ops_input(
const std::unordered_set<std::string> &op_types) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->outputs) {
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type())) {
return true;
}
}
return false;
});
return this;
}
PDNode *PDNode::assert_is_ops_input(
const std::unordered_set<std::string> &op_types,
const std::string &argument) {
assert_is_var();
assert_is_ops_nth_input(op_types, argument, 0);
return this;
}
bool VarLinksToOp(Node *node, const std::string &op_type) {
for (auto *out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth) {
bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->Op()->Input(argument).size() <= nth) return false;
return var->Name() == op->Op()->Input(argument)[nth];
}
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth) {
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->Op()->Output(argument).size() <= nth) return false;
return var->Name() == op->Op()->Output(argument)[nth];
}
void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes) {
for (auto* node : nodes) {
graph->RemoveNode(const_cast<Node*>(node));
void GraphSafeRemoveNodes(Graph *graph,
const std::unordered_set<const Node *> &nodes) {
for (auto *node : nodes) {
graph->RemoveNode(const_cast<Node *>(node));
}
for (auto* node : graph->Nodes()) {
for (auto *node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it);
it = const_cast<Node *>(node)->inputs.erase(it);
} else {
it++;
}
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it);
it = const_cast<Node *>(node)->outputs.erase(it);
} else {
it++;
}
}
}
}
bool VarLinksFromOp(Node* node, const std::string& op_type) {
for (auto* out : node->inputs) {
bool VarLinksFromOp(Node *node, const std::string &op_type) {
for (auto *out : node->inputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
......@@ -522,30 +626,30 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
return false;
}
PDNode* patterns::ConvReLU::operator()(
paddle::framework::ir::PDNode* conv_input) {
PDNode *patterns::ConvReLU::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto* conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto* relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu");
// Create variables
// Filter
auto* conv_weight_var = pattern->NewNode(conv_weight_repr())
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// Bias
auto* conv_bias_var = pattern->NewNode(conv_bias_repr())
auto *conv_bias_var = pattern->NewNode(conv_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Bias");
// intermediate variable, will be removed in the IR after fuse.
auto* conv_out_var = pattern->NewNode(conv_out_repr())
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("relu");
// output
auto* relu_out_var = pattern->NewNode(relu_out_repr())
auto *relu_out_var = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu");
......@@ -555,18 +659,18 @@ PDNode* patterns::ConvReLU::operator()(
return relu_out_var;
}
PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x,
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
// Create shared nodes.
x->assert_is_op_input("mul", "X");
auto* mul = pattern->NewNode(mul_repr())->assert_is_op("mul");
auto *mul = pattern->NewNode(mul_repr())->assert_is_op("mul");
auto* mul_w_var = pattern->NewNode(w_repr())
auto *mul_w_var = pattern->NewNode(w_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("mul", "Y");
auto* mul_out_var =
auto *mul_out_var =
pattern->NewNode(mul_out_repr())->assert_is_op_output("mul");
if (!with_bias) { // not with bias
......@@ -577,14 +681,14 @@ PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x,
} else { // with bias
mul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
// Create operators.
auto* elementwise_add = pattern->NewNode(elementwise_add_repr())
auto *elementwise_add = pattern->NewNode(elementwise_add_repr())
->assert_is_op("elementwise_add");
// Create variables.
auto* bias = pattern->NewNode(bias_repr())
auto *bias = pattern->NewNode(bias_repr())
->assert_is_op_input("elementwise_add")
->AsInput();
auto* fc_out = pattern->NewNode(Out_repr())
auto *fc_out = pattern->NewNode(Out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add");
......@@ -594,11 +698,11 @@ PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x,
}
}
PDNode* patterns::LSTM::operator()(PDNode* x) {
PDNode *patterns::LSTM::operator()(PDNode *x) {
x->assert_is_op_input("lstm", "Input");
auto* lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm");
auto *lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm");
#define NEW_NODE(arg__, io__) \
auto* arg__ = \
auto *arg__ = \
pattern->NewNode(arg__##_repr())->assert_is_op_##io__("lstm", #arg__);
// Currently, the H0 and C0 are optional
......@@ -619,11 +723,11 @@ PDNode* patterns::LSTM::operator()(PDNode* x) {
return Hidden;
}
PDNode* patterns::GRU::operator()(PDNode* x) {
PDNode *patterns::GRU::operator()(PDNode *x) {
x->assert_is_op_input("gru", "Input");
auto* gru_op = pattern->NewNode(gru_repr())->assert_is_op("gru");
auto *gru_op = pattern->NewNode(gru_repr())->assert_is_op("gru");
#define NEW_NODE(arg__, io__) \
auto* arg__ = \
auto *arg__ = \
pattern->NewNode(arg__##_repr())->assert_is_op_##io__("gru", #arg__);
NEW_NODE(Weight, input);
......@@ -648,6 +752,100 @@ PDNode* patterns::GRU::operator()(PDNode* x) {
return Hidden;
}
PDNode *patterns::ActElewiseAdd::operator()(
paddle::framework::ir::PDNode *in_var,
std::unordered_set<std::string> act_types) {
in_var->assert_is_ops_input(act_types, "X");
auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types);
auto *act_out_var = pattern->NewNode(act_out_repr())
->assert_is_not_ctrl_var()
->assert_is_ops_output(act_types);
act_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto *ele_x_var = pattern->NewNode(ele_x_repr())
->assert_is_not_ctrl_var()
->assert_is_op_input("elementwise_add")
->AsInput();
auto *elementwise_add =
pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add");
auto *elewise_add_out = pattern->NewNode(elewise_add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");
act->LinksFrom({in_var}).LinksTo({act_out_var});
elementwise_add->LinksFrom({act_out_var, ele_x_var})
.LinksTo({elewise_add_out});
return elewise_add_out;
}
PDNode *patterns::ElewiseAddAct::operator()(
paddle::framework::ir::PDNode *ele_x_var,
std::unordered_set<std::string> act_types) {
auto *ele_y_var = pattern->NewNode(ele_y_repr())
->assert_is_op_input("elementwise_add", "Y");
auto *ele_add =
pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add");
auto *ele_out_var = pattern->NewNode(elewise_add_out_repr())
->assert_is_op_output("elementwise_add", "Out");
ele_out_var->AsIntermediate()->assert_is_ops_input(act_types);
auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types);
auto *act_out_var =
pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out");
ele_add->LinksFrom({ele_x_var, ele_y_var}).LinksTo({ele_out_var});
act->LinksFrom({ele_out_var}).LinksTo({act_out_var});
return act_out_var;
}
PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
paddle::framework::ir::PDNode *d_act_out_var,
std::unordered_set<std::string> act_types) {
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
auto *act_grad = pattern->NewNode(act_grad_repr())->assert_is_ops(act_types);
auto *act_out_var =
pattern->NewNode(act_out_repr())->assert_is_ops_input(act_types, "Out");
auto *d_intermediate_var =
pattern->NewNode(d_itermediate_out_repr())
->assert_is_ops_output(act_types, GradVarName("X"));
act_grad->LinksFrom({d_act_out_var, act_out_var})
.LinksTo({d_intermediate_var});
auto *ele_y_var = pattern->NewNode(ele_y_repr())
->assert_is_not_ctrl_var()
->assert_is_op_input("elementwise_add_grad", "Y");
auto *ele_add_grad = pattern->NewNode(ele_add_grad_repr())
->assert_is_op("elementwise_add_grad");
auto *d_ele_x_var =
pattern->NewNode(d_ele_x_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("elementwise_add_grad", GradVarName("X"));
auto *d_ele_y_var =
pattern->NewNode(d_ele_y_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("elementwise_add_grad", GradVarName("Y"));
ele_add_grad->LinksFrom({d_intermediate_var, ele_y_var})
.LinksTo({d_ele_x_var, d_ele_y_var});
return ele_add_grad;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -95,6 +95,7 @@ struct PDNode {
PDNode* assert_is_op();
PDNode* assert_is_op(const std::string& op_type);
PDNode* assert_is_var();
PDNode* assert_is_not_ctrl_var();
PDNode* assert_var_not_persistable();
PDNode* assert_is_persistable_var();
PDNode* assert_is_op_output(const std::string& op_type);
......@@ -113,6 +114,20 @@ struct PDNode {
PDNode* assert_op_has_n_outputs(const std::string& op_type, size_t n);
PDNode* assert_more(teller_t&& teller);
PDNode* assert_is_ops_output(const std::unordered_set<std::string>& op_types);
PDNode* assert_is_ops(const std::unordered_set<std::string>& op_types);
PDNode* assert_is_ops_output(const std::unordered_set<std::string>& op_types,
const std::string& argument);
PDNode* assert_is_ops_nth_input(
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
PDNode* assert_is_ops_input(const std::unordered_set<std::string>& op_types);
PDNode* assert_is_ops_input(const std::unordered_set<std::string>& op_types,
const std::string& argument);
PDNode* assert_is_ops_nth_output(
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
private:
PDNode(PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
......@@ -447,6 +462,68 @@ struct GRU : public PatternBase {
PATTERN_DECL_NODE(Hidden);
};
// The following patterns are used to fuse elewise_add and act
// formula: act(ele_add(x, y))
// op: elementwise_add + act
// named nodes: elementwise_add, act
// ele_x, ele_y, elewise_add_out, act_out
struct ElewiseAddAct : public PatternBase {
ElewiseAddAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elewise_add_act") {}
PDNode* operator()(PDNode* x, std::unordered_set<std::string> acts);
// declare operator node's name
PATTERN_DECL_NODE(ele_add);
PATTERN_DECL_NODE(act);
// declare variable node's name
PATTERN_DECL_NODE(elewise_add_out);
PATTERN_DECL_NODE(ele_y);
PATTERN_DECL_NODE(act_out);
};
// formula: ele_add(x, act(y))
// op: elementwise_add + act
// named nodes: elementwise_add, act
// act_in, act_out, ele_x, elewise_add_out
struct ActElewiseAdd : public PatternBase {
ActElewiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "act_elewise_add") {}
PDNode* operator()(PDNode* x, std::unordered_set<std::string> acts);
// declare operator node's name
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(ele_add);
// declare variable node's name
PATTERN_DECL_NODE(act_out);
PATTERN_DECL_NODE(ele_x);
PATTERN_DECL_NODE(elewise_add_out);
};
// the backward of act(ele_add(x, y))
// the act is inplace.
// op: elementwise_add_grad + act_grad
// named nodes: elementwise_add_grad, act_grad
// act_out, act_out_g, ele_y, d_itermediate_out, d_ele_x, d_ele_y
struct ElewiseAddActInplaceGrad : public PatternBase {
ElewiseAddActInplaceGrad(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elewise_add_act_grad1") {}
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
PDNode* operator()(PDNode* x, std::unordered_set<std::string> acts);
// declare operator node's name
PATTERN_DECL_NODE(act_grad);
PATTERN_DECL_NODE(ele_add_grad);
// declare variable node's name
PATTERN_DECL_NODE(act_out);
PATTERN_DECL_NODE(d_itermediate_out);
PATTERN_DECL_NODE(d_ele_x);
PATTERN_DECL_NODE(d_ele_y);
PATTERN_DECL_NODE(ele_y);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......@@ -454,6 +531,12 @@ struct GRU : public PatternBase {
a->outputs.push_back(b); \
b->inputs.push_back(a);
// Set the out_var as the output of the op
#define IR_OP_VAR_LINK(op, out_var) \
op->outputs.push_back(out_var); \
out_var->inputs.clear(); \
out_var->inputs.push_back(op);
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -48,6 +48,10 @@ class Node {
bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; }
bool IsCtrlVar() const {
return type_ == Type::kVariable &&
Name().find(ir::Node::kControlDepVarName) != std::string::npos;
}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
......
......@@ -57,6 +57,21 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
graph = viz_pass->Apply(std::move(graph));
}
// Apply op fusion.
if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass =
ir::PassRegistry::Instance().Get("fuse_elewise_add_act_pass");
graph = fuse_elewise_add_act_pass->Apply(std::move(graph));
// Apply a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
}
// Convert graph to run on multi-devices.
auto multi_devices_pass =
ir::PassRegistry::Instance().Get("multi_devices_pass");
......@@ -359,6 +374,7 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework
} // namespace paddle
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
......
......@@ -987,18 +987,28 @@ void FusedElemwiseAndActComputeWithBroadcast(
}
// --- backward
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut>
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast {
HOSTDEVICE void operator()(size_t i) {
if (dx_ != nullptr) {
dx_[i] = UseIntermediateOut ? dx_op_(x_[i], y_[i], intermediate_out_[i],
out_[i], dout_[i])
: dx_op_(x_[i], y_[i], out_[i], dout_[i]);
dx_[i] = UseIntermediateOut
? dx_op_.UseIntermediateOut(
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
: dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
}
if (dy_ != nullptr) {
dy_[i] = UseIntermediateOut ? dy_op_(x_[i], y_[i], intermediate_out_[i],
out_[i], dout_[i])
: dy_op_(x_[i], y_[i], out_[i], dout_[i]);
dy_[i] = UseIntermediateOut
? dy_op_.UseIntermediateOut(
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
: dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
}
if (dintermediate_ != nullptr) {
dintermediate_[i] =
UseIntermediateOut
? dintermediate_op_.UseIntermediateOut(
x_[i], intermediate_out_[i], out_[i], dout_[i])
: dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
}
}
......@@ -1009,37 +1019,44 @@ struct FusedElemwiseAndActGradNoBroadcast {
const T *dout_;
DX_OP dx_op_;
DY_OP dy_op_;
DIntermediate_OP dintermediate_op_;
T *dx_;
T *dy_;
T *dintermediate_;
};
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
bool UseIntermediateOut>
typename DIntermediate_OP, bool UseIntermediateOut>
void FusedElemwiseAndActGradComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *intermediate_out,
const framework::Tensor *out, const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
framework::Tensor *dx, framework::Tensor *dy,
framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op) {
size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
for_range(
FusedElemwiseAndActGradNoBroadcast<T, DX_OP, DY_OP, UseIntermediateOut>{
FusedElemwiseAndActGradNoBroadcast<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut>{
x->data<T>(), y->data<T>(),
intermediate_out ? intermediate_out->data<T>() : nullptr,
out->data<T>(), dout->data<T>(), dx_op, dy_op,
out->data<T>(), dout->data<T>(), dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace())});
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y,
const T *intermediate_out,
const T *out, const T *dout,
int h, int w, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CPU(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int64_t tmp_out_idx, x_idx, y_idx;
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
......@@ -1055,9 +1072,11 @@ static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y,
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
......@@ -1071,9 +1090,11 @@ static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y,
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (BcastY) {
if (i == 0) {
dy[y_idx] = tmp;
......@@ -1084,18 +1105,34 @@ static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y,
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
x[x_idx], intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx],
out[offset], dout[i]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
if (i == 0) {
d_intermediate[tmp_out_idx] = tmp;
} else {
d_intermediate[tmp_out_idx] += tmp;
}
}
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y,
const T *intermediate_out,
const T *out, const T *dout,
int pre, int n, int post,
DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CPU(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int64_t tmp_out_idx, x_idx, y_idx;
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
......@@ -1112,9 +1149,11 @@ static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y,
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
......@@ -1128,9 +1167,11 @@ static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y,
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (BcastY) {
if (i == 0 && k == 0) {
dy[y_idx] = tmp;
......@@ -1141,21 +1182,40 @@ static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y,
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
x[x_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx],
out[offset], dout[i]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
if (i == 0) {
d_intermediate[tmp_out_idx] = tmp;
} else {
d_intermediate[tmp_out_idx] += tmp;
}
}
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val(0);
T val(0), inter_val(0);
int64_t tmp_out_idx, x_idx, y_idx;
do {
......@@ -1170,10 +1230,12 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
T tmp =
UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
......@@ -1182,23 +1244,38 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
T tmp =
UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
inter_val += tmp;
}
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
if (BcastY) {
if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
......@@ -1206,41 +1283,49 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
}
} else {
if (dx) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) {
inter_val = paddle::platform::reduceSum(inter_val, tid, h);
if (threadIdx.x == 0) {
d_intermediate[j] = inter_val;
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA(cudaStream_t stream,
const T *x, const T *y,
const T *intermediate_out,
const T *out, const T *dout,
int h, int w, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA(
cudaStream_t stream, const T *x, const T *y, const T *intermediate_out,
const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
FusedElemwiseAndActGradBroadcast1CUDAKernel<
T, DX_OP, DY_OP, UseIntermediateOut, BcastY,
T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dx, dy);
x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op,
dx, dy, d_intermediate);
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op, T *dx,
T *dy) {
const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int tid = threadIdx.x;
int j = blockIdx.x;
T val(0);
T val(0), inter_val(0);
int ttid = tid;
int64_t tmp_out_idx, x_idx, y_idx;
while (true) {
......@@ -1259,10 +1344,12 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
T tmp =
UseIntermediateOut
? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
......@@ -1271,24 +1358,38 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
T tmp =
UseIntermediateOut
? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
: dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset],
dout[offset]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
inter_val += tmp;
}
}
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
if (BcastY) {
if (dy) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
if (BcastY) {
if (dy) {
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
......@@ -1296,40 +1397,51 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
}
} else {
if (dx) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) {
inter_val = paddle::platform::reduceSum(inter_val, tid, h);
if (threadIdx.x == 0) {
d_intermediate[j] = inter_val;
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CUDA(
cudaStream_t stream, const T *x, const T *y, const T *intermediate_out,
const T *out, const T *dout, int pre, int n, int post, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy,
T *dintermediate) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
FusedElemwiseAndActGradBroadcast2CUDAKernel<
T, DX_OP, DY_OP, UseIntermediateOut, BcastY,
T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op,
dintermediate_op, dx, dy, dintermediate);
}
#endif
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
bool UseIntermediateOut, bool BcastY,
typename DIntermediate_OP, bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim_untrimed, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *intermediate_out,
const framework::Tensor *out, const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
framework::Tensor *dx, framework::Tensor *dy,
framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
......@@ -1341,70 +1453,82 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
int w = n;
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op,
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op,
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace()));
}
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
ctx.GetPlace()));
}
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
bool UseIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
typename DIntermediate_OP, bool UseIntermediateOut,
bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeEx(
const framework::ExecutionContext &ctx, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *out,
const framework::Tensor *intermediate_out, const framework::Tensor *dout,
int axis, framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op,
DY_OP dy_op) {
int axis, framework::Tensor *dx, framework::Tensor *dy,
framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op) {
const framework::DDim &x_dim = x->dims();
const framework::DDim &y_dim = y->dims();
if (UseIntermediateOut) {
PADDLE_ENFORCE(intermediate_out, "intermediate_out should not be nullptr");
}
if (x_dim == y_dim) {
FusedElemwiseAndActGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP,
UseIntermediateOut>(
FusedElemwiseAndActGradComputeNoBroadcast<
DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut>(
ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
dx_op, dy_op);
dintermediate, dx_op, dy_op, dintermediate_op);
} else { // Y is a scalar
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
......@@ -1420,16 +1544,16 @@ void FusedElemwiseAndActGradComputeEx(
// z = f1(f2(x, y))
if (bcast_y) { // Y should be broadcast.
FusedElemwiseAndActGradComputeWithBroadcast<
DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, true /*BcastY*/,
SameShapeOfIntermediateOutAndOut>(ctx, x_dim, y_dim, x, y,
intermediate_out, out, dout, axis,
dx, dy, dx_op, dy_op);
DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut,
true /*BcastY*/, SameShapeOfIntermediateOutAndOut>(
ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
dintermediate, dx_op, dy_op, dintermediate_op);
} else {
FusedElemwiseAndActGradComputeWithBroadcast<
DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, false /*BcastY*/,
SameShapeOfIntermediateOutAndOut>(ctx, y_dim, x_dim, x, y,
intermediate_out, out, dout, axis,
dx, dy, dx_op, dy_op);
DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut,
false /*BcastY*/, SameShapeOfIntermediateOutAndOut>(
ctx, y_dim, x_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
dintermediate, dx_op, dy_op, dintermediate_op);
}
}
}
......@@ -1444,7 +1568,7 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
framework::Tensor *intermediate_out) {
if (KeepIntermediateOut) {
PADDLE_ENFORCE(intermediate_out,
"The keep_intermediate_value is opened, "
"The save_intermediate_out is opened, "
"intermediate_out should not be nullptr.");
}
......
......@@ -13,18 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
/*
* Whether the compound function is Unary(Binary(X, Y)).
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
* out.
*/
static bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> binary_fun = {
"elementwise_add", "elementwise_mul", "elementwise_add_grad",
......@@ -32,10 +25,17 @@ static bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
return binary_fun.count(functor_list[1]) != 0;
}
/*
* Whether the Input(X) could be absent.
*/
static bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
bool HasInPlaceUnary(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> InplaceOpSet = {"relu", "relu_grad"};
bool is_in_place = false;
for (auto &func_name : functor_list) {
is_in_place |= (InplaceOpSet.count(func_name) == 1);
}
return is_in_place;
}
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"};
return binary_fun.count(functor_list[0]) != 0 ||
......@@ -86,20 +86,12 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
// Whether the shape of Y is a continuous subsequence of X,
// For more information please refer to the op's introduction.
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] < y_dim[i]) {
bcast_y = false;
break;
}
}
}
bool bcast_y = IsBcastY(x_dim, y_dim);
auto &out_dim = bcast_y ? x_dim : y_dim;
std::string out_lod = bcast_y ? "X" : "Y";
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
if (ctx->Attrs().Get<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"),
"Output(IntermediateOut) of FusedElemwiseActivationOp "
"should not be null.");
......@@ -123,6 +115,20 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
ctx->ShareLoD(out_lod, /*->*/ "Out");
}
static bool IsBcastY(const framework::DDim &x_dim,
const framework::DDim &y_dim) {
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] < y_dim[i]) {
bcast_y = false;
break;
}
}
}
return bcast_y;
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -157,17 +163,7 @@ class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("scale",
"scale is used by scale_op, the default value is 0.0.")
.SetDefault(0.0);
AddAttr<bool>(
"recomputation",
"Whether to recompute the Out."
"The computation of fused_elemwise_activation_grad has two methods to "
"get the dx and dy, one is to use the 'Out', and the other is not. "
"The former method will save the time of recomputing the 'Out', but it "
"must occupy the memory to store the 'out'. While, the later method "
"can avoid occupying the memory, but it must recompute the 'Out'. "
"It is useful for Unary(Binary(X, Y)). The default value is true.")
.SetDefault(true);
AddAttr<bool>("keep_intermediate_value",
AddAttr<bool>("save_intermediate_out",
"Whether to save the intermediate_out.")
.SetDefault(false);
AddAttr<std::vector<std::string>>("functor_list",
......@@ -227,30 +223,38 @@ class FusedElemwiseActivationGradMaker
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *op_desc_ptr = new framework::OpDesc();
op_desc_ptr->SetType(this->ForwardOpType() + "_grad");
auto *grad_op = new framework::OpDesc();
grad_op->SetType(this->ForwardOpType() + "_grad");
for (auto &input_param : this->InputNames()) {
op_desc_ptr->SetInput(input_param, this->Input(input_param));
op_desc_ptr->SetOutput(framework::GradVarName(input_param),
grad_op->SetInput(input_param, this->Input(input_param));
grad_op->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param, true));
}
for (auto &output_param : this->OutputNames()) {
op_desc_ptr->SetInput(output_param, this->Output(output_param));
op_desc_ptr->SetInput(framework::GradVarName(output_param),
this->OutputGrad(output_param));
}
grad_op->SetInput("Out", this->Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op_desc_ptr->SetAttrMap(this->Attrs());
grad_op->SetAttrMap(this->Attrs());
std::vector<std::string> functor_names =
boost::get<std::vector<std::string>>(
op_desc_ptr->GetAttr("functor_list"));
boost::get<std::vector<std::string>>(grad_op->GetAttr("functor_list"));
functor_names[0] += "_grad";
functor_names[1] += "_grad";
op_desc_ptr->SetAttr("functor_list", functor_names);
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
grad_op->SetAttr("functor_list", functor_names);
if (boost::get<bool>(grad_op->GetAttr("save_intermediate_out"))) {
PADDLE_ENFORCE_NE(Output("IntermediateOut").size(), 0);
grad_op->SetInput("IntermediateOut", this->Output("IntermediateOut"));
grad_op->SetOutput(framework::GradVarName("IntermediateOut"),
this->OutputGrad("IntermediateOut"));
} else {
grad_op->SetInput("IntermediateOut", {});
grad_op->SetOutput(framework::GradVarName("IntermediateOut"), {});
}
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
......@@ -261,56 +265,65 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
auto functor_list =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
if (ctx->Attrs().Get<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"),
"Input(IntermediateOut) should not be null");
} else {
PADDLE_ENFORCE_EQ(ctx->Inputs(framework::GradVarName("Out")).size(), 1);
if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
}
}
auto funtor_list =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
auto inter_grad_name = framework::GradVarName("IntermediateOut");
if (ctx->HasOutput(x_grad_name)) {
if (ctx->HasInputs("X")) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", x_grad_name);
} else {
// Node: If "X" is absence, the shape of Y should be a continuous
// subsequence of X, if not, we could not infer the shape of dx.
// Currently, only when Binary is elementwise_add or elementwise_sub,
// the "X" could be absent.
PADDLE_ENFORCE(InputXCanBeAbsent(funtor_list),
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
"Only when BinaryFunctor is elementwise_add, the 'X' "
"could be absent.");
// For Unary(Binary(X, Y)), IntermediateOut should not be empty.
if (IsUnaryCompound(funtor_list)) {
PADDLE_ENFORCE(
ctx->HasInputs("IntermediateOut"),
"If the compound_functor is Unary(Binary(X, Y)) and Binary "
"is elementwise_add, the intermediate_out must be not absent.");
}
// Node: If "X" is absence, the shape of Y should be a continuous
// subsequence of X, otherwise, we could not infer the shape of dx.
ctx->SetOutputDim(x_grad_name,
ctx->GetInputDim(framework::GradVarName("Out")));
ctx->ShareLoD(framework::GradVarName("Out"), x_grad_name);
}
}
if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", y_grad_name);
}
if (ctx->HasOutput(inter_grad_name)) {
// For Unary(Binary(X, Y)), IntermediateOut should not be empty.
if (IsUnaryCompound(functor_list)) {
ctx->SetOutputDim(inter_grad_name,
ctx->GetInputDim(framework::GradVarName("Out")));
ctx->ShareLoD(framework::GradVarName("Out"), inter_grad_name);
} else {
ctx->SetOutputDim(inter_grad_name, ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", inter_grad_name);
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
// PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type();
auto input_data_type = framework::ToDataType(input_data_type_index);
return framework::OpKernelType(input_data_type, ctx.GetPlace());
......
......@@ -26,6 +26,24 @@ limitations under the License. */
namespace paddle {
namespace operators {
/**
* Whether the compound function is Unary(Binary(X, Y)).
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
* out.
*/
bool IsUnaryCompound(const std::vector<std::string> &functor_list);
/**
* For the in-place unary functor, the inputs of op_desc only have Out and
* Out@Grad.
*/
bool HasInPlaceUnary(const std::vector<std::string> &functor_list);
/**
* Whether the Input(X) could be absent.
*/
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list);
template <typename DeviceContext, typename T, typename BinaryFunctor,
typename UnaryFunctor>
static void RunBinaryCompoundFunctor(
......@@ -39,7 +57,7 @@ static void RunBinaryCompoundFunctor(
paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
compound_func(binary_functor, unary_functor);
int axis = ctx.Attr<int>("axis");
if (ctx.Attr<bool>("keep_intermediate_value")) {
if (ctx.Attr<bool>("save_intermediate_out")) {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::BinaryCompoundFunctor<
T, BinaryFunctor, UnaryFunctor>,
......@@ -71,7 +89,7 @@ static void RunUnaryCompoundFunctors(
paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
compound_func(unary_functor, binary_functor);
if (ctx.Attr<bool>("keep_intermediate_value")) {
if (ctx.Attr<bool>("save_intermediate_out")) {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::UnaryCompoundFunctor<
T, UnaryFunctor, BinaryFunctor>,
......@@ -89,7 +107,7 @@ static void RunUnaryCompoundFunctors(
}
template <typename DeviceContext, typename T, typename BinaryGradFunctor,
typename UnaryFunctor, typename UnaryGradFunctor>
typename UnaryFunctor, typename UnaryGradFunctor, bool InPlace>
static void RunBinaryCompoundGradFunctors(
const framework::ExecutionContext &ctx,
const BinaryGradFunctor &binary_grad_functor,
......@@ -98,7 +116,7 @@ static void RunBinaryCompoundGradFunctors(
const framework::Tensor *in_y, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad) {
framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
// Z = Binary(X, Unary(Y))
int axis = ctx.Attr<int>("axis");
......@@ -107,32 +125,40 @@ static void RunBinaryCompoundGradFunctors(
UnaryFunctor>;
using BinaryCompoundDyFunctor =
paddle::operators::math::BinaryCompoundGradDyFunctor<
T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor>;
T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
using BinaryCompoundDIntermedaiteOutFunctor =
paddle::operators::math::BinaryCompoundGradDIntermedaiteOutFunctor<
T, BinaryGradFunctor, UnaryFunctor>;
if (in_intermediate_out) {
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
true /*UseIntermediateOut*/,
BinaryCompoundDIntermedaiteOutFunctor, true /*UseIntermediateOut*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
y_grad, d_intermediate_out,
BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
unary_grad_functor));
unary_grad_functor),
BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
unary_functor));
} else {
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
false /*UseIntermediateOut*/,
BinaryCompoundDIntermedaiteOutFunctor, false /*UseIntermediateOut*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
y_grad, d_intermediate_out,
BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
unary_grad_functor));
unary_grad_functor),
BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
unary_functor));
}
}
template <typename DeviceContext, typename T, typename UnaryGradFunctor,
typename BinaryFunctor, typename BinaryGradFunctor,
bool Recomputation = true>
typename BinaryFunctor, typename BinaryGradFunctor, bool InPlace>
static void RunUnaryCompoundGradFunctors(
const framework::ExecutionContext &ctx,
const UnaryGradFunctor &unary_grad_functor,
......@@ -141,36 +167,44 @@ static void RunUnaryCompoundGradFunctors(
const framework::Tensor *in_y, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad) {
framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
// Z = Unary(Binary(X, Y))
int axis = ctx.Attr<int>("axis");
using UnaryCompoundDxFunctor =
paddle::operators::math::UnaryCompoundGradDxFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>;
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDyFunctor =
paddle::operators::math::UnaryCompoundGradDyFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>;
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDIntermediateFunctor =
paddle::operators::math::UnaryCompoundGradDIntermediateFunctor<
T, UnaryGradFunctor, BinaryFunctor, InPlace>;
if (in_intermediate_out) {
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
true /*UseIntermediateOut*/, true /*SameShapeOfIntermediateOutAndOut*/>(
UnaryCompoundDIntermediateFunctor, true /*UseIntermediateOut*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
y_grad, d_intermediate_out,
UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
binary_grad_functor),
UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
binary_grad_functor));
binary_grad_functor),
UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
} else {
FusedElemwiseAndActGradComputeEx<DeviceContext, T, UnaryCompoundDxFunctor,
UnaryCompoundDyFunctor,
false /*UseIntermediateOut*/,
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
UnaryCompoundDIntermediateFunctor, false /*UseIntermediateOut*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
y_grad, d_intermediate_out,
UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
binary_grad_functor),
UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
binary_grad_functor));
binary_grad_functor),
UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
}
}
......@@ -226,72 +260,67 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
}
}
template <typename DeviceContext, typename T, bool ReComputation>
static void RunGradFunctors(const framework::ExecutionContext &ctx,
const framework::Tensor *in_x,
const framework::Tensor *in_y,
const framework::Tensor *in_out,
template <typename DeviceContext, typename T, bool InPlace>
static void RunGradFunctors(
const framework::ExecutionContext &ctx, const framework::Tensor *in_x,
const framework::Tensor *in_y, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad,
framework::Tensor *x_grad,
framework::Tensor *y_grad) {
const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
auto funcs_str = functors[0] + "," + functors[1];
// TODO(zcd): The following code can be refined. for example, use registrition
if (funcs_str == "elementwise_add_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::AddGradFunctor<T>,
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>>(
paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::AddGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "scale_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunUnaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::ScaleGradFunctor<T>,
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::ScaleGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>,
ReComputation /*Recomputation*/>(
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::ScaleGradFunctor<T>(scale),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_add_grad,relu_grad") {
RunBinaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::AddGradFunctor<T>,
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
paddle::operators::math::ReluFunctor<T>,
paddle::operators::math::ReluGradFunctor<T>>(
paddle::operators::math::ReluGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::AddGradFunctor<T>(),
paddle::operators::math::ReluFunctor<T>(),
paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "relu_grad,elementwise_add_grad") {
RunUnaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::ReluGradFunctor<T>,
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::ReluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>,
ReComputation /*Recomputation*/>(
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::ReluGradFunctor<T>(),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_mul_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::MulGradFunctor<T>,
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>>(
paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::MulGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
}
......@@ -313,9 +342,9 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
std::vector<framework::Tensor *> outputs;
outputs.emplace_back(output);
if (ctx.Attr<bool>("keep_intermediate_value")) {
if (ctx.Attr<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
"The keep_intermediate_value is enable, so the "
"The save_intermediate_out is enable, so the "
"IntermediateOut should not be empty.");
auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
outputs.emplace_back(intermediate_out);
......@@ -331,65 +360,63 @@ template <typename DeviceContext, typename T>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto x = ctx.Input<framework::Tensor>("X");
auto y = ctx.Input<framework::Tensor>("Y");
auto in_y = ctx.Input<framework::Tensor>("Y");
PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr.");
auto in_out = ctx.Input<framework::Tensor>("Out");
PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr.");
auto in_out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(in_out_grad != nullptr,
"Input(Out@Grad) should not be nullptr.");
framework::Tensor *in_x =
const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
framework::Tensor *x_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
framework::Tensor *y_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
framework::Tensor *d_intermediate_out = ctx.Output<framework::Tensor>(
framework::GradVarName("IntermediateOut"));
PADDLE_ENFORCE(y != nullptr, "Input(Y) should not be nullptr.");
auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");
if (ctx.Attr<bool>("recomputation")) {
PADDLE_ENFORCE(
x != nullptr,
"The recomputation is opened, so Input(X) should not be absent.");
// Get intermediate_out
framework::Tensor *in_intermediate_out = nullptr;
if (ctx.Attr<bool>("save_intermediate_out")) {
// if save_intermediate_out is true, for Unary(Binary(x, y)) and
// Binary(x, Unary(y)), the Binary(x, y) and Unary(y) not need to
// recompute.
in_intermediate_out = const_cast<framework::Tensor *>(
ctx.Input<framework::Tensor>("IntermediateOut"));
PADDLE_ENFORCE(in_intermediate_out != nullptr,
"The option of 'save_intermediate_out' is opened, "
"so the number of 'Out' should be two.");
} else {
PADDLE_ENFORCE(in_out != nullptr,
"The recomputation is disabled, so the Input('Out') "
"should not be empty.");
if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null.");
}
}
framework::Tensor *in_x;
auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");
// Get in_x
if (ctx.HasInput("X")) {
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr.");
} else {
// If functor_list contains elementwise_add, the backward doesn't use
// in_x, and in_outs.
if (x == nullptr) {
PADDLE_ENFORCE(functor_list[0] == "elementwise_add_grad" ||
functor_list[1] == "elementwise_add_grad",
// in_x, in_y and in_out.
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent.");
in_x = const_cast<framework::Tensor *>(in_out_grad);
in_out = const_cast<framework::Tensor *>(in_out_grad);
} else {
in_x = const_cast<framework::Tensor *>(x);
}
framework::Tensor *in_intermediate_out;
if (ctx.Attr<bool>("keep_intermediate_value")) {
in_intermediate_out = const_cast<framework::Tensor *>(
ctx.Input<framework::Tensor>("IntermediateOut"));
PADDLE_ENFORCE(in_intermediate_out != nullptr,
"The option of 'keep_intermediate_value' is opened, "
"so the number of 'Out' should be two.");
} else {
in_intermediate_out = nullptr;
}
if (ctx.Attr<bool>("recomputation")) {
RunGradFunctors<DeviceContext, T, true /*Recomputation*/>(
ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad);
bool has_in_place = HasInPlaceUnary(functor_list);
if (has_in_place) {
RunGradFunctors<DeviceContext, T, true /*InPlace*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad, d_intermediate_out);
} else {
RunGradFunctors<DeviceContext, T, false /*Recomputation*/>(
ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad);
RunGradFunctors<DeviceContext, T, false /*InPlace*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad, d_intermediate_out);
}
}
};
......
......@@ -22,11 +22,11 @@ namespace paddle {
namespace operators {
namespace math {
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename BinaryFunctor, typename UnaryFunctor>
struct BinaryCompoundFunctor {
BinaryCompoundFunctor(const BinaryFunctor func1, const UnaryFunctor func2)
: func1_(func1), func2_(func2) {}
// Z = BinaryFunctor(X, UnaryFunctor(Y))
inline HOSTDEVICE T GetOut(T x, T y) { return func1_(x, func2_(y)); }
......@@ -40,11 +40,11 @@ struct BinaryCompoundFunctor {
UnaryFunctor func2_;
};
// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename UnaryFunctor, typename BinaryFunctor>
struct UnaryCompoundFunctor {
UnaryCompoundFunctor(const UnaryFunctor func1, const BinaryFunctor func2)
: func1_(func1), func2_(func2) {}
// Z = UnaryFunctor(BinaryFunctor(X, Y))
inline HOSTDEVICE T GetOut(T x, T y) { return func1_(func2_(x, y)); }
......@@ -58,23 +58,19 @@ struct UnaryCompoundFunctor {
BinaryFunctor func2_;
};
// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get
// the dx, one is to use the 'out', and the other is not to use it.
// the former method will save the time of recomputing the
// 'out', but it must occupy the memory to store the 'out'.
// While the later method can avoid occupying this memory,
// but it must recompute the 'out'.
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDxFunctor {
BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun)
: d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dx(x, unary_fun_(y));
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
return dout * d_binary_fun_.Dx(x, intermediate_out);
}
......@@ -83,8 +79,9 @@ struct BinaryCompoundGradDxFunctor {
UnaryFun unary_fun_;
};
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun,
typename DUnaryFun>
typename DUnaryFun, bool InPlace>
struct BinaryCompoundGradDyFunctor {
BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun,
......@@ -93,13 +90,19 @@ struct BinaryCompoundGradDyFunctor {
unary_fun_(unary_fun),
d_unary_fun_(d_unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_(y);
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_.UseX(y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
if (InPlace) {
return dout * d_binary_fun_.Dy(x, intermediate_out) *
d_unary_fun_(y, intermediate_out);
d_unary_fun_.UseOut(intermediate_out);
} else {
return dout * d_binary_fun_.Dy(x, intermediate_out) *
d_unary_fun_.UseXAndOut(y, intermediate_out);
}
}
private:
......@@ -108,8 +111,9 @@ struct BinaryCompoundGradDyFunctor {
DUnaryFun d_unary_fun_;
};
// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
typename DBinaryFun, bool InPlace>
struct UnaryCompoundGradDxFunctor {
UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
......@@ -118,22 +122,23 @@ struct UnaryCompoundGradDxFunctor {
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
base = dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
}
return base * d_binary_fun_.Dx(x, y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(intermediate_out);
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
} else {
base = dout * d_unary_fun_(intermediate_out, out);
base = dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
}
return base * d_binary_fun_.Dx(x, y);
}
......@@ -144,8 +149,9 @@ struct UnaryCompoundGradDxFunctor {
DBinaryFun d_binary_fun_;
};
// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
typename DBinaryFun, bool InPlace>
struct UnaryCompoundGradDyFunctor {
UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
......@@ -154,22 +160,23 @@ struct UnaryCompoundGradDyFunctor {
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
base = dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
}
return base * d_binary_fun_.Dy(x, y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(intermediate_out);
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
} else {
base = dout * d_unary_fun_(intermediate_out, out);
base = dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
}
return base * d_binary_fun_.Dy(x, y);
}
......@@ -180,6 +187,56 @@ struct UnaryCompoundGradDyFunctor {
DBinaryFun d_binary_fun_;
};
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDIntermedaiteOutFunctor {
BinaryCompoundGradDIntermedaiteOutFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun)
: d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dy(x, unary_fun_(y));
}
inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out,
T dout) {
return dout * d_binary_fun_.Dy(x, intermediate_out);
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
};
// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun, bool InPlace>
struct UnaryCompoundGradDIntermediateFunctor {
UnaryCompoundGradDIntermediateFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun)
: d_unary_fun_(d_unary_fun), binary_fun_(binary_fun) {}
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
if (InPlace) {
return dout * d_unary_fun_.UseOut(out);
} else {
return dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
}
}
inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out,
T dout) {
if (InPlace) {
return dout * d_unary_fun_.UseOut(out);
} else {
return dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
}
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -58,9 +58,9 @@ template <typename T>
struct ScaleGradFunctor {
explicit ScaleGradFunctor(T coeff) : coeff_(coeff) {}
inline HOSTDEVICE T operator()(T x) { return coeff_; }
inline HOSTDEVICE T operator()(T x, T out) { return coeff_; }
inline HOSTDEVICE T UseX(T x) { return coeff_; }
inline HOSTDEVICE T UseOut(T out) { return coeff_; }
inline HOSTDEVICE T UseXAndOut(T x, T out) { return coeff_; }
private:
T coeff_;
......@@ -73,9 +73,9 @@ struct ReluFunctor {
template <typename T>
struct ReluGradFunctor {
inline HOSTDEVICE T operator()(T x) { return x > 0 ? 1 : 0; }
inline HOSTDEVICE T operator()(T x, T out) { return x > 0 ? 1 : 0; }
inline HOSTDEVICE T UseX(T x) { return x > 0 ? 1 : 0; }
inline HOSTDEVICE T UseOut(T out) { return out > 0 ? 1 : 0; }
inline HOSTDEVICE T UseXAndOut(T x, T out) { return out > 0 ? 1 : 0; }
};
} // namespace math
......
......@@ -670,7 +670,14 @@ All parameter, weight, gradient are variables in Paddle.
.def_property(
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; });
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; })
.def_property("fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
return self.fuse_elewise_add_act_ops_;
},
[](BuildStrategy &self, bool b) {
self.fuse_elewise_add_act_ops_ = b;
});
pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &,
......
......@@ -47,8 +47,7 @@ def get_numeric_gradient(place,
input_to_check,
output_names,
delta=0.005,
in_place=False,
sum_outputs=None):
in_place=False):
# FIXME: change this method by compile time concepts
set_input(scope, op, inputs, place)
......@@ -59,8 +58,6 @@ def get_numeric_gradient(place,
sum = []
op.run(scope, place)
for output_name in output_names:
if sum_outputs and output_name not in sum_outputs:
continue
sum.append(
np.array(scope.find_var(output_name).get_tensor()).mean())
return np.array(sum).sum() / len(output_names)
......@@ -407,14 +404,13 @@ class OpTest(unittest.TestCase):
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
sum_outputs=None):
user_defined_grads=None):
places = self._get_places()
for place in places:
self.check_grad_with_place(place, inputs_to_check, output_names,
no_grad_set, numeric_grad_delta,
in_place, max_relative_error,
user_defined_grads, sum_outputs)
user_defined_grads)
def check_grad_with_place(self,
place,
......@@ -424,8 +420,7 @@ class OpTest(unittest.TestCase):
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
sum_outputs=None):
user_defined_grads=None):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
......@@ -448,8 +443,7 @@ class OpTest(unittest.TestCase):
input_to_check,
output_names,
delta=numeric_grad_delta,
in_place=in_place,
sum_outputs=sum_outputs) for input_to_check in inputs_to_check
in_place=in_place) for input_to_check in inputs_to_check
]
analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set)
......
......@@ -38,6 +38,7 @@ class TestParallelExecutorBase(unittest.TestCase):
seed=None,
use_parallel_executor=True,
use_reduce=False,
fuse_elewise_add_act_ops=False,
optimizer=fluid.optimizer.Adam,
use_fast_executor=False):
def run_executor(exe, feed, fetch_list, program=None):
......@@ -78,6 +79,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
if use_parallel_executor:
exe = fluid.ParallelExecutor(
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parallel_executor_test_base import TestParallelExecutorBase
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import paddle
import paddle.dataset.mnist as mnist
import unittest
import os
MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio"
def simple_fc_net(use_feed):
if use_feed:
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_files(
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = img
for _ in range(4):
hidden = fluid.layers.fc(
hidden,
size=200,
act='relu',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
def fc_with_batchnorm(use_feed):
if use_feed:
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_files(
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = img
for _ in range(2):
hidden = fluid.layers.fc(
hidden,
size=200,
act='relu',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
hidden = fluid.layers.batch_norm(input=hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestMNIST(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=4)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
MNIST_RECORDIO_FILE, reader, feeder)
def _init_data(self, random=True):
np.random.seed(5)
if random:
img = np.random.random(size=[32, 784]).astype(np.float32)
else:
img = np.ones(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
return img, label
def _compare_fuse_elewise_add_act_ops(self,
model,
use_cuda,
random_data=True):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = self._init_data(random_data)
def _optimizer(learning_rate=1e-6):
optimizer = fluid.optimizer.SGD(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-6))
return optimizer
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
fuse_elewise_add_act_ops=False,
memory_opt=False,
optimizer=_optimizer)
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
fuse_elewise_add_act_ops=True,
memory_opt=False,
optimizer=_optimizer)
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
def test_simple_fc_with_fuse_op(self):
self._compare_fuse_elewise_add_act_ops(simple_fc_net, True)
self._compare_fuse_elewise_add_act_ops(simple_fc_net, False)
def test_batchnorm_fc_with_fuse_op(self):
self._compare_fuse_elewise_add_act_ops(fc_with_batchnorm, True)
self._compare_fuse_elewise_add_act_ops(fc_with_batchnorm, False)
if __name__ == '__main__':
unittest.main()
......@@ -48,7 +48,7 @@ def create_test_class(test_case, callback, attrs):
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
if self.attrs["keep_intermediate_value"]:
if self.attrs["save_intermediate_out"]:
self.outputs = {
'Out': self.out,
"IntermediateOut": self.intermediate_out
......@@ -73,22 +73,19 @@ def create_test_class(test_case, callback, attrs):
def test_check_output(self):
self.check_output()
# FIXME(zcd): the intermediate_out_grad is not checked.
def test_check_grad_normal(self):
if self.attrs["keep_intermediate_value"]:
self.check_grad(
['X', 'Y'], ['Out', 'IntermediateOut'],
max_relative_error=0.005,
sum_outputs=['Out'])
if self.attrs["save_intermediate_out"]:
self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005)
else:
self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005)
def test_check_grad_ingore_x(self):
if self.attrs["keep_intermediate_value"]:
if self.attrs["save_intermediate_out"]:
self.check_grad(
['Y'], ['Out', 'IntermediateOut'],
['Y'], ['Out'],
max_relative_error=0.005,
no_grad_set=set("X"),
sum_outputs=['Out'])
no_grad_set=set("X"))
else:
self.check_grad(
['Y'], ['Out'],
......@@ -96,12 +93,11 @@ def create_test_class(test_case, callback, attrs):
no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
if self.attrs["keep_intermediate_value"]:
if self.attrs["save_intermediate_out"]:
self.check_grad(
['X'], ['Out', 'IntermediateOut'],
['X'], ['Out'],
max_relative_error=0.005,
no_grad_set=set("Y"),
sum_outputs=['Out'])
no_grad_set=set("Y"))
else:
self.check_grad(
['X'], ['Out'],
......@@ -303,38 +299,31 @@ for mode in {0, 1}:
relu_add_func = partial(relu_add_func, mode=mode)
add_relu_func = partial(add_relu_func, mode=mode)
for recomputation in {True, False}:
for keep_intermediate_value in {True, False}:
suffix = ("_keep_intermediate_value" if keep_intermediate_value else "") \
+ ("_recomputation" if recomputation else "") \
for save_intermediate_out in {True, False}:
suffix = ("_save_intermediate_out" if save_intermediate_out else "") \
+ ("_mode_"+ str(mode))
create_test_class('scale_add' + suffix, scale_add_func, {
'scale': scale,
'functor_list': ["scale", "elementwise_add"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
'save_intermediate_out': save_intermediate_out,
})
create_test_class('add_scale' + suffix, add_scale_func, {
'scale': scale,
'functor_list': ["elementwise_add", "scale"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
'save_intermediate_out': save_intermediate_out,
})
create_test_class('add_relu' + suffix, add_relu_func, {
'functor_list': ["elementwise_add", "relu"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
'save_intermediate_out': save_intermediate_out,
})
create_test_class('relu_add' + suffix, relu_add_func, {
'functor_list': ["relu", "elementwise_add"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
'save_intermediate_out': save_intermediate_out,
})
create_test_class('mul_scale' + suffix, mul_scale_func, {
'scale': scale,
'functor_list': ["elementwise_mul", "scale"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
'save_intermediate_out': save_intermediate_out,
})
if __name__ == '__main__':
......
......@@ -79,7 +79,7 @@ class TestReshapeOpWithInputShape(OpTest):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out", sum_outputs=["Out"])
self.check_grad(["X"], "Out")
if __name__ == "__main__":
......
......@@ -34,7 +34,7 @@ class TestTransposeOp(OpTest):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(['X'], 'Out', sum_outputs=['Out'])
self.check_grad(['X'], 'Out')
def initTestCase(self):
self.shape = (3, 4)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册