diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index ff443e82ae75151448605e08347ba324c68e055f..ed4e67879c795258683b094cfaeaff9063d66848 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -36,6 +36,7 @@ paddle.fluid.default_startup_program ArgSpec(args=[], varargs=None, keywords=Non paddle.fluid.default_main_program ArgSpec(args=[], varargs=None, keywords=None, defaults=None) paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.name_scope ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 7722c9401e0e7c071adb7bee9b35306431bb7a11..0bfff745493d069e948e6d277ec2bbfb0673a70b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -625,19 +625,11 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( ir::Graph *result, const std::string &loss_grad_name) const { for (size_t i = 0; i < places_.size(); ++i) { -// Insert ScaleCost OpHandle -#ifdef PADDLE_WITH_CUDA - auto *communication_dev_ctx = - nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i]) - : platform::DeviceContextPool::Instance().Get(places_[i]); -#else - auto *communication_dev_ctx = - platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); -#endif + // Insert ScaleCost OpHandle + auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *op_handle = new ScaleLossGradOpHandle( result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), - local_scopes_.size(), local_scopes_[i], places_[i], - communication_dev_ctx); + local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx); result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -744,7 +736,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, .emplace(varname, op_dev_id); } } else { - PADDLE_ENFORCE( + PADDLE_THROW( "the distribute training related op should be in [split_byref, " "concat]."); } diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc index 2876de88f174b1fa4ce0eacb8687e15e723bf1fc..d2d051a69a33a38535e67227d4cc62f5b35e430c 100644 --- a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc @@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) { auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - auto* while_pat_node = gpd.pattern().RetriveNode("while"); + auto* while_pat_node = gpd.pattern().RetrieveNode("while"); auto* while_node = subgraph.at(while_pat_node); marked_nodes.insert(while_node); }; diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 201160f29df1ee5473ba5e6cf434fa246e015a12..513742bab69d465aac1bfb7bcef2fe89108c14a0 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) { } void BuildFCPattern(PDPattern* pattern) { - // make sure the selected MUL op has one input argument is a parameter. - auto* mul_parameter_var = pattern->NewNode( - [](Node* node) { - return node->IsVar() && node->outputs.size() == 1UL && - node->outputs.front()->Op()->Type() == "mul" && node->Var() && - node->Var()->Persistable(); // check is a parameter - }, - "mul_weight" /*name*/); - - auto* mul_tmp_input_var = pattern->NewNode( - [](Node* node) { - bool result = - node->IsVar() && node->outputs.size() >= 1UL && node->Var() && - !node->Var()->Persistable(); // this input is not an parameter. - if (!result) return false; - // check whether one output is MUL op. - for (auto* op : node->outputs) { - if (op->IsOp() && op->Op()->Type() == "mul") return true; - } - return false; - }, - "mul_tmp_var" /*name*/); - - // select a MUL op - auto* mul_op = pattern->NewNode( - [](Node* node) { - return node->IsOp() && // start from an Op - node->Op()->Type() == "mul"; // type is mul - // the output should be consumed only by one element_add, that check - // leaves in a Var PDNode. - }, - "mul" /*name*/); - - // make sure the MUL op's output has only one consumer and links to an - // ELEMENTWISE_ADD op. - auto* mul_out_var = pattern->NewNode( - [](Node* node) { - return node->IsVar() && // starts from a Var - node->outputs.size() == 1UL && // only has one consumer - node->outputs.front()->IsOp() && // check basic logic - node->Var() && // not a ControlDepVar - node->outputs.front()->Op()->Type() == - "elementwise_add"; // a very strong validation - }, - "mul_out"); - // this check is not essential, just to make the corresponding variable Node - // retrival easier. - auto* elementwise_add_tmp_var = pattern->NewNode( - [](Node* node) { - return node->IsVar() && node->outputs.size() >= 1UL && node->Var() && - VarOutLinksToOp(node, "elementwise_add"); - }, - "elementwise_add_tmpvar"); - - // select an ELEMENTWISE_ADD op - auto* elementwise_add_op = pattern->NewNode( - [](Node* node) { - return node->IsOp() && node->Op()->Type() == "elementwise_add"; - }, - "elementwise_add" /*name*/); - - // get the ELEMENTWISE_ADD op's output - auto* elementwise_add_out_var = pattern->NewNode( - [](Node* node) { - return node->IsVar() && node->inputs.size() == 1UL && node->Var() && - node->inputs.front()->Op()->Type() == "elementwise_add"; - }, - "elementwise_add_out"); - - mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var}) - .LinksTo({mul_out_var}); + // Create Operators + auto* mul_op = pattern->NewNode("mul")->assert_is_op("mul"); + auto* elementwise_add_op = + pattern->NewNode("elementwise_add")->assert_is_op("elementwise_add"); + // Create variables + // w + auto* mul_weight_var = pattern->NewNode("mul_weight") + ->AsInput() + ->assert_is_op_nth_input("mul", "Y", 0); + // x + auto* mul_tmp_var = pattern->NewNode("mul_tmp_var") + ->AsInput() + ->assert_is_op_nth_input("mul", "X", 0); + // intermediate variable, will be removed in the IR after fuse. + auto* mul_out_var = pattern->NewNode("mul_out") + ->AsIntermediate() + ->assert_is_only_output_of_op("mul") + ->assert_is_op_input("elementwise_add"); + // bias + auto* elementwise_add_tmp_var = pattern->NewNode("elementwise_add_tmpvar") + ->assert_is_op_input("elementwise_add") + ->AsInput(); + // output + auto* elementwise_add_out_var = pattern->NewNode("elementwise_add_out") + ->AsOutput() + ->assert_is_op_output("elementwise_add"); + + mul_op->LinksFrom({mul_weight_var, mul_tmp_var}).LinksTo({mul_out_var}); elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var}) .LinksTo({elementwise_add_out_var}); } @@ -120,18 +77,20 @@ bool LinksReplace(std::vector* links, Node* from, Node* to) { std::unique_ptr FCFusePass::ApplyImpl( std::unique_ptr graph) const { PADDLE_ENFORCE(graph.get()); + FusePassBase::Init("fc", graph.get()); std::unordered_set nodes2delete; GraphPatternDetector gpd; BuildFCPattern(gpd.mutable_pattern()); -#define GET_NODE(id) \ - PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \ - "pattern has no Node called %s", #id); \ - auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \ +#define GET_NODE(id) \ + PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode(#id)), \ + "pattern has no Node called %s", #id); \ + auto* id = subgraph.at(gpd.pattern().RetrieveNode(#id)); \ PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); + int found_fc_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { VLOG(4) << "handle FC fuse"; @@ -176,10 +135,13 @@ std::unique_ptr FCFusePass::ApplyImpl( graph->RemoveNode(mul); graph->RemoveNode(elementwise_add); graph->RemoveNode(mul_out); // tmp variable + + found_fc_count++; }; gpd(graph.get(), handler); + AddStatis(found_fc_count); return graph; } diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.h b/paddle/fluid/framework/ir/fc_fuse_pass.h index 31ed0e362f760319130135ad49fe2bb4e68e6786..6c69539d1e48268afc2435f8f73b3818d13107cd 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_fuse_pass.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" @@ -23,7 +24,7 @@ namespace ir { /* * Fuse the MUL and ELEMENTWISE_ADD to a FCOp. */ -class FCFusePass : public Pass { +class FCFusePass : public FusePassBase { public: virtual ~FCFusePass() {} diff --git a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc index 87ba417b1a43475f48380009f8e5cd84699b8e40..06286a109d01af638e74e06ccc83e2a5500663ea 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc @@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::vector& outputs) { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); - op->SetInput("Xs", inputs); - op->SetOutput("Ys", outputs); + if (type == "mul") { + op->SetInput("X", {inputs[0]}); + op->SetInput("Y", {inputs[1]}); + } else if (type == "elementwise_add") { + op->SetInput("X", inputs); + } + op->SetOutput("Out", outputs); } // a->OP0->b diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index daecf3b407c5b40c0ad6c3a75d7fbad3fe45c664..5852705b6b8d1c650faeae3dc810aac65353b459 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -36,7 +36,7 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node")); + auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node")); marked_nodes.insert(id); }; gpd(graph.get(), handler); @@ -64,9 +64,9 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( #undef GET_NODE #undef SET_IN - LOG(INFO) << "hidden_n: " << hidden_n->Name(); - LOG(INFO) << "cell: " << cell_n->Name(); - LOG(INFO) << "xx: " << xx_n->Name(); + VLOG(4) << "hidden_n: " << hidden_n->Name(); + VLOG(4) << "cell: " << cell_n->Name(); + VLOG(4) << "xx: " << xx_n->Name(); op_desc.SetInput("H0", {}); op_desc.SetInput("C0", {}); diff --git a/paddle/fluid/framework/ir/fuse_pass_base.h b/paddle/fluid/framework/ir/fuse_pass_base.h index bf6a0ae8274cecc785ffb269b0b574a42ee7d418..877bbeb502252cac77095981641d7ce283ca1eb7 100644 --- a/paddle/fluid/framework/ir/fuse_pass_base.h +++ b/paddle/fluid/framework/ir/fuse_pass_base.h @@ -22,21 +22,37 @@ namespace paddle { namespace framework { namespace ir { -static const char kParamScopeAttr[] = "param_scope"; +static const char kParamScopeAttr[] = "__param_scope__"; +static const char kFuseStatisAttr[] = "__fuse_statis__"; class FusePassBase : public Pass { public: - void Init(Graph* graph) const { graph_ = graph; } + void Init(const std::string& repr, Graph* graph) const { + repr_ = repr; + graph_ = graph; + } Scope* param_scope() const { PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); return graph_->Get(kParamScopeAttr); } + void AddStatis(int count_of_fused) const { + PADDLE_ENFORCE(graph_); + PADDLE_ENFORCE(!repr_.empty()); + if (!graph_->Has(kFuseStatisAttr)) { + graph_->Set(kFuseStatisAttr, new std::unordered_map); + } + auto& info = + graph_->Get>(kFuseStatisAttr); + info[repr_] = count_of_fused; + } + virtual ~FusePassBase() {} protected: mutable Graph* graph_; + mutable std::string repr_; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index dce4be8ff04204a134441410646c9a01b5dd40a3..945ab110b148c320b6626cadaa47d483df68419e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -27,6 +27,19 @@ namespace ir { size_t PDPattern::id_ = 0UL; +PDNode* PDPattern::NewNode(const std::string& name) { + if (!name.empty()) { + PADDLE_ENFORCE_EQ(node_map_.count(name), 0, + "PDNode's name should be unique, get duplicate [%s]", + name); + } + + nodes_.emplace_back(new PDNode(this, name)); + auto* cur = nodes_.back().get(); + node_map_[name] = cur; + return cur; +} + PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { if (!name.empty()) { PADDLE_ENFORCE_EQ(node_map_.count(name), 0, @@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { return cur; } -PDNode* PDPattern::RetriveNode(const std::string& id) const { +PDNode* PDPattern::RetrieveNode(const std::string& id) const { auto it = node_map_.find(id); if (it == node_map_.end()) { return nullptr; @@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph, auto subgraphs = DetectPatterns(); UniquePatterns(&subgraphs); RemoveOverlappedMatch(&subgraphs); + ValidateByNodeRole(&subgraphs); + if (subgraphs.empty()) return; LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern"; int id = 0; for (auto& g : subgraphs) { @@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { } } } + // Check to early stop if some PDNode can't find matched Node. + for (auto& pdnode : pattern_.nodes()) { + if (!pdnodes2nodes_.count(pdnode.get())) { + VLOG(4) << pdnode->name() << " can't find matched Node, early stop"; + + return false; + } + } VLOG(3) << pdnodes2nodes_.size() << " nodes marked"; return !pdnodes2nodes_.empty(); } +// The intermediate Nodes can only link to the nodes inside the pattern, or this +// subgraph will be droped. +void GraphPatternDetector::ValidateByNodeRole( + std::vector* subgraphs) { + std::vector result; + + subgraphs->erase( + std::remove_if( + subgraphs->begin(), subgraphs->end(), + [](const GraphPatternDetector::subgraph_t& subgraph) -> bool { + // Collect the inputs and outputs. + std::unordered_set ios; + for (auto& item : subgraph) { + if (!item.first->IsIntermediate()) { + ios.insert(item.second); + } + } + for (auto& item : subgraph) { + if (item.first->IsIntermediate()) { + for (auto* x : item.second->inputs) { + if (!ios.count(x)) { + return true; + } + } + for (auto* x : item.second->outputs) { + if (!ios.count(x)) { + return true; + } + } + } + } + return false; + }), + subgraphs->end()); +} + struct HitGroup { std::unordered_map roles; @@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() { // in edges of PDNodes. for (const auto& edge : pattern_.edges()) { VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name(); + // TODO(Superjomn) Fix bug here, the groups might be duplicate here. // Each role has two PDNodes, which indicates two roles. // Detect two Nodes that can match these two roles and they are connected. auto& pre_groups = bi_records[step % 2]; @@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() { // source -> target for (Node* source : pdnodes2nodes_[edge.first]) { for (Node* target : pdnodes2nodes_[edge.second]) { + VLOG(8) << "check " << source->id() << " -- " << target->id(); // TODO(Superjomn) add some prune strategies. for (const auto& group : pre_groups) { HitGroup new_group = group; @@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() { } } VLOG(3) << "step " << step << " get records: " << cur_groups.size(); + for (auto& group : cur_groups) { + for (auto& item : group.roles) { + VLOG(4) << "node " << item.second->id() << " as " << item.first->name(); + } + VLOG(4) << "========================================================="; + } } for (auto& group : bi_records[step % 2]) { @@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector& others) { return *this; } +PDNode* PDNode::assert_is_op() { + asserts_.emplace_back([this](Node* x) { return x && x->IsOp(); }); + return this; +} +PDNode* PDNode::assert_is_op(const std::string& op_type) { + asserts_.emplace_back([this, op_type](Node* x) { + return x && x->IsOp() && x->Op()->Type() == op_type; + }); + return this; +} +PDNode* PDNode::assert_is_var() { + asserts_.emplace_back([this](Node* x) { return x && x->IsVar(); }); + return this; +} +PDNode* PDNode::assert_var_not_persistable() { + assert_is_var(); + asserts_.emplace_back([this](Node* x) { return !x->Var()->Persistable(); }); + return this; +} +PDNode* PDNode::assert_is_persistable_var() { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { return x->Var()->Persistable(); }); + return this; +} +PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type, + const std::string& argument, int nth) { + assert_is_var(); + assert_is_op_input(op_type); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->outputs) { + if (IsNthInput(x, op, argument, nth)) return true; + } + return false; + }); + return this; +} +PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type, + const std::string& argument, int nth) { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->inputs) { + if (IsNthOutput(x, op, argument, nth)) return true; + } + return false; + }); + return this; +} +PDNode* PDNode::assert_is_only_input_of_op(const std::string& op_type) { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->outputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type && + op->inputs.size() == 1) { + return true; + } + } + return false; + }); + return this; +} +PDNode* PDNode::assert_is_only_output_of_op(const std::string& op_type) { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->inputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type && + op->outputs.size() == 1) { + return true; + } + } + return false; + }); + return this; +} +PDNode* PDNode::assert_is_op_output(const std::string& op_type) { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->inputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { + return true; + } + } + return false; + }); + return this; +} +PDNode* PDNode::assert_is_op_input(const std::string& op_type) { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->outputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { + return true; + } + } + return false; + }); + return this; +} +PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) { + assert_is_op(op_type); + asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; }); + return this; +} +PDNode* PDNode::assert_op_has_n_outputs(const std::string& op_type, size_t n) { + assert_is_op(op_type); + asserts_.emplace_back([=](Node* x) { return x->outputs.size() == n; }); + return this; +} +PDNode* PDNode::assert_more(PDNode::teller_t&& teller) { + asserts_.emplace_back(std::move(teller)); + return this; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 0ac34a57aacdc4fcd3d6bcaa0b72b1d6dabb3abd..f8488c84962d1caa6e7817b3c0349d6da3a59182 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -39,14 +39,24 @@ struct PDNode { // tell whether an ir::Node* is a candidation for a PDNode. using teller_t = std::function; enum class Type { kOp, kVar }; + enum class Role { + kUnknown, // No role, + kInput, // an input and will be retained, + kOutput, // an output and will be retained, + kIntermediate // will be removed after handler. + }; // this link to others PDNode& LinksTo(const std::vector& others); PDNode& LinksFrom(const std::vector& others); bool Tell(Node* node) const { - PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); - return teller_(node); + if (teller_) return teller_(node); + + for (auto& asrt : asserts_) { + if (!asrt(node)) return false; + } + return true; } bool IsOp() const { return type_ == Type::kOp; } @@ -54,10 +64,52 @@ struct PDNode { const std::string& name() const { return name_; } - PDNode(const PDNode&) = delete; PDNode& operator=(const PDNode&) = delete; + PDNode(const PDNode&) = delete; + + // Mark this node is an Input of a subgraph and will be retained. + PDNode* AsInput() { + role_ = Role::kInput; + return this; + } + // Mark this node is an Output of a subgraph and will be retained. + PDNode* AsOutput() { + role_ = Role::kOutput; + return this; + } + // Mark this node will be removed, so all the links should be inside a matched + // sub-graph. + PDNode* AsIntermediate() { + role_ = Role::kIntermediate; + return this; + } + + bool IsIntermediate() const { return role_ == Role::kIntermediate; } + bool IsInput() const { return role_ == Role::kInput; } + bool IsOutput() const { return role_ == Role::kOutput; } + + // Assertions, helper functions to simplify the pattern definition. + PDNode* assert_is_op(); + PDNode* assert_is_op(const std::string& op_type); + PDNode* assert_is_var(); + PDNode* assert_var_not_persistable(); + PDNode* assert_is_persistable_var(); + PDNode* assert_is_op_output(const std::string& op_type); + PDNode* assert_is_op_input(const std::string& op_type); + PDNode* assert_is_op_nth_input(const std::string& op_type, + const std::string& argument, int nth); + PDNode* assert_is_op_nth_output(const std::string& op_type, + const std::string& argument, int nth); + PDNode* assert_is_only_input_of_op(const std::string& op_type); + PDNode* assert_is_only_output_of_op(const std::string& op_type); + PDNode* assert_op_has_n_inputs(const std::string& op_type, size_t n); + PDNode* assert_op_has_n_outputs(const std::string& op_type, size_t n); + PDNode* assert_more(teller_t&& teller); private: + PDNode(PDPattern* pattern, const std::string& name = "", + Type type = Type::kVar) + : pattern_(pattern), name_(name), type_(type) {} PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "", Type type = Type::kVar) : teller_(std::move(teller)), @@ -71,10 +123,13 @@ struct PDNode { friend class PDPattern; + // Will removed latter. teller_t teller_; + std::vector asserts_; PDPattern* pattern_; std::string name_; Type type_; + Role role_{Role::kUnknown}; }; /* @@ -87,19 +142,18 @@ struct PDNode { * This pattern can be defined as with the following pseudo codes * * // Create two operator PDNodes. - * MUL = PDPattern.NewNode() - * ELE = PDPattern.NewNode() + * MUL = PDPattern.NewNode().assert_is_op("mul"); + * ELE = PDPattern.NewNode().assert_is_op("elementwise_add"); * // Create the variable PDNodes. - * MUL_out = PDPattern.NewNode() - * // Add teller to define some rules that help to filter the target Nodes. - * MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul"; - * ELE.teller = lambda(node): \ - * node->IsOp() && node->Op()->Type == "elementwise_add"; - * MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs) - * && (ELE in node->outputs) + * MUL_out = PDPattern.NewNode().assert_is_op_output("mul") \ + * .assert_is_op_input("elementwise_add") \ + * .AsIntermediate(); + * // Add relations. + * MUL->LinksTo({MUL_out}); + * MUL_out->LinksTo({ELE}); * - * One can add more specific tellers for PDNodes or edges, both the Operator - * and Variable Nodes can be ruled in PDNode.teller. + * One can add more specific asserts for PDNodes or edges, both the Operator + * and Variable Nodes can be ruled in PDNode.assert_more(...). * * PDPattern can record the general patterns, such as the pattern represents * - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place. @@ -112,7 +166,8 @@ class PDPattern { void AddEdge(PDNode* a, PDNode* b); PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID()); - PDNode* RetriveNode(const std::string& id) const; + PDNode* NewNode(const std::string& name = NewID()); + PDNode* RetrieveNode(const std::string& id) const; const std::vector>& nodes() const { return nodes_; } const std::vector& edges() const { return edges_; } @@ -185,6 +240,9 @@ class GraphPatternDetector { // Remove overlapped match subgraphs, when overlapped, keep the previous one. void RemoveOverlappedMatch(std::vector* subgraphs); + // Validate whether the intermediate nodes are linked by external nodes. + void ValidateByNodeRole(std::vector* subgraphs); + #ifdef PADDLE_WITH_TESTING FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph); FRIEND_TEST(GraphPatternDetecter, DetectPatterns); @@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument, return var->Name() == op->Op()->Input(argument)[nth]; } +static bool IsNthOutput(Node* var, Node* op, const std::string& argument, + size_t nth) { + PADDLE_ENFORCE(var->IsVar()); + PADDLE_ENFORCE(op->IsOp()); + if (op->inputs.size() <= nth) return false; + return var->Name() == op->Op()->Output(argument)[nth]; +} + static void GraphSafeRemoveNodes(Graph* graph, const std::unordered_set& nodes) { for (auto* node : nodes) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc index a4d0646230c0fdfb7e1970523799e7db10c75538..7e5c86b033a7c69a306491cf4bf8d099018c5f19 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc @@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ASSERT_LE(count, 2); } +TEST(GraphPatternDetector, IntermediateCheck) { + ProgramDesc program; + Graph graph(program); + BuildGraph(&graph); + + // o2->v2->o3 + // o2->v2->o4 + // check o2+o3 fuse, should fail because v2 also link to o4. + GraphPatternDetector detector; + auto* op2 = detector.mutable_pattern()->NewNode( + [](Node* x) { return x && x->IsOp() && x->Name() == "op2"; }, "op2"); + auto* op3 = detector.mutable_pattern()->NewNode( + [](Node* x) { return x && x->IsOp() && x->Name() == "op3"; }, "op3"); + auto* v2 = + detector.mutable_pattern() + ->NewNode( + [](Node* x) { return x && x->IsVar() && x->Name() == "var2"; }, + "var2") + ->AsIntermediate(); + v2->LinksFrom({op2}).LinksTo({op3}); + + int count = 0; + detector(&graph, [&](const GraphPatternDetector::subgraph_t& g, + Graph* graph) { ++count; }); + EXPECT_EQ(count, 0); + + count = 0; + v2->AsInput(); + detector(&graph, [&](const GraphPatternDetector::subgraph_t& g, + Graph* graph) { ++count; }); + ASSERT_EQ(count, 1); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc index 3a114c6a237ea4411a8c4dd4b3ee6a00b7729d7c..4c7ffe69e933de3d52c8f762a1eeb73de17e0561 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.cc +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -16,13 +16,27 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/inference/analysis/dot.h" +#include "paddle/fluid/string/printf.h" namespace paddle { namespace framework { namespace ir { -static const char kGraphVizPath[] = "graph_viz_path"; using inference::analysis::Dot; +namespace { +const char kGraphVizPath[] = "graph_viz_path"; + +std::string FormatName(const Node* node) { + if (!node->IsOp() || !node->Op() || + !node->Op()->HasAttr(OpProtoAndCheckerMaker::OpNamescopeAttrName())) { + return node->Name(); + } + const std::string full_scope = boost::get( + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpNamescopeAttrName())); + return string::Sprintf("%s%s", full_scope.c_str(), node->Name().c_str()); +} +} // namespace std::unique_ptr GraphVizPass::ApplyImpl( std::unique_ptr graph) const { @@ -54,7 +68,7 @@ std::unique_ptr GraphVizPass::ApplyImpl( auto marked_nodes = ConsumeMarkedNodes(graph.get()); // Create nodes for (const Node* n : graph->Nodes()) { - std::string node_id = n->Name() + "(" + std::to_string(n->id()) + ")"; + std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")"; if (n->IsOp()) { decltype(op_attrs) attr = marked_nodes.count(n) ? marked_op_attrs : op_attrs; diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 79ec70a1039b8719e7f7e6845f3bb083372ccfa9..d53d789d3ad27b8f9606a396264d91e5f07a9d10 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -55,11 +55,11 @@ class Node { std::string Name() const { return name_; } VarDesc* Var() { - PADDLE_ENFORCE(type_ == Type::kVariable); + PADDLE_ENFORCE(IsVar()); return var_desc_.get(); } - OpDesc* Op() { + OpDesc* Op() const { PADDLE_ENFORCE(IsOp()); return op_desc_.get(); } diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc index 9bb5c232e5c2269643ddef7ed9c938e0332f7274..a776a898a5ee13b4dde12460dce71433268fb9d4 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc @@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { std::unique_ptr SeqConcatFcFusePass::ApplyImpl( std::unique_ptr graph) const { - FusePassBase::Init(graph.get()); + FusePassBase::Init("seq_concat_fc_fuse", graph.get()); GraphPatternDetector detector; auto* pattern = detector.mutable_pattern(); auto* concat_out = BuildSeqExpandConcatPattern(pattern); BuildFCPattern(pattern, concat_out); -#define GET_NODE(id, pattern) \ - PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \ - "pattern has no Node called %s", #id); \ - auto* id = subgraph.at(pattern.RetriveNode(#id)); \ +#define GET_NODE(id, pattern) \ + PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \ + "pattern has no Node called %s", #id); \ + auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 2288c7fe6609a765612b468d69ad35101b92b384..4fa047bf3ee3d06ac4aec5d2cc6a355965836d42 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -129,6 +129,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, "Optimized for variable") .SetDefault({}); + AddAttr(OpNamescopeAttrName(), "Operator name with namesope.") + .SetDefault(""); + Validate(); } diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 80970291c9c234f1306162f4ffa3c2528f88c35f..18827385ad659922230ff68709a2926a8c9013ac 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -39,6 +39,7 @@ class OpProtoAndCheckerMaker { public: static const char *OpRoleAttrName() { return "op_role"; } static const char *OpRoleVarAttrName() { return "op_role_var"; } + static const char *OpNamescopeAttrName() { return "op_namescope"; } void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker); diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 08a55a73e3a318cd8cfe25c64ad2ff6955b7e445..e6e63544ffa2de09e39b02769aaaf0793d6b1111 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager { void AddGraphvizDebugerPass(Pass* pass) { auto* debuger_pass = pass->CreateGraphvizDebugerPass(); if (debuger_pass) { - LOG(INFO) << " - register debug pass [" << debuger_pass->repr() << "]"; Register(debuger_pass->repr(), debuger_pass); } } diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index cccd6b55ad493f9cb0eeedeab02c1a3970a55fb5..1a65e85dd237eb1bacd3c15b4538a9835ec4b9e0 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -16,10 +16,13 @@ #include #include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/inference/analysis/ut_helper.h" +#include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/platform/profiler.h" DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); @@ -31,6 +34,8 @@ namespace paddle { namespace inference { namespace analysis { +using namespace framework; + TEST(Analyzer, analysis_without_tensorrt) { FLAGS_IA_enable_tensorrt_subgraph_engine = false; Argument argument; @@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path, EXPECT_NEAR(data[i], base_data[i], 1e-3); } } + + if (use_analysis && activate_ir) { + AnalysisPredictor *analysis_predictor = + dynamic_cast(predictor.get()); + auto &fuse_statis = analysis_predictor->analysis_argument() + .Get>( + framework::ir::kFuseStatisAttr); + for (auto &item : fuse_statis) { + LOG(INFO) << "fused " << item.first << " " << item.second; + } + + ASSERT_TRUE(fuse_statis.count("fc")); + EXPECT_EQ(fuse_statis.at("fc"), 1); + } } // Directly infer with the original model. diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 4401d5c5a3ca8da1c04336de4be8397334d46d9e..3a4ffe967e67ab0487192bbf12d4d5a15f536aa3 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -64,7 +64,8 @@ struct Argument { template void Set(const std::string& key, T* data) { PADDLE_ENFORCE_NOT_NULL(data); - PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key); + PADDLE_ENFORCE(!attrs_.count(key), "Duplicate set Argument's attr [%s]", + key); attrs_[key] = data; attr_deleters_[key] = [data, key, this]() { VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 8ca402da31f52f1a68a04b5de368c9c659a3a108..80c85555e722433f3657e880520b3fe459f6ce1a 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include #include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/inference/analysis/analyzer.h" @@ -34,7 +35,6 @@ std::vector ExtractParameters( bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { ANALYSIS_ARGUMENT_CHECK_FIELD(argument) ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc) - PADDLE_ENFORCE(!argument->transformed_program_desc); // The transformed_program_desc should inherit all the VarDesc and BlockDesc // from the original program desc. The operators of the main block(the first // block) should rewritten by data flow graph. @@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { } } - if (argument_->Has("param_scope")) { + if (argument_->Has(framework::ir::kParamScopeAttr)) { LOG(WARNING) << "parameter changes in the scope takes effect"; } diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc b/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc index 5e53fff39213b53bc78e9272a7efd26d7ee91023..fc60ca3bd0bf706407defb2655a093d999aef7c2 100644 --- a/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h" #include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" @@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir, const std::string &prog_file, const std::string ¶m_file) { PADDLE_ENFORCE(argument_); - argument_->Set("param_scope", new framework::Scope); + argument_->Set(framework::ir::kParamScopeAttr, new framework::Scope); // Load parameters. VLOG(3) << "Loading parameters from " << model_dir; - LoadParams(&argument_->Get("param_scope"), model_dir, - prog_file, param_file); + LoadParams(&argument_->Get(framework::ir::kParamScopeAttr), + model_dir, prog_file, param_file); } bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir, diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h index 29008105f82989f5797116e78990853880708936..6731b1f759363eec5dd8645783212a72ace67b2f 100644 --- a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h +++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h @@ -14,12 +14,14 @@ #pragma once +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/pass.h" namespace paddle { namespace inference { namespace analysis { +using namespace framework; static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__"; @@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass { ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path); // Load program. auto program = LoadProgramDesc(*argument->fluid_model_program_path); - argument->origin_program_desc.reset( - new framework::proto::ProgramDesc(program)); + argument->origin_program_desc.reset(new proto::ProgramDesc(program)); // Create main data flow graph. if (!argument->main_dfg) { argument->main_dfg.reset(new DataFlowGraph); } - argument->Set("ir_program_desc", new framework::ProgramDesc(program)); + argument->Set("ir_program_desc", new ProgramDesc(program)); LOG(INFO) << "Loading parameters"; // Load parameters to argument if needed. @@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass { void Run(DataFlowGraph *graph) override { // Call all the IR Passes - IRPassManager ir_passes( - argument_->Get("ir_program_desc"), nullptr); + IRPassManager ir_passes(argument_->Get("ir_program_desc"), + nullptr); // Pass the scope from analysis to IR if needed. - if (argument_->Has("param_scope")) { + if (argument_->Has(ir::kParamScopeAttr)) { // Here the address is passed, attention that IR doesn't own the scope, so // the real scope in analysis should live during the IR phase. ir_passes.graph().Set( - "param_scope", new framework::Scope *( - &argument_->Get("param_scope"))); + ir::kParamScopeAttr, + new Scope *(&argument_->Get(ir::kParamScopeAttr))); } const auto &ir_passes_to_apply = @@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass { PADDLE_ENFORCE(argument_->main_dfg.get()); argument_->main_dfg->Build(ir_passes.graph()); + // inherit the arguments from ir. + if (ir_passes.graph().Has(ir::kFuseStatisAttr)) { + argument_->Set( + ir::kFuseStatisAttr, + new std::unordered_map( + ir_passes.graph().Get>( + ir::kFuseStatisAttr))); + } } void EnableParamModify(const std::string &model_dir, @@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass { private: // Load parameters from a single file or from a directory. - bool LoadParams(framework::Scope *scope, const std::string &dir, + bool LoadParams(Scope *scope, const std::string &dir, const std::string &prog_file, const std::string ¶m_file); private: diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 5da5241e49a2f7c8c0951e1a3c31784b8af65134..ea0f2241d7dbab8f79ec9349effbe96112748e34 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/scope.h" @@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program, framework::Scope *scope) : program_(program) { graph_.reset(new framework::ir::Graph(program)); - if (scope) graph_->Set("param_scope", new framework::Scope *(scope)); + if (scope) + graph_->Set(framework::ir::kParamScopeAttr, new framework::Scope *(scope)); } void IRPassManager::Apply(const std::vector &passes) { diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7cdaf4c6a232f2cdfdd4fa27797de632bbe9c560..33862232bdaae817b9ca72879605386c32ed3e8b 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -12,121 +12,96 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/inference/api/analysis_predictor.h" #include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/inference/analysis/analyzer.h" -#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/utils/singleton.h" namespace paddle { -using inference::analysis::Argument; -using inference::Singleton; -using inference::analysis::Analyzer; -using framework::proto::ProgramDesc; - -/* This predictor is based on the original native predictor with IR and Analysis - * support. It will optimize IR and Parameters in the runtime. - * TODO(Superjomn) Replace the Navive predictor? - */ -class AnalysisPredictor : public NativePaddlePredictor { - public: - explicit AnalysisPredictor(const NativeConfig& config) - : NativePaddlePredictor(config), config_(config) {} - - bool Init(const std::shared_ptr& parent_scope) { - VLOG(3) << "Predictor::init()"; - if (config_.use_gpu) { - place_ = paddle::platform::CUDAPlace(config_.device); - } else { - place_ = paddle::platform::CPUPlace(); - } - PADDLE_ENFORCE(!parent_scope); - if (parent_scope) { - scope_ = parent_scope; - sub_scope_ = &(parent_scope->NewScope()); - } else { - paddle::framework::InitDevices(false); - scope_.reset(new paddle::framework::Scope()); - } - - executor_.reset(new paddle::framework::Executor(place_)); - - // Initialize the inference program - if (!config_.model_dir.empty()) { - // Parameters are saved in separate files sited in - // the specified `dirname`. - inference_program_ = paddle::inference::Load( - executor_.get(), scope_.get(), config_.model_dir); - } else if (!config_.prog_file.empty() && !config_.param_file.empty()) { - // All parameters are saved in a single file. - // The file names should be consistent with that used - // in Python API `fluid.io.save_inference_model`. - inference_program_ = paddle::inference::Load( - executor_.get(), scope_.get(), config_.prog_file, config_.param_file); - } else { - LOG(ERROR) << "fail to load inference model."; - return false; - } - - OptimizeInferenceProgram(); - ctx_ = executor_->Prepare(*inference_program_, 0); - - VLOG(5) << "to create variables"; - PADDLE_ENFORCE(scope_.get()); - executor_->CreateVariables(*inference_program_, - sub_scope_ ? sub_scope_ : scope_.get(), 0); - - // Get the feed_target_names and fetch_target_names - PrepareFeedFetch(); - return true; +bool AnalysisPredictor::Init( + const std::shared_ptr& parent_scope) { + VLOG(3) << "Predictor::init()"; + if (config_.use_gpu) { + place_ = paddle::platform::CUDAPlace(config_.device); + } else { + place_ = paddle::platform::CPUPlace(); } - - bool Run(const std::vector& inputs, - std::vector* output_data, - int batch_size = -1) override { - return NativePaddlePredictor::Run(inputs, output_data, batch_size); + PADDLE_ENFORCE(!parent_scope); + if (parent_scope) { + scope_ = parent_scope; + sub_scope_ = &(parent_scope->NewScope()); + } else { + paddle::framework::InitDevices(false); + scope_.reset(new paddle::framework::Scope()); } - void OptimizeInferenceProgram() { - LOG(INFO) << "optimize begin"; - FLAGS_IA_enable_ir = true; - FLAGS_IA_enable_tensorrt_subgraph_engine = false; - FLAGS_IA_output_storage_path = ""; // Don't output the model. - // Analyze inference_program - Argument argument; - if (!config_.model_dir.empty()) { - argument.fluid_model_dir.reset(new std::string(config_.model_dir)); - } else { - PADDLE_ENFORCE( - !config_.param_file.empty(), - "Either model_dir or (param_file, prog_file) should be set."); - PADDLE_ENFORCE(!config_.prog_file.empty()); - argument.fluid_model_program_path.reset( - new std::string(config_.prog_file)); - argument.fluid_model_param_path.reset( - new std::string(config_.param_file)); - } - argument.origin_program_desc.reset( - new ProgramDesc(*inference_program_->Proto())); - Singleton::Global().Run(&argument); - CHECK(argument.transformed_program_desc); - VLOG(5) << "to prepare executor"; - // LOG(INFO) << "transformed_parogram_desc " << - // argument.transformed_program_desc->DebugString(); - inference_program_.reset( - new framework::ProgramDesc(*argument.transformed_program_desc)); - PADDLE_ENFORCE(argument.Has("param_scope")); - // Update scope. - scope_.reset(argument.Release("param_scope")); - LOG(INFO) << "optimize end =="; + executor_.reset(new paddle::framework::Executor(place_)); + + // Initialize the inference program + if (!config_.model_dir.empty()) { + // Parameters are saved in separate files sited in + // the specified `dirname`. + inference_program_ = paddle::inference::Load(executor_.get(), scope_.get(), + config_.model_dir); + } else if (!config_.prog_file.empty() && !config_.param_file.empty()) { + // All parameters are saved in a single file. + // The file names should be consistent with that used + // in Python API `fluid.io.save_inference_model`. + inference_program_ = paddle::inference::Load( + executor_.get(), scope_.get(), config_.prog_file, config_.param_file); + } else { + LOG(ERROR) << "fail to load inference model."; + return false; } - private: - NativeConfig config_; -}; + OptimizeInferenceProgram(); + ctx_ = executor_->Prepare(*inference_program_, 0); + + VLOG(5) << "to create variables"; + PADDLE_ENFORCE(scope_.get()); + executor_->CreateVariables(*inference_program_, + sub_scope_ ? sub_scope_ : scope_.get(), 0); + // Get the feed_target_names and fetch_target_names + PrepareFeedFetch(); + return true; +} + +void AnalysisPredictor::OptimizeInferenceProgram() { + LOG(INFO) << "optimize begin"; + FLAGS_IA_enable_ir = true; + FLAGS_IA_enable_tensorrt_subgraph_engine = false; + FLAGS_IA_output_storage_path = ""; // Don't output the model. + // Analyze inference_program + if (!config_.model_dir.empty()) { + argument_.fluid_model_dir.reset(new std::string(config_.model_dir)); + } else { + PADDLE_ENFORCE( + !config_.param_file.empty(), + "Either model_dir or (param_file, prog_file) should be set."); + PADDLE_ENFORCE(!config_.prog_file.empty()); + argument_.fluid_model_program_path.reset( + new std::string(config_.prog_file)); + argument_.fluid_model_param_path.reset(new std::string(config_.param_file)); + } + argument_.origin_program_desc.reset( + new ProgramDesc(*inference_program_->Proto())); + Analyzer().Run(&argument_); + CHECK(argument_.transformed_program_desc); + VLOG(5) << "to prepare executor"; + // LOG(INFO) << "transformed_parogram_desc " << + // argument.transformed_program_desc->DebugString(); + inference_program_.reset( + new framework::ProgramDesc(*argument_.transformed_program_desc)); + PADDLE_ENFORCE(argument_.Has(framework::ir::kParamScopeAttr)); + // Update scope. + scope_.reset( + argument_.Release(framework::ir::kParamScopeAttr)); + LOG(INFO) << "optimize end =="; +} template <> std::unique_ptr CreatePaddlePredictor< diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h new file mode 100644 index 0000000000000000000000000000000000000000..e32b6185f6044ab3577bde0a8f8dcf2391688aa8 --- /dev/null +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/analyzer.h" +#include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/inference/api/paddle_inference_api.h" + +namespace paddle { + +using inference::analysis::Argument; +using inference::analysis::Analyzer; +using framework::proto::ProgramDesc; + +/* This predictor is based on the original native predictor with IR and Analysis + * support. It will optimize IR and Parameters in the runtime. + * TODO(Superjomn) Replace the Navive predictor? + */ +class AnalysisPredictor : public NativePaddlePredictor { + public: + explicit AnalysisPredictor(const NativeConfig& config) + : NativePaddlePredictor(config), config_(config) {} + + bool Init(const std::shared_ptr& parent_scope); + + bool Run(const std::vector& inputs, + std::vector* output_data, + int batch_size = -1) override { + return NativePaddlePredictor::Run(inputs, output_data, batch_size); + } + + void OptimizeInferenceProgram(); + + Argument& analysis_argument() { return argument_; } + + private: + NativeConfig config_; + Argument argument_; +}; + +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 88e0383146c1adf2752a362091996bad9cfcce5e..b97dad20db0b003b4886b7c7cfd1c8de8bf44ab9 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -24,7 +24,7 @@ limitations under the License. */ #endif #ifdef PADDLE_WITH_MKLDNN -#include +#include "mkldnn.hpp" #endif #include diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index e4415ed15c791100a5b309e73d7deb5943f71b97..f577068d1f39a3083a54f106d006f9982304411e 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -43,6 +43,9 @@ void BindConstValue(pybind11::module* m) { op_proto_and_checker_maker.def( "kOpRoleVarAttrName", framework::OpProtoAndCheckerMaker::OpRoleVarAttrName); + op_proto_and_checker_maker.def( + "kOpNameScopeAttrName", + framework::OpProtoAndCheckerMaker::OpNamescopeAttrName); } } // namespace pybind diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fbe766336b19719b7a3eac41ad5e877ef3ec4181..b0e0d27ff7a0c603523065d34169b1b73eabdac3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -43,6 +43,7 @@ __all__ = [ 'default_main_program', 'program_guard', 'get_var', + 'name_scope', ] EMPTY_VAR_NAME = core.kEmptyVarName() @@ -52,6 +53,70 @@ ZERO_VAR_SUFFIX = core.kZeroVarSuffix() CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() +class NameScope(object): + def __init__(self, name="", parent=None): + self._children = dict() + self._name = name + self._parent = parent + + def child(self, prefix): + if prefix not in self._children: + new_child = NameScope(prefix, self) + self._children[prefix] = [new_child] + else: + new_child = NameScope(prefix + "_%d" % len(self._children[prefix]), + self) + self._children[prefix].append(new_child) + return new_child + + def parent(self): + return self._parent + + def name(self): + return self._name + + +_name_scope = NameScope() + + +@contextlib.contextmanager +def name_scope(prefix=None): + """ + Generate hierarchical name prefix for the operators. + + Note: This should only used for debugging and visualization purpose. + Don't use it for serious analysis such as graph/program transformations. + + Args: + prefix(str): prefix. + + Examples: + .. code-block:: python + with name_scope("encoder"): + ... + with name_scope("decoder"): + ... + with name_scope("attention"): + ... + """ + # TODO(panyx0718): Only [0-9a-z]. + assert prefix, "namescope prefix cannot be empty." + global _name_scope + _name_scope = _name_scope.child(prefix) + yield + _name_scope = _name_scope.parent() + + +def _full_name_scope(): + global _name_scope + scope = _name_scope + name = "" + while scope: + name = scope.name() + "/" + name + scope = scope.parent() + return name + + def generate_control_dev_var_name(): import random return CONTROL_DEP_VAR_PREFIX + "@" + str(random.random()) @@ -515,6 +580,9 @@ class Operator(object): self.desc.set_type(type) proto = OpProtoHolder.instance().get_op_proto(type) + namescope_var_name = op_maker.kOpNameScopeAttrName() + op_attrs[namescope_var_name] = _full_name_scope() + def find_name(var_list, name): for var_name in var_list: if var_list[var_name] is not None and var_name == name: diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 6b9749a5799ecc0b26babbf088614d6b5de2a5dd..33d6311b9717c66f0d6782eb6b3e348cd4c02a69 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -15,7 +15,7 @@ from __future__ import print_function import re from collections import defaultdict -from paddle.fluid.framework import Program, Variable +from paddle.fluid.framework import Program, Variable, name_scope from . import framework from . import layers from .backward import append_backward @@ -237,7 +237,7 @@ class Optimizer(object): if param_and_grad[1] is None: continue with param_and_grad[0].block.program.optimized_guard( - param_and_grad): + param_and_grad), name_scope("optimizer"): if param_and_grad[0].trainable is True: optimize_op = self._append_optimize_op(loss.block, param_and_grad) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index b9387ae9d83f36a491414764619b86e39368d266..58875a1dd19fd91f6f2bed928397ee7f73302dff 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -82,8 +82,18 @@ class TestDistRunnerBase(object): strategy = fluid.ExecutionStrategy() strategy.num_threads = 1 strategy.allow_op_delay = False + build_stra = fluid.BuildStrategy() + + if args.use_reduce: + build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce + else: + build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce + exe = fluid.ParallelExecutor( - True, loss_name=avg_cost.name, exec_strategy=strategy) + True, + loss_name=avg_cost.name, + exec_strategy=strategy, + build_strategy=build_stra) feed_var_list = [ var for var in trainer_prog.global_block().vars.values() @@ -123,6 +133,7 @@ def runtime_main(test_class): '--current_endpoint', type=str, required=False, default="") parser.add_argument('--sync_mode', action='store_true') parser.add_argument('--mem_opt', action='store_true') + parser.add_argument('--use_reduce', action='store_true') args = parser.parse_args() @@ -149,20 +160,25 @@ class TestDistBase(unittest.TestCase): self._python_interp = "python" self._sync_mode = True self._mem_opt = False + self._use_reduce = False self._setup_config() def start_pserver(self, model_file, check_error_log): - ps0_ep, ps1_ep = self._ps_endpoints.split(",") - ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist %s %s" - sync_mode_str = "--sync_mode" if self._sync_mode else "" - mem_opt_str = "--mem_opt" if self._mem_opt else "" + ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist" ps0_cmd = ps_cmd % \ (self._python_interp, model_file, self._ps_endpoints, ps0_ep, - self._trainers, sync_mode_str, mem_opt_str) + self._trainers) ps1_cmd = ps_cmd % \ (self._python_interp, model_file, self._ps_endpoints, ps1_ep, - self._trainers, sync_mode_str, mem_opt_str) + self._trainers) + + if self._sync_mode: + ps0_cmd += " --sync_mode" + ps1_cmd += " --sync_mode" + if self._mem_opt: + ps0_cmd += " --mem_opt" + ps1_cmd += " --mem_opt" ps0_pipe = subprocess.PIPE ps1_pipe = subprocess.PIPE @@ -242,17 +258,23 @@ class TestDistBase(unittest.TestCase): self._wait_ps_ready(ps1.pid) ps0_ep, ps1_ep = self._ps_endpoints.split(",") - tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist %s %s" - sync_mode_str = "--sync_mode" if self._sync_mode else "" - mem_opt_str = "--mem_opt" if self._mem_opt else "" + tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist" tr0_cmd = tr_cmd % \ (self._python_interp, model_file, self._ps_endpoints, - 0, ps0_ep, - self._trainers, sync_mode_str, mem_opt_str) + 0, ps0_ep, self._trainers) tr1_cmd = tr_cmd % \ (self._python_interp, model_file, self._ps_endpoints, - 1, ps1_ep, - self._trainers, sync_mode_str, mem_opt_str) + 1, ps1_ep, self._trainers) + + if self._sync_mode: + tr0_cmd += " --sync_mode" + tr1_cmd += " --sync_mode" + if self._mem_opt: + tr0_cmd += " --mem_opt" + tr1_cmd += " --mem_opt" + if self._use_reduce: + tr0_cmd += " --use_reduce" + tr1_cmd += " --use_reduce" env0 = {"CUDA_VISIBLE_DEVICES": "0"} env1 = {"CUDA_VISIBLE_DEVICES": "1"} @@ -303,6 +325,8 @@ class TestDistBase(unittest.TestCase): # FIXME: use terminate() instead of sigkill. os.kill(ps0.pid, signal.SIGKILL) os.kill(ps1.pid, signal.SIGKILL) + ps0.terminate() + ps1.terminate() ps0.wait() ps1.wait() FNULL.close() diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 157243df47189bddd494e5d533fdc34a28100c57..59a137c18c9435ef5c5772d0cc08f197c1d86603 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -20,6 +20,7 @@ from test_dist_base import TestDistBase class TestDistMnist2x2(TestDistBase): def _setup_config(self): self._sync_mode = True + self._use_reduce = False def test_se_resnext(self): self.check_with_place("dist_mnist.py", delta=1e-7) @@ -37,10 +38,30 @@ class TestDistMnist2x2WithMemopt(TestDistBase): class TestDistMnistAsync(TestDistBase): def _setup_config(self): self._sync_mode = False + self._use_reduce = False def test_se_resnext(self): self.check_with_place("dist_mnist.py", delta=200) +# FIXME(typhoonzero): enable these tests once we have 4 +# 4 GPUs on CI machine, and the base class should be updated. +# +# class TestDistMnist2x2ReduceMode(TestDistBase): +# def _setup_config(self): +# self._sync_mode = True +# self._use_reduce = True + +# def test_se_resnext(self): +# self.check_with_place("dist_mnist.py", delta=1e-7) + +# class TestDistMnistAsyncReduceMode(TestDistBase): +# def _setup_config(self): +# self._sync_mode = False +# self._use_reduce = True + +# def test_se_resnext(self): +# self.check_with_place("dist_mnist.py", delta=200) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_name_scope.py b/python/paddle/fluid/tests/unittests/test_name_scope.py new file mode 100644 index 0000000000000000000000000000000000000000..08c802e20d2bb364ef7f116ee0042a2ad21a9b2b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_name_scope.py @@ -0,0 +1,45 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid + + +class TestNameScope(unittest.TestCase): + def test_name_scope(self): + with fluid.name_scope("s1"): + a = fluid.layers.data(name='data', shape=[1], dtype='int32') + b = a + 1 + with fluid.name_scope("s2"): + c = b * 1 + with fluid.name_scope("s3"): + d = c / 1 + with fluid.name_scope("s1"): + f = fluid.layers.pow(d, 2.0) + with fluid.name_scope("s4"): + g = f - 1 + + for op in fluid.default_main_program().block(0).ops: + if op.type == 'elementwise_add': + self.assertEqual(op.desc.attr("op_namescope"), '/s1/') + elif op.type == 'elementwise_mul': + self.assertEqual(op.desc.attr("op_namescope"), '/s1/s2/') + elif op.type == 'elementwise_div': + self.assertEqual(op.desc.attr("op_namescope"), '/s1/s3/') + elif op.type == 'elementwise_sub': + self.assertEqual(op.desc.attr("op_namescope"), '/s4/') + elif op.type == 'pow': + self.assertEqual(op.desc.attr("op_namescope"), '/s1_1/') diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index 6d01955993324498de42462b7f85ef6f8e444505..cac132e6e08a8a9ec595236b1a990c0900ea4f0f 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -67,7 +67,10 @@ class TestOperator(unittest.TestCase): self.assertEqual(mul_op.output("Out"), ["mul.out"]) self.assertEqual( set(mul_op.attr_names), - set(["x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var"])) + set([ + "x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var", + "op_namescope" + ])) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr("x_num_col_dims"), 1) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py index 5b96d641d667eee1aa0c7c6019bf92494f777259..af3745987aa3eae96968bdc6b5c9cd951e9ca6fa 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py @@ -67,18 +67,20 @@ def fc_with_batchnorm(use_feed): hidden = img for _ in range(1): - hidden = fluid.layers.fc( - hidden, - size=200, - act='tanh', - bias_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant(value=1.0))) - - hidden = fluid.layers.batch_norm(input=hidden) - - prediction = fluid.layers.fc(hidden, size=10, act='softmax') - loss = fluid.layers.cross_entropy(input=prediction, label=label) - loss = fluid.layers.mean(loss) + with fluid.name_scope("hidden"): + hidden = fluid.layers.fc( + hidden, + size=200, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + hidden = fluid.layers.batch_norm(input=hidden) + with fluid.name_scope("fc_layer"): + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + with fluid.name_scope("loss"): + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) return loss diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 4eb87b6a77e998a2d70ed6ebfb9df90c96a8dc09..a6266a7b0c9ac40eac7b2823fc7ddf38f55357a9 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -273,6 +273,10 @@ class DistributeTranspiler(object): name=framework.generate_control_dev_var_name()) grad_name_to_send_dummy_out[grad_varname] = dummy_output + # get send op_role_var, if not splited, the grad should have .trainer suffix + # if splited, grad should be the original grad var name (split_by_ref and send + # will be on the same place). ParallelExecutor + # will use op_role_var to get expected device place to run this op. program.global_block()._insert_op( index=index + 1, type="send", @@ -281,8 +285,10 @@ class DistributeTranspiler(object): attrs={ "epmap": eplist, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, - OP_ROLE_VAR_ATTR_NAME: - [self.grad_name_to_param_name[grad_varname], grad_varname], + OP_ROLE_VAR_ATTR_NAME: [ + self.grad_name_to_param_name[grad_varname], + splited_grad_varname + ], "sync_mode": not self.sync_mode, }) for _, var in enumerate(splited_vars): @@ -326,6 +332,15 @@ class DistributeTranspiler(object): recv_dep_in = grad_name_to_send_dummy_out[ self.param_name_to_grad_name[param_varname]] all_recv_outputs.extend(splited_var) + # get recv op_role_var, if not splited, the grad should have .trainer suffix + # if splited, grad should be the original grad var name. ParallelExecutor + # will use op_role_var to get expected device place to run this op. + orig_grad_name = self.param_name_to_grad_name[param_varname] + recv_op_role_var_name = orig_grad_name + splited_trainer_grad = self.grad_var_mapping[orig_grad_name] + if len(splited_trainer_grad) == 1: + recv_op_role_var_name = splited_trainer_grad[0].name + program.global_block().append_op( type="recv", inputs={"X": [recv_dep_in]}, @@ -333,10 +348,8 @@ class DistributeTranspiler(object): attrs={ "epmap": eps, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, - OP_ROLE_VAR_ATTR_NAME: [ - param_varname, - self.param_name_to_grad_name[param_varname] - ], + OP_ROLE_VAR_ATTR_NAME: + [param_varname, recv_op_role_var_name], "sync_mode": not self.sync_mode })