未验证 提交 d402234b 编写于 作者: C chengduo 提交者: GitHub

Feature/op_fuse_pass (#12440)

* Add Preface

* Add demo code

* Save file

* Refine code

* seems can work

* use elementwise strategy

* Use ElementwiseComputeEx

* Add comments

* extract functions from operator

* Refine code

* Follow comment

* code refine

* add op_fuse  pass

* add backward

* code refine

* use TopologySortOperations

* follow comments

* refine IsFusible

* code enhance

* fix op_fusion_pass

* refine code

* refine fuse_elemwise_act_op

* adjust the input and output

* refine logic

* add intermediate_edge

* disable inplace

* follow comments

* refine logic

* follow comments

* Remove the removable IntermediateOut

* change strategy

* code refine

* enable fuse backward

* code refine

* code refine

* rename unit test

* follow comments
上级 5acdbbb4
......@@ -150,11 +150,11 @@ else()
endif()
if (NOT WIN32)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
graph graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fast_threaded_ssa_graph_executor)
fast_threaded_ssa_graph_executor fuse_elewise_add_act_pass)
endif() # NOT WIN32
cc_library(prune SRCS prune.cc DEPS framework_proto)
......
......@@ -54,6 +54,8 @@ struct BuildStrategy {
std::string debug_graphviz_path_{""};
bool fuse_elewise_add_act_ops_{false};
bool enable_data_balance_{false};
};
......
......@@ -37,6 +37,8 @@ pass_library(fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference)
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale"};
graph = FuseActElewiseAdd(std::move(graph), act_types);
graph = FuseElewiseAddAct(std::move(graph), act_types);
// backward
{
std::unordered_set<std::string> in_place_act_types = {"relu_grad"};
graph = FuseElewiseAddActInplaceGrad(std::move(graph), in_place_act_types);
}
// Remove the removable intermediate_out.
RemoveIntermediateOut(graph.get());
return graph;
}
// ele_add(x, act(y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("elewise_add_act", graph.get());
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("elewise_add_act/x")
->AsInput()
->assert_is_op_input("elementwise_add", "X");
patterns::ElewiseAddAct elewise_add_act_pattern(gpd.mutable_pattern(),
"elementwise_add");
elewise_add_act_pattern(x, act_types);
int found_elewise_add_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseElewiseAddAct fuse";
GET_IR_NODE_FROM_SUBGRAPH(ele_y, ele_y, elewise_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out,
elewise_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, elewise_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, elewise_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_add, ele_add, elewise_add_act_pattern);
std::string ele_x_n = subgraph.at(x)->Name();
std::string ele_y_n = ele_y->Name();
std::string ele_out_n = ele_out->Name();
std::string act_out_n = act_out->Name();
Node *elewise_add_act_node = CreateFuseElewiseAddActNode(
g, act, ele_add, ele_x_n, ele_y_n, ele_out_n, act_out_n);
VLOG(4) << "\n\t " << ele_x_n << " and " << ele_y_n << " -> "
<< ele_add->Name() << " -> " << ele_out_n << "\n"
<< "\t " << ele_out_n << " -> " << act->Name() << " -> "
<< act_out_n;
ReLinkNodes(g, ele_out, ele_add, act, elewise_add_act_node);
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
AddStatis(found_elewise_add_act_count);
return graph;
}
// act(ele_add(x,y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("act_elewise_add", graph.get());
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("act_elewise_add/x")
->AsInput()
->assert_is_ops_input(act_types, "X");
patterns::ActElewiseAdd act_elewise_add_pattern(gpd.mutable_pattern(),
"act_elewise_add");
act_elewise_add_pattern(x, act_types);
int found_elewise_add_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseElewiseAddAct fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_x, ele_x, act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out,
act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_add, ele_add, act_elewise_add_pattern);
std::string act_i_n = subgraph.at(x)->Name();
std::string act_o_n = act_out->Name();
std::string elewise_add_x_n = ele_x->Name();
std::string elewise_add_out_n = ele_out->Name();
Node *elewise_add_act_node = CreateFuseElewiseAddActNode(
g, ele_add, act, elewise_add_x_n, act_i_n, act_o_n, elewise_add_out_n);
VLOG(4) << "\n\t " << act_i_n << " -> " << act->Name() << " -> " << act_o_n
<< "\n\t " << act_o_n << " and " << elewise_add_x_n << " -> "
<< ele_add->Name() << " -> " << elewise_add_out_n;
ReLinkNodes(g, act_out, act, ele_add, elewise_add_act_node);
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
AddStatis(found_elewise_add_act_count);
return graph;
}
// the backward of act(ele_add(x,y))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("elewise_add_act_grad", graph.get());
GraphPatternDetector gpd;
auto *d_act_out = gpd.mutable_pattern()
->NewNode("elewise_add_act_grad_inplace/x")
->AsInput()
->assert_is_ops_input(act_types, GradVarName("Out"));
patterns::ElewiseAddActInplaceGrad elewise_add_act_grad_pattern(
gpd.mutable_pattern(), "elewise_add_act_grad_inplace");
elewise_add_act_grad_pattern(d_act_out, act_types);
int found_elewise_add_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseElewiseAddActGrad1 fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_itermediate_out, d_itermediate_out,
elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_y, ele_y, elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_add_grad, ele_add_grad,
elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_ele_x, d_ele_x, elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_ele_y, d_ele_y, elewise_add_act_grad_pattern);
std::string d_act_out_n = subgraph.at(d_act_out)->Name();
std::string act_out_n = act_out->Name();
std::string d_itermediate_out_n = d_itermediate_out->Name();
std::string ele_y_n = ele_y->Name();
std::string d_ele_x_n = d_ele_x->Name();
std::string d_ele_y_n = d_ele_y->Name();
OpDesc desc;
desc.SetType("fused_elemwise_activation_grad");
desc.SetInput("IntermediateOut", {});
desc.SetInput("X", {});
desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
desc.SetInput("Out", std::vector<std::string>({act_out_n}));
desc.SetInput(GradVarName("Out"), std::vector<std::string>({d_act_out_n}));
desc.SetOutput(GradVarName("X"), std::vector<std::string>({d_ele_x_n}));
desc.SetOutput(GradVarName("Y"), std::vector<std::string>({d_ele_y_n}));
desc.SetOutput(GradVarName("IntermediateOut"),
std::vector<std::string>({d_itermediate_out_n}));
desc.SetAttr("save_intermediate_out", false);
desc.SetAttr("functor_list",
std::vector<std::string>(
{act_grad->Op()->Type(), ele_add_grad->Op()->Type()}));
for (auto &n : {act_grad->Op(), ele_add_grad->Op()}) {
for (auto &m_ele : n->GetAttrMap()) {
desc.SetAttr(m_ele.first, m_ele.second);
}
}
auto fused_node = g->CreateOpNode(&desc);
VLOG(4) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> "
<< act_grad->Name() << " -> " << d_itermediate_out_n << "\n\t "
<< d_itermediate_out_n << " and " << act_out_n << " -> "
<< ele_add_grad->Name() << " -> " << d_itermediate_out_n;
ReLinkNodes(g, d_itermediate_out, act_grad, ele_add_grad, fused_node);
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
AddStatis(found_elewise_add_act_count);
return graph;
}
Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode(
Graph *g, const Node *op_1, const Node *op_2, const std::string &ele_x_n,
const std::string &ele_y_n, const std::string &ele_out_n,
const std::string &act_out_n) const {
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({ele_x_n}));
desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
desc.SetOutput("Out", std::vector<std::string>({act_out_n}));
desc.SetOutput("IntermediateOut", std::vector<std::string>({ele_out_n}));
desc.SetType("fused_elemwise_activation");
desc.SetAttr("save_intermediate_out", true);
desc.SetAttr("functor_list", std::vector<std::string>(
{op_1->Op()->Type(), op_2->Op()->Type()}));
// Set attrs
for (auto &n : {op_1->Op(), op_2->Op()}) {
for (auto &m_ele : n->GetAttrMap()) {
desc.SetAttr(m_ele.first, m_ele.second);
}
}
auto elewise_add_act_node = g->CreateOpNode(&desc);
return elewise_add_act_node;
}
void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
std::unordered_set<const Node *> need_removed_nodes;
for (auto &cur_node : graph->Nodes()) {
if (cur_node->IsVar()) continue;
if (cur_node->Name() == "fused_elemwise_activation") {
bool save_intermediate_out =
boost::get<bool>(cur_node->Op()->GetAttr("save_intermediate_out"));
auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut");
PADDLE_ENFORCE(
save_intermediate_out && !intermediate_out_args.empty(),
"The %s should save the intermediate_out in the fusing stage.",
cur_node->Name());
// If the intermediate_out's output is empty, it should be removed.
auto cur_node_outputs = cur_node->outputs;
for (auto &out : cur_node_outputs) {
if (out->Name() == intermediate_out_args[0]) {
if (out->outputs.size() == 0) {
cur_node->outputs = this->RemoveNode(out, cur_node->outputs);
need_removed_nodes.insert(std::move(out));
cur_node->Op()->SetAttr("save_intermediate_out", false);
}
}
}
} else if (cur_node->Name() == "fused_elemwise_activation_grad") {
auto intermediate_out_grad_args =
cur_node->Op()->Output(GradVarName("IntermediateOut"));
PADDLE_ENFORCE(
!intermediate_out_grad_args.empty(),
"The %s should save the intermediate_out in the fusing stage.",
cur_node->Name());
auto cur_node_outputs = cur_node->outputs;
// If the intermediate_out_g's output is empty, it should be removed.
for (auto &out : cur_node_outputs) {
if (out->Name() == intermediate_out_grad_args[0] &&
out->outputs.empty()) {
cur_node->Op()->SetOutput(GradVarName("IntermediateOut"), {});
cur_node->outputs = this->RemoveNode(out, cur_node->outputs);
need_removed_nodes.insert(std::move(out));
}
}
}
}
GraphSafeRemoveNodes(graph, need_removed_nodes);
}
void FuseElewiseAddActPass::ReLinkNodes(Graph *graph,
const Node *intermediate_out,
Node *op_1, Node *op_2,
Node *fused_op) const { // delete act
for (auto &in : op_1->inputs) {
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_1, fused_op, in->outputs);
}
std::unordered_set<const Node *> nodes2delete;
for (auto &out : op_1->outputs) {
if (out->IsCtrlVar()) {
auto result_iter = std::find_if(
op_2->inputs.begin(), op_2->inputs.end(),
[&out](const Node *node) -> bool { return node == out; });
if (result_iter == op_2->inputs.end()) {
IR_OP_VAR_LINK(fused_op, out);
} else {
nodes2delete.emplace(out);
}
} else {
PADDLE_ENFORCE(out == intermediate_out);
IR_OP_VAR_LINK(fused_op, out);
}
}
for (auto &in : op_2->inputs) {
if (in == intermediate_out || nodes2delete.count(in)) {
continue;
}
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_2, fused_op, in->outputs);
}
for (auto &out : op_2->outputs) {
IR_OP_VAR_LINK(fused_op, out);
}
nodes2delete.insert(std::move(op_1));
nodes2delete.insert(std::move(op_2));
GraphSafeRemoveNodes(graph, nodes2delete);
}
std::vector<Node *> FuseElewiseAddActPass::ReplaceNode(
Node *cur_node, Node *new_node, const std::vector<Node *> &nodes) const {
std::vector<Node *> new_list(nodes.size());
bool has_replaced = false;
std::transform(nodes.begin(), nodes.end(), new_list.begin(),
[&](Node *node) -> Node * {
if (node == cur_node) {
has_replaced = true;
return new_node;
}
return node;
});
PADDLE_ENFORCE(has_replaced, "Not find %s in the node list.",
cur_node->Name());
return new_list;
}
std::vector<Node *> FuseElewiseAddActPass::RemoveNode(
Node *trg_node, const std::vector<Node *> &nodes) const {
std::vector<Node *> new_list(nodes.size());
auto end_iter =
std::copy_if(nodes.begin(), nodes.end(), new_list.begin(),
[&](Node *node) -> bool { return node != trg_node; });
new_list.resize(
static_cast<uint64_t>(std::distance(new_list.begin(), end_iter)));
return new_list;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_elewise_add_act_pass,
paddle::framework::ir::FuseElewiseAddActPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the ElewiseAdd and activation
*/
class FuseElewiseAddActPass : public FusePassBase {
public:
virtual ~FuseElewiseAddActPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
std::unique_ptr<ir::Graph> FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
/**
* Remove the removable intermediate_out.
* - If the intermediate_out is only used by the backward op, but the
* backward op doesn't use intermediate_out.
* - If the intermediate_out_grad is not used by any op.
*/
void RemoveIntermediateOut(Graph *graph) const;
std::vector<Node *> ReplaceNode(Node *cur_node, Node *new_node,
const std::vector<Node *> &nodes) const;
std::vector<Node *> RemoveNode(Node *trg_node,
const std::vector<Node *> &nodes) const;
void ReLinkNodes(Graph *graph, const Node *intermediate_out, Node *op_1,
Node *op_2, Node *fused_op) const;
Node *CreateFuseElewiseAddActNode(Graph *g, const Node *op_1,
const Node *op_2,
const std::string &ele_x_n,
const std::string &ele_y_n,
const std::string &ele_out_n,
const std::string &act_out_n) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -95,6 +95,7 @@ struct PDNode {
PDNode* assert_is_op();
PDNode* assert_is_op(const std::string& op_type);
PDNode* assert_is_var();
PDNode* assert_is_not_ctrl_var();
PDNode* assert_var_not_persistable();
PDNode* assert_is_persistable_var();
PDNode* assert_is_op_output(const std::string& op_type);
......@@ -113,6 +114,20 @@ struct PDNode {
PDNode* assert_op_has_n_outputs(const std::string& op_type, size_t n);
PDNode* assert_more(teller_t&& teller);
PDNode* assert_is_ops_output(const std::unordered_set<std::string>& op_types);
PDNode* assert_is_ops(const std::unordered_set<std::string>& op_types);
PDNode* assert_is_ops_output(const std::unordered_set<std::string>& op_types,
const std::string& argument);
PDNode* assert_is_ops_nth_input(
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
PDNode* assert_is_ops_input(const std::unordered_set<std::string>& op_types);
PDNode* assert_is_ops_input(const std::unordered_set<std::string>& op_types,
const std::string& argument);
PDNode* assert_is_ops_nth_output(
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
private:
PDNode(PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
......@@ -447,6 +462,68 @@ struct GRU : public PatternBase {
PATTERN_DECL_NODE(Hidden);
};
// The following patterns are used to fuse elewise_add and act
// formula: act(ele_add(x, y))
// op: elementwise_add + act
// named nodes: elementwise_add, act
// ele_x, ele_y, elewise_add_out, act_out
struct ElewiseAddAct : public PatternBase {
ElewiseAddAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elewise_add_act") {}
PDNode* operator()(PDNode* x, std::unordered_set<std::string> acts);
// declare operator node's name
PATTERN_DECL_NODE(ele_add);
PATTERN_DECL_NODE(act);
// declare variable node's name
PATTERN_DECL_NODE(elewise_add_out);
PATTERN_DECL_NODE(ele_y);
PATTERN_DECL_NODE(act_out);
};
// formula: ele_add(x, act(y))
// op: elementwise_add + act
// named nodes: elementwise_add, act
// act_in, act_out, ele_x, elewise_add_out
struct ActElewiseAdd : public PatternBase {
ActElewiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "act_elewise_add") {}
PDNode* operator()(PDNode* x, std::unordered_set<std::string> acts);
// declare operator node's name
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(ele_add);
// declare variable node's name
PATTERN_DECL_NODE(act_out);
PATTERN_DECL_NODE(ele_x);
PATTERN_DECL_NODE(elewise_add_out);
};
// the backward of act(ele_add(x, y))
// the act is inplace.
// op: elementwise_add_grad + act_grad
// named nodes: elementwise_add_grad, act_grad
// act_out, act_out_g, ele_y, d_itermediate_out, d_ele_x, d_ele_y
struct ElewiseAddActInplaceGrad : public PatternBase {
ElewiseAddActInplaceGrad(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elewise_add_act_grad1") {}
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
PDNode* operator()(PDNode* x, std::unordered_set<std::string> acts);
// declare operator node's name
PATTERN_DECL_NODE(act_grad);
PATTERN_DECL_NODE(ele_add_grad);
// declare variable node's name
PATTERN_DECL_NODE(act_out);
PATTERN_DECL_NODE(d_itermediate_out);
PATTERN_DECL_NODE(d_ele_x);
PATTERN_DECL_NODE(d_ele_y);
PATTERN_DECL_NODE(ele_y);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......@@ -454,6 +531,12 @@ struct GRU : public PatternBase {
a->outputs.push_back(b); \
b->inputs.push_back(a);
// Set the out_var as the output of the op
#define IR_OP_VAR_LINK(op, out_var) \
op->outputs.push_back(out_var); \
out_var->inputs.clear(); \
out_var->inputs.push_back(op);
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -48,6 +48,10 @@ class Node {
bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; }
bool IsCtrlVar() const {
return type_ == Type::kVariable &&
Name().find(ir::Node::kControlDepVarName) != std::string::npos;
}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
......
......@@ -57,6 +57,21 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
graph = viz_pass->Apply(std::move(graph));
}
// Apply op fusion.
if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass =
ir::PassRegistry::Instance().Get("fuse_elewise_add_act_pass");
graph = fuse_elewise_add_act_pass->Apply(std::move(graph));
// Apply a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
}
// Convert graph to run on multi-devices.
auto multi_devices_pass =
ir::PassRegistry::Instance().Get("multi_devices_pass");
......@@ -359,6 +374,7 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework
} // namespace paddle
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
......
......@@ -13,18 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
/*
* Whether the compound function is Unary(Binary(X, Y)).
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
* out.
*/
static bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> binary_fun = {
"elementwise_add", "elementwise_mul", "elementwise_add_grad",
......@@ -32,10 +25,17 @@ static bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
return binary_fun.count(functor_list[1]) != 0;
}
/*
* Whether the Input(X) could be absent.
*/
static bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
bool HasInPlaceUnary(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> InplaceOpSet = {"relu", "relu_grad"};
bool is_in_place = false;
for (auto &func_name : functor_list) {
is_in_place |= (InplaceOpSet.count(func_name) == 1);
}
return is_in_place;
}
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"};
return binary_fun.count(functor_list[0]) != 0 ||
......@@ -86,20 +86,12 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
// Whether the shape of Y is a continuous subsequence of X,
// For more information please refer to the op's introduction.
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] < y_dim[i]) {
bcast_y = false;
break;
}
}
}
bool bcast_y = IsBcastY(x_dim, y_dim);
auto &out_dim = bcast_y ? x_dim : y_dim;
std::string out_lod = bcast_y ? "X" : "Y";
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
if (ctx->Attrs().Get<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"),
"Output(IntermediateOut) of FusedElemwiseActivationOp "
"should not be null.");
......@@ -123,6 +115,20 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
ctx->ShareLoD(out_lod, /*->*/ "Out");
}
static bool IsBcastY(const framework::DDim &x_dim,
const framework::DDim &y_dim) {
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] < y_dim[i]) {
bcast_y = false;
break;
}
}
}
return bcast_y;
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -157,17 +163,7 @@ class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("scale",
"scale is used by scale_op, the default value is 0.0.")
.SetDefault(0.0);
AddAttr<bool>(
"recomputation",
"Whether to recompute the Out."
"The computation of fused_elemwise_activation_grad has two methods to "
"get the dx and dy, one is to use the 'Out', and the other is not. "
"The former method will save the time of recomputing the 'Out', but it "
"must occupy the memory to store the 'out'. While, the later method "
"can avoid occupying the memory, but it must recompute the 'Out'. "
"It is useful for Unary(Binary(X, Y)). The default value is true.")
.SetDefault(true);
AddAttr<bool>("keep_intermediate_value",
AddAttr<bool>("save_intermediate_out",
"Whether to save the intermediate_out.")
.SetDefault(false);
AddAttr<std::vector<std::string>>("functor_list",
......@@ -227,30 +223,38 @@ class FusedElemwiseActivationGradMaker
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *op_desc_ptr = new framework::OpDesc();
op_desc_ptr->SetType(this->ForwardOpType() + "_grad");
auto *grad_op = new framework::OpDesc();
grad_op->SetType(this->ForwardOpType() + "_grad");
for (auto &input_param : this->InputNames()) {
op_desc_ptr->SetInput(input_param, this->Input(input_param));
op_desc_ptr->SetOutput(framework::GradVarName(input_param),
grad_op->SetInput(input_param, this->Input(input_param));
grad_op->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param, true));
}
for (auto &output_param : this->OutputNames()) {
op_desc_ptr->SetInput(output_param, this->Output(output_param));
op_desc_ptr->SetInput(framework::GradVarName(output_param),
this->OutputGrad(output_param));
}
grad_op->SetInput("Out", this->Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op_desc_ptr->SetAttrMap(this->Attrs());
grad_op->SetAttrMap(this->Attrs());
std::vector<std::string> functor_names =
boost::get<std::vector<std::string>>(
op_desc_ptr->GetAttr("functor_list"));
boost::get<std::vector<std::string>>(grad_op->GetAttr("functor_list"));
functor_names[0] += "_grad";
functor_names[1] += "_grad";
op_desc_ptr->SetAttr("functor_list", functor_names);
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
grad_op->SetAttr("functor_list", functor_names);
if (boost::get<bool>(grad_op->GetAttr("save_intermediate_out"))) {
PADDLE_ENFORCE_NE(Output("IntermediateOut").size(), 0);
grad_op->SetInput("IntermediateOut", this->Output("IntermediateOut"));
grad_op->SetOutput(framework::GradVarName("IntermediateOut"),
this->OutputGrad("IntermediateOut"));
} else {
grad_op->SetInput("IntermediateOut", {});
grad_op->SetOutput(framework::GradVarName("IntermediateOut"), {});
}
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
......@@ -261,56 +265,65 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
auto functor_list =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
if (ctx->Attrs().Get<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"),
"Input(IntermediateOut) should not be null");
} else {
PADDLE_ENFORCE_EQ(ctx->Inputs(framework::GradVarName("Out")).size(), 1);
if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
}
}
auto funtor_list =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
auto inter_grad_name = framework::GradVarName("IntermediateOut");
if (ctx->HasOutput(x_grad_name)) {
if (ctx->HasInputs("X")) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", x_grad_name);
} else {
// Node: If "X" is absence, the shape of Y should be a continuous
// subsequence of X, if not, we could not infer the shape of dx.
// Currently, only when Binary is elementwise_add or elementwise_sub,
// the "X" could be absent.
PADDLE_ENFORCE(InputXCanBeAbsent(funtor_list),
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
"Only when BinaryFunctor is elementwise_add, the 'X' "
"could be absent.");
// For Unary(Binary(X, Y)), IntermediateOut should not be empty.
if (IsUnaryCompound(funtor_list)) {
PADDLE_ENFORCE(
ctx->HasInputs("IntermediateOut"),
"If the compound_functor is Unary(Binary(X, Y)) and Binary "
"is elementwise_add, the intermediate_out must be not absent.");
}
// Node: If "X" is absence, the shape of Y should be a continuous
// subsequence of X, otherwise, we could not infer the shape of dx.
ctx->SetOutputDim(x_grad_name,
ctx->GetInputDim(framework::GradVarName("Out")));
ctx->ShareLoD(framework::GradVarName("Out"), x_grad_name);
}
}
if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", y_grad_name);
}
if (ctx->HasOutput(inter_grad_name)) {
// For Unary(Binary(X, Y)), IntermediateOut should not be empty.
if (IsUnaryCompound(functor_list)) {
ctx->SetOutputDim(inter_grad_name,
ctx->GetInputDim(framework::GradVarName("Out")));
ctx->ShareLoD(framework::GradVarName("Out"), inter_grad_name);
} else {
ctx->SetOutputDim(inter_grad_name, ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", inter_grad_name);
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
// PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type();
auto input_data_type = framework::ToDataType(input_data_type_index);
return framework::OpKernelType(input_data_type, ctx.GetPlace());
......
......@@ -26,6 +26,24 @@ limitations under the License. */
namespace paddle {
namespace operators {
/**
* Whether the compound function is Unary(Binary(X, Y)).
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
* out.
*/
bool IsUnaryCompound(const std::vector<std::string> &functor_list);
/**
* For the in-place unary functor, the inputs of op_desc only have Out and
* Out@Grad.
*/
bool HasInPlaceUnary(const std::vector<std::string> &functor_list);
/**
* Whether the Input(X) could be absent.
*/
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list);
template <typename DeviceContext, typename T, typename BinaryFunctor,
typename UnaryFunctor>
static void RunBinaryCompoundFunctor(
......@@ -39,7 +57,7 @@ static void RunBinaryCompoundFunctor(
paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
compound_func(binary_functor, unary_functor);
int axis = ctx.Attr<int>("axis");
if (ctx.Attr<bool>("keep_intermediate_value")) {
if (ctx.Attr<bool>("save_intermediate_out")) {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::BinaryCompoundFunctor<
T, BinaryFunctor, UnaryFunctor>,
......@@ -71,7 +89,7 @@ static void RunUnaryCompoundFunctors(
paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
compound_func(unary_functor, binary_functor);
if (ctx.Attr<bool>("keep_intermediate_value")) {
if (ctx.Attr<bool>("save_intermediate_out")) {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::UnaryCompoundFunctor<
T, UnaryFunctor, BinaryFunctor>,
......@@ -89,7 +107,7 @@ static void RunUnaryCompoundFunctors(
}
template <typename DeviceContext, typename T, typename BinaryGradFunctor,
typename UnaryFunctor, typename UnaryGradFunctor>
typename UnaryFunctor, typename UnaryGradFunctor, bool InPlace>
static void RunBinaryCompoundGradFunctors(
const framework::ExecutionContext &ctx,
const BinaryGradFunctor &binary_grad_functor,
......@@ -98,7 +116,7 @@ static void RunBinaryCompoundGradFunctors(
const framework::Tensor *in_y, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad) {
framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
// Z = Binary(X, Unary(Y))
int axis = ctx.Attr<int>("axis");
......@@ -107,32 +125,40 @@ static void RunBinaryCompoundGradFunctors(
UnaryFunctor>;
using BinaryCompoundDyFunctor =
paddle::operators::math::BinaryCompoundGradDyFunctor<
T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor>;
T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
using BinaryCompoundDIntermedaiteOutFunctor =
paddle::operators::math::BinaryCompoundGradDIntermedaiteOutFunctor<
T, BinaryGradFunctor, UnaryFunctor>;
if (in_intermediate_out) {
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
true /*UseIntermediateOut*/,
BinaryCompoundDIntermedaiteOutFunctor, true /*UseIntermediateOut*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
y_grad, d_intermediate_out,
BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
unary_grad_functor));
unary_grad_functor),
BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
unary_functor));
} else {
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
false /*UseIntermediateOut*/,
BinaryCompoundDIntermedaiteOutFunctor, false /*UseIntermediateOut*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
y_grad, d_intermediate_out,
BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
unary_grad_functor));
unary_grad_functor),
BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
unary_functor));
}
}
template <typename DeviceContext, typename T, typename UnaryGradFunctor,
typename BinaryFunctor, typename BinaryGradFunctor,
bool Recomputation = true>
typename BinaryFunctor, typename BinaryGradFunctor, bool InPlace>
static void RunUnaryCompoundGradFunctors(
const framework::ExecutionContext &ctx,
const UnaryGradFunctor &unary_grad_functor,
......@@ -141,36 +167,44 @@ static void RunUnaryCompoundGradFunctors(
const framework::Tensor *in_y, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad) {
framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
// Z = Unary(Binary(X, Y))
int axis = ctx.Attr<int>("axis");
using UnaryCompoundDxFunctor =
paddle::operators::math::UnaryCompoundGradDxFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>;
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDyFunctor =
paddle::operators::math::UnaryCompoundGradDyFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>;
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDIntermediateFunctor =
paddle::operators::math::UnaryCompoundGradDIntermediateFunctor<
T, UnaryGradFunctor, BinaryFunctor, InPlace>;
if (in_intermediate_out) {
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
true /*UseIntermediateOut*/, true /*SameShapeOfIntermediateOutAndOut*/>(
UnaryCompoundDIntermediateFunctor, true /*UseIntermediateOut*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
y_grad, d_intermediate_out,
UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
binary_grad_functor),
UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
binary_grad_functor));
binary_grad_functor),
UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
} else {
FusedElemwiseAndActGradComputeEx<DeviceContext, T, UnaryCompoundDxFunctor,
UnaryCompoundDyFunctor,
false /*UseIntermediateOut*/,
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
UnaryCompoundDIntermediateFunctor, false /*UseIntermediateOut*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
y_grad, d_intermediate_out,
UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
binary_grad_functor),
UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
binary_grad_functor));
binary_grad_functor),
UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
}
}
......@@ -226,72 +260,67 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
}
}
template <typename DeviceContext, typename T, bool ReComputation>
static void RunGradFunctors(const framework::ExecutionContext &ctx,
const framework::Tensor *in_x,
const framework::Tensor *in_y,
const framework::Tensor *in_out,
template <typename DeviceContext, typename T, bool InPlace>
static void RunGradFunctors(
const framework::ExecutionContext &ctx, const framework::Tensor *in_x,
const framework::Tensor *in_y, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad,
framework::Tensor *x_grad,
framework::Tensor *y_grad) {
const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
auto funcs_str = functors[0] + "," + functors[1];
// TODO(zcd): The following code can be refined. for example, use registrition
if (funcs_str == "elementwise_add_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::AddGradFunctor<T>,
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>>(
paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::AddGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "scale_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunUnaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::ScaleGradFunctor<T>,
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::ScaleGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>,
ReComputation /*Recomputation*/>(
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::ScaleGradFunctor<T>(scale),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_add_grad,relu_grad") {
RunBinaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::AddGradFunctor<T>,
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
paddle::operators::math::ReluFunctor<T>,
paddle::operators::math::ReluGradFunctor<T>>(
paddle::operators::math::ReluGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::AddGradFunctor<T>(),
paddle::operators::math::ReluFunctor<T>(),
paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "relu_grad,elementwise_add_grad") {
RunUnaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::ReluGradFunctor<T>,
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::ReluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>,
ReComputation /*Recomputation*/>(
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::ReluGradFunctor<T>(),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_mul_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::MulGradFunctor<T>,
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>>(
paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::MulGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
}
......@@ -313,9 +342,9 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
std::vector<framework::Tensor *> outputs;
outputs.emplace_back(output);
if (ctx.Attr<bool>("keep_intermediate_value")) {
if (ctx.Attr<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
"The keep_intermediate_value is enable, so the "
"The save_intermediate_out is enable, so the "
"IntermediateOut should not be empty.");
auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
outputs.emplace_back(intermediate_out);
......@@ -331,65 +360,63 @@ template <typename DeviceContext, typename T>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto x = ctx.Input<framework::Tensor>("X");
auto y = ctx.Input<framework::Tensor>("Y");
auto in_y = ctx.Input<framework::Tensor>("Y");
PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr.");
auto in_out = ctx.Input<framework::Tensor>("Out");
PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr.");
auto in_out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(in_out_grad != nullptr,
"Input(Out@Grad) should not be nullptr.");
framework::Tensor *in_x =
const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
framework::Tensor *x_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
framework::Tensor *y_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
framework::Tensor *d_intermediate_out = ctx.Output<framework::Tensor>(
framework::GradVarName("IntermediateOut"));
PADDLE_ENFORCE(y != nullptr, "Input(Y) should not be nullptr.");
auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");
if (ctx.Attr<bool>("recomputation")) {
PADDLE_ENFORCE(
x != nullptr,
"The recomputation is opened, so Input(X) should not be absent.");
// Get intermediate_out
framework::Tensor *in_intermediate_out = nullptr;
if (ctx.Attr<bool>("save_intermediate_out")) {
// if save_intermediate_out is true, for Unary(Binary(x, y)) and
// Binary(x, Unary(y)), the Binary(x, y) and Unary(y) not need to
// recompute.
in_intermediate_out = const_cast<framework::Tensor *>(
ctx.Input<framework::Tensor>("IntermediateOut"));
PADDLE_ENFORCE(in_intermediate_out != nullptr,
"The option of 'save_intermediate_out' is opened, "
"so the number of 'Out' should be two.");
} else {
PADDLE_ENFORCE(in_out != nullptr,
"The recomputation is disabled, so the Input('Out') "
"should not be empty.");
if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null.");
}
}
framework::Tensor *in_x;
auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");
// Get in_x
if (ctx.HasInput("X")) {
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr.");
} else {
// If functor_list contains elementwise_add, the backward doesn't use
// in_x, and in_outs.
if (x == nullptr) {
PADDLE_ENFORCE(functor_list[0] == "elementwise_add_grad" ||
functor_list[1] == "elementwise_add_grad",
// in_x, in_y and in_out.
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent.");
in_x = const_cast<framework::Tensor *>(in_out_grad);
in_out = const_cast<framework::Tensor *>(in_out_grad);
} else {
in_x = const_cast<framework::Tensor *>(x);
}
framework::Tensor *in_intermediate_out;
if (ctx.Attr<bool>("keep_intermediate_value")) {
in_intermediate_out = const_cast<framework::Tensor *>(
ctx.Input<framework::Tensor>("IntermediateOut"));
PADDLE_ENFORCE(in_intermediate_out != nullptr,
"The option of 'keep_intermediate_value' is opened, "
"so the number of 'Out' should be two.");
} else {
in_intermediate_out = nullptr;
}
if (ctx.Attr<bool>("recomputation")) {
RunGradFunctors<DeviceContext, T, true /*Recomputation*/>(
ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad);
bool has_in_place = HasInPlaceUnary(functor_list);
if (has_in_place) {
RunGradFunctors<DeviceContext, T, true /*InPlace*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad, d_intermediate_out);
} else {
RunGradFunctors<DeviceContext, T, false /*Recomputation*/>(
ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad);
RunGradFunctors<DeviceContext, T, false /*InPlace*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad, d_intermediate_out);
}
}
};
......
......@@ -22,11 +22,11 @@ namespace paddle {
namespace operators {
namespace math {
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename BinaryFunctor, typename UnaryFunctor>
struct BinaryCompoundFunctor {
BinaryCompoundFunctor(const BinaryFunctor func1, const UnaryFunctor func2)
: func1_(func1), func2_(func2) {}
// Z = BinaryFunctor(X, UnaryFunctor(Y))
inline HOSTDEVICE T GetOut(T x, T y) { return func1_(x, func2_(y)); }
......@@ -40,11 +40,11 @@ struct BinaryCompoundFunctor {
UnaryFunctor func2_;
};
// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename UnaryFunctor, typename BinaryFunctor>
struct UnaryCompoundFunctor {
UnaryCompoundFunctor(const UnaryFunctor func1, const BinaryFunctor func2)
: func1_(func1), func2_(func2) {}
// Z = UnaryFunctor(BinaryFunctor(X, Y))
inline HOSTDEVICE T GetOut(T x, T y) { return func1_(func2_(x, y)); }
......@@ -58,23 +58,19 @@ struct UnaryCompoundFunctor {
BinaryFunctor func2_;
};
// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get
// the dx, one is to use the 'out', and the other is not to use it.
// the former method will save the time of recomputing the
// 'out', but it must occupy the memory to store the 'out'.
// While the later method can avoid occupying this memory,
// but it must recompute the 'out'.
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDxFunctor {
BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun)
: d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dx(x, unary_fun_(y));
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
return dout * d_binary_fun_.Dx(x, intermediate_out);
}
......@@ -83,8 +79,9 @@ struct BinaryCompoundGradDxFunctor {
UnaryFun unary_fun_;
};
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun,
typename DUnaryFun>
typename DUnaryFun, bool InPlace>
struct BinaryCompoundGradDyFunctor {
BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun,
......@@ -93,13 +90,19 @@ struct BinaryCompoundGradDyFunctor {
unary_fun_(unary_fun),
d_unary_fun_(d_unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_(y);
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_.UseX(y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
if (InPlace) {
return dout * d_binary_fun_.Dy(x, intermediate_out) *
d_unary_fun_(y, intermediate_out);
d_unary_fun_.UseOut(intermediate_out);
} else {
return dout * d_binary_fun_.Dy(x, intermediate_out) *
d_unary_fun_.UseXAndOut(y, intermediate_out);
}
}
private:
......@@ -108,8 +111,9 @@ struct BinaryCompoundGradDyFunctor {
DUnaryFun d_unary_fun_;
};
// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
typename DBinaryFun, bool InPlace>
struct UnaryCompoundGradDxFunctor {
UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
......@@ -118,22 +122,23 @@ struct UnaryCompoundGradDxFunctor {
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
base = dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
}
return base * d_binary_fun_.Dx(x, y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(intermediate_out);
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
} else {
base = dout * d_unary_fun_(intermediate_out, out);
base = dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
}
return base * d_binary_fun_.Dx(x, y);
}
......@@ -144,8 +149,9 @@ struct UnaryCompoundGradDxFunctor {
DBinaryFun d_binary_fun_;
};
// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
typename DBinaryFun, bool InPlace>
struct UnaryCompoundGradDyFunctor {
UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
......@@ -154,22 +160,23 @@ struct UnaryCompoundGradDyFunctor {
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
base = dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
}
return base * d_binary_fun_.Dy(x, y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(intermediate_out);
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
} else {
base = dout * d_unary_fun_(intermediate_out, out);
base = dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
}
return base * d_binary_fun_.Dy(x, y);
}
......@@ -180,6 +187,56 @@ struct UnaryCompoundGradDyFunctor {
DBinaryFun d_binary_fun_;
};
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDIntermedaiteOutFunctor {
BinaryCompoundGradDIntermedaiteOutFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun)
: d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dy(x, unary_fun_(y));
}
inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out,
T dout) {
return dout * d_binary_fun_.Dy(x, intermediate_out);
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
};
// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun, bool InPlace>
struct UnaryCompoundGradDIntermediateFunctor {
UnaryCompoundGradDIntermediateFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun)
: d_unary_fun_(d_unary_fun), binary_fun_(binary_fun) {}
inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
if (InPlace) {
return dout * d_unary_fun_.UseOut(out);
} else {
return dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
}
}
inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out,
T dout) {
if (InPlace) {
return dout * d_unary_fun_.UseOut(out);
} else {
return dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
}
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -58,9 +58,9 @@ template <typename T>
struct ScaleGradFunctor {
explicit ScaleGradFunctor(T coeff) : coeff_(coeff) {}
inline HOSTDEVICE T operator()(T x) { return coeff_; }
inline HOSTDEVICE T operator()(T x, T out) { return coeff_; }
inline HOSTDEVICE T UseX(T x) { return coeff_; }
inline HOSTDEVICE T UseOut(T out) { return coeff_; }
inline HOSTDEVICE T UseXAndOut(T x, T out) { return coeff_; }
private:
T coeff_;
......@@ -73,9 +73,9 @@ struct ReluFunctor {
template <typename T>
struct ReluGradFunctor {
inline HOSTDEVICE T operator()(T x) { return x > 0 ? 1 : 0; }
inline HOSTDEVICE T operator()(T x, T out) { return x > 0 ? 1 : 0; }
inline HOSTDEVICE T UseX(T x) { return x > 0 ? 1 : 0; }
inline HOSTDEVICE T UseOut(T out) { return out > 0 ? 1 : 0; }
inline HOSTDEVICE T UseXAndOut(T x, T out) { return out > 0 ? 1 : 0; }
};
} // namespace math
......
......@@ -670,7 +670,14 @@ All parameter, weight, gradient are variables in Paddle.
.def_property(
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; });
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; })
.def_property("fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
return self.fuse_elewise_add_act_ops_;
},
[](BuildStrategy &self, bool b) {
self.fuse_elewise_add_act_ops_ = b;
});
pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &,
......
......@@ -47,8 +47,7 @@ def get_numeric_gradient(place,
input_to_check,
output_names,
delta=0.005,
in_place=False,
sum_outputs=None):
in_place=False):
# FIXME: change this method by compile time concepts
set_input(scope, op, inputs, place)
......@@ -59,8 +58,6 @@ def get_numeric_gradient(place,
sum = []
op.run(scope, place)
for output_name in output_names:
if sum_outputs and output_name not in sum_outputs:
continue
sum.append(
np.array(scope.find_var(output_name).get_tensor()).mean())
return np.array(sum).sum() / len(output_names)
......@@ -407,14 +404,13 @@ class OpTest(unittest.TestCase):
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
sum_outputs=None):
user_defined_grads=None):
places = self._get_places()
for place in places:
self.check_grad_with_place(place, inputs_to_check, output_names,
no_grad_set, numeric_grad_delta,
in_place, max_relative_error,
user_defined_grads, sum_outputs)
user_defined_grads)
def check_grad_with_place(self,
place,
......@@ -424,8 +420,7 @@ class OpTest(unittest.TestCase):
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
sum_outputs=None):
user_defined_grads=None):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
......@@ -448,8 +443,7 @@ class OpTest(unittest.TestCase):
input_to_check,
output_names,
delta=numeric_grad_delta,
in_place=in_place,
sum_outputs=sum_outputs) for input_to_check in inputs_to_check
in_place=in_place) for input_to_check in inputs_to_check
]
analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set)
......
......@@ -38,6 +38,7 @@ class TestParallelExecutorBase(unittest.TestCase):
seed=None,
use_parallel_executor=True,
use_reduce=False,
fuse_elewise_add_act_ops=False,
optimizer=fluid.optimizer.Adam,
use_fast_executor=False):
def run_executor(exe, feed, fetch_list, program=None):
......@@ -78,6 +79,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
if use_parallel_executor:
exe = fluid.ParallelExecutor(
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parallel_executor_test_base import TestParallelExecutorBase
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import paddle
import paddle.dataset.mnist as mnist
import unittest
import os
MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio"
def simple_fc_net(use_feed):
if use_feed:
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_files(
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = img
for _ in range(4):
hidden = fluid.layers.fc(
hidden,
size=200,
act='relu',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
def fc_with_batchnorm(use_feed):
if use_feed:
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_files(
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = img
for _ in range(2):
hidden = fluid.layers.fc(
hidden,
size=200,
act='relu',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
hidden = fluid.layers.batch_norm(input=hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestMNIST(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=4)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
MNIST_RECORDIO_FILE, reader, feeder)
def _init_data(self, random=True):
np.random.seed(5)
if random:
img = np.random.random(size=[32, 784]).astype(np.float32)
else:
img = np.ones(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
return img, label
def _compare_fuse_elewise_add_act_ops(self,
model,
use_cuda,
random_data=True):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = self._init_data(random_data)
def _optimizer(learning_rate=1e-6):
optimizer = fluid.optimizer.SGD(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-6))
return optimizer
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
fuse_elewise_add_act_ops=False,
memory_opt=False,
optimizer=_optimizer)
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
fuse_elewise_add_act_ops=True,
memory_opt=False,
optimizer=_optimizer)
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
def test_simple_fc_with_fuse_op(self):
self._compare_fuse_elewise_add_act_ops(simple_fc_net, True)
self._compare_fuse_elewise_add_act_ops(simple_fc_net, False)
def test_batchnorm_fc_with_fuse_op(self):
self._compare_fuse_elewise_add_act_ops(fc_with_batchnorm, True)
self._compare_fuse_elewise_add_act_ops(fc_with_batchnorm, False)
if __name__ == '__main__':
unittest.main()
......@@ -48,7 +48,7 @@ def create_test_class(test_case, callback, attrs):
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
if self.attrs["keep_intermediate_value"]:
if self.attrs["save_intermediate_out"]:
self.outputs = {
'Out': self.out,
"IntermediateOut": self.intermediate_out
......@@ -73,22 +73,19 @@ def create_test_class(test_case, callback, attrs):
def test_check_output(self):
self.check_output()
# FIXME(zcd): the intermediate_out_grad is not checked.
def test_check_grad_normal(self):
if self.attrs["keep_intermediate_value"]:
self.check_grad(
['X', 'Y'], ['Out', 'IntermediateOut'],
max_relative_error=0.005,
sum_outputs=['Out'])
if self.attrs["save_intermediate_out"]:
self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005)
else:
self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005)
def test_check_grad_ingore_x(self):
if self.attrs["keep_intermediate_value"]:
if self.attrs["save_intermediate_out"]:
self.check_grad(
['Y'], ['Out', 'IntermediateOut'],
['Y'], ['Out'],
max_relative_error=0.005,
no_grad_set=set("X"),
sum_outputs=['Out'])
no_grad_set=set("X"))
else:
self.check_grad(
['Y'], ['Out'],
......@@ -96,12 +93,11 @@ def create_test_class(test_case, callback, attrs):
no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
if self.attrs["keep_intermediate_value"]:
if self.attrs["save_intermediate_out"]:
self.check_grad(
['X'], ['Out', 'IntermediateOut'],
['X'], ['Out'],
max_relative_error=0.005,
no_grad_set=set("Y"),
sum_outputs=['Out'])
no_grad_set=set("Y"))
else:
self.check_grad(
['X'], ['Out'],
......@@ -303,38 +299,31 @@ for mode in {0, 1}:
relu_add_func = partial(relu_add_func, mode=mode)
add_relu_func = partial(add_relu_func, mode=mode)
for recomputation in {True, False}:
for keep_intermediate_value in {True, False}:
suffix = ("_keep_intermediate_value" if keep_intermediate_value else "") \
+ ("_recomputation" if recomputation else "") \
for save_intermediate_out in {True, False}:
suffix = ("_save_intermediate_out" if save_intermediate_out else "") \
+ ("_mode_"+ str(mode))
create_test_class('scale_add' + suffix, scale_add_func, {
'scale': scale,
'functor_list': ["scale", "elementwise_add"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
'save_intermediate_out': save_intermediate_out,
})
create_test_class('add_scale' + suffix, add_scale_func, {
'scale': scale,
'functor_list': ["elementwise_add", "scale"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
'save_intermediate_out': save_intermediate_out,
})
create_test_class('add_relu' + suffix, add_relu_func, {
'functor_list': ["elementwise_add", "relu"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
'save_intermediate_out': save_intermediate_out,
})
create_test_class('relu_add' + suffix, relu_add_func, {
'functor_list': ["relu", "elementwise_add"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
'save_intermediate_out': save_intermediate_out,
})
create_test_class('mul_scale' + suffix, mul_scale_func, {
'scale': scale,
'functor_list': ["elementwise_mul", "scale"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
'save_intermediate_out': save_intermediate_out,
})
if __name__ == '__main__':
......
......@@ -79,7 +79,7 @@ class TestReshapeOpWithInputShape(OpTest):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out", sum_outputs=["Out"])
self.check_grad(["X"], "Out")
if __name__ == "__main__":
......
......@@ -34,7 +34,7 @@ class TestTransposeOp(OpTest):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(['X'], 'Out', sum_outputs=['Out'])
self.check_grad(['X'], 'Out')
def initTestCase(self):
self.shape = (3, 4)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册