diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md index a33b5a9c9312c93247a1e1f3431061a5aad6c884..c29337cba1fe859e4968cb800e4e7d9ff6a54d31 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/draft.md @@ -64,6 +64,41 @@ can also contain other things that describe some properties of the `Graph` or `Graph` nodes. `Attribute` can be passed across `Pass`. However, it should be used with care. +```cpp +class Graph { + public: + explicit Graph(const ProgramDesc &program); + + bool Has(const std::string &attr_name) const; + + template + AttrType &Get(const std::string &attr_name) const; + + template + void Set(const std::string &attr_name, AttrType *attr); + const std::unordered_set &Nodes() const; + + // Create a normal variable with non-null VarDesc. + ir::Node *CreateVarNode(VarDesc *var_desc); + + // Create a normal runnable operator with OpDesc. + ir::Node *CreateOpNode(OpDesc *op_desc); + + // Create a control dependency var that connects 2 operations. The + // var doesn't hold any data. Other than that, it's no different from + // other var, considering dependency analysis. + ir::Node *CreateControlDepVar(); + + // A more free style way of creating a graph node. Mostly use for test + // or "copy" from another node. Avoid using it if possible. + ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type); + + // Clear all node information of the graph and return the ownership of the + // nodes. + std::vector> ReleaseNodes(); +}; +``` + #### Pass `Pass` represents a transformation of `Graph`. Its input @@ -71,6 +106,54 @@ is a `Graph` and its output is also a `Graph`. For example, a `Pass` can simply print out the `Graph`. A `Pass` can also fuse some `Graph`'s `Node`s. +```cpp +class Pass { + public: + + std::unique_ptr Apply(std::unique_ptr graph) const { + // Some correctness check. + auto new_graph = ApplyImpl(std::move(graph)); + // Some correctness check. + return new_graph; + } + + // Get a reference to the attributed previously set. + template + AttrType &Get(const std::string &attr_name) const; + + // Set a pointer to the attribute. Pass takes ownership of the attribute. + template + void Set(const std::string &attr_name, AttrType *attr) ; + + // Set a pointer to the attribute. Pass doesn't take ownership. Caller + // should delete the attribute. + template + void SetNotOwned(const std::string &attr_name, AttrType *attr); + + protected: + virtual std::unique_ptr ApplyImpl(std::unique_ptr graph) const = 0; +}; + +// In my_pass.cc +class MyPass : public Pass { + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const override { + // do something. + return graph; + } +} +REGISTER_PASS(my_pass, MyPass) +.RequirePassAttr("places") +.RequireGraphAttr("dep_vars"); + + +// To use the pass. +auto my_pass = ir::PassRegistry::Instance().Get("my_pass"); +graph = my_pass->Apply(std::move(graph)); +// Note: to force link my_pass.cc, in the code: +USE_PASS(my_pass); +``` + #### Optimize `Optimize` contains a series of `Pass` with defined order. @@ -86,4 +169,17 @@ maintaining the original modeling logic. * Graph is transformed from raw model logic to a form that is efficient to execute. -Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor +``` +// Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor +auto graph = Graph(program); +graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah)); +// For more complex Pass, Optimize Process can provide Pass attributes. +auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass"); +mem_opt_pass.SetNotOwned("optimize_level", 1); +mem_opt_pass->Apply(std::move(graph)); +graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah)); +graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah)); +Executor exe; +exe.Run(graph); + +``` diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6efb03dabe89b28f3ff1a55c4a940dfe74e8001d..5f3bfa296546fcbc6a3410d7ae072ff74954bc74 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -170,6 +170,7 @@ paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], var paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) +paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) @@ -201,7 +202,6 @@ paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs= paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.While.complete ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Switch.default ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) @@ -225,17 +225,14 @@ paddle.fluid.layers.DynamicRNN.static_input ArgSpec(args=['self', 'x'], varargs= paddle.fluid.layers.DynamicRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.DynamicRNN.update_memory ArgSpec(args=['self', 'ex_mem', 'new_mem'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.StaticRNN.complete_op ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.memory ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1)) paddle.fluid.layers.StaticRNN.output ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None) -paddle.fluid.layers.StaticRNN.parent_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.step ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.step_output ArgSpec(args=['self', 'o'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.update_memory ArgSpec(args=['self', 'mem', 'var'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.reorder_lod_tensor_by_rank ArgSpec(args=['x', 'rank_table'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.ParallelDo.__init__ ArgSpec(args=['self', 'places', 'use_nccl', 'name'], varargs=None, keywords=None, defaults=(False, None)) -paddle.fluid.layers.ParallelDo.complete_op ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.ParallelDo.do ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.ParallelDo.get_parameters ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.ParallelDo.parent_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index df2a7bf90d9be480c514d9dc70571c7f56fd8db2..139411f3e0d945f9265d19a28487c05d06722d69 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -99,7 +99,7 @@ else() endif() -cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 9df7df1f42886d40210b16aa2ae5823e3310bfe7..5d652d37307d0a55ffee14930ae180dcd3e27841 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) - -cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) - cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index b7b67916205689753bc3f9fe844945ee3e78eeb4..5ca2ed8f96244a11925dfa6af8e48458cf334ecd 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -34,30 +34,22 @@ namespace paddle { namespace framework { namespace details { +static const char kLossVarName[] = "loss_var_name"; +static const char kPlaces[] = "places"; +static const char kParams[] = "params"; +static const char kLocalScopes[] = "local_scopes"; +static const char kStrategy[] = "strategy"; + +void MultiDevSSAGraphBuilder::Init() const { + loss_var_name_ = Get(kLossVarName); + places_ = Get>(kPlaces); + local_scopes_ = Get>(kLocalScopes); + strategy_ = Get(kStrategy); #ifdef PADDLE_WITH_CUDA -MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( - const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy) - : loss_var_name_(loss_var_name), - places_(places), - local_scopes_(local_scopes), - nccl_ctxs_(nccl_ctxs), - strategy_(strategy) { -#else -MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( - const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, const BuildStrategy &strategy) - : loss_var_name_(loss_var_name), - places_(places), - local_scopes_(local_scopes), - strategy_(strategy) { + nccl_ctxs_ = &Get("nccl_ctxs"); #endif - for (auto &p : params) { + + for (auto &p : Get>(kParams)) { grad_names_.insert(GradVarName(p)); } balance_vars_.resize(places_.size(), 0); @@ -72,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ir::Node *node, size_t place_id) const { auto p = places_[place_id]; - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); @@ -239,8 +231,9 @@ std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { return sorted_ret; } -std::unique_ptr MultiDevSSAGraphBuilder::Apply( +std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( std::unique_ptr graph) const { + Init(); // Give the topology sort order and rebuild the graph structure. std::vector sorted_ops = SortOpsAndDelayOptimizeOp(*graph); auto nodes = graph->ReleaseNodes(); @@ -254,9 +247,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 - result.Set("vars", new GraphVars(places_.size())); - result.Set("dep_vars", new GraphDepVars); - result.Set("ops", new GraphOps); + result.Set(kGraphVars, new GraphVars(places_.size())); + result.Set(kGraphDepVars, new GraphDepVars); + result.Set(kGraphOps, new GraphOps); + result.Set(kShardedVarDevice, new ShardedVarDevice); // find send/recv vars so that we can place the distributed training // related op in the place 0 @@ -289,11 +283,12 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // the block. is_forwarding = false; } else { - int op_dev_id = GetOpDeviceID(node); + int op_dev_id = GetOpDeviceID(result, node); if (op_dev_id != -1) { // This op only runs on one specific device. CreateComputationalOp(&result, node, op_dev_id); for (ir::Node *n : node->outputs) { - var_name_on_devices_.emplace(n->Name(), op_dev_id); + graph->Get(kShardedVarDevice) + .emplace(n->Name(), op_dev_id); } } else { // This op runs on all devices, and its output may have parameter's @@ -330,7 +325,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( case BuildStrategy::ReduceStrategy::kReduce: cur_device_id = GetAppropriateDeviceID({g_name}); CreateReduceOp(&result, g_name, cur_device_id); - var_name_on_devices_.emplace(g_name, cur_device_id); + graph->Get(kShardedVarDevice) + .emplace(g_name, cur_device_id); bcast_var_name_set[cur_device_id].emplace(p_name); break; case BuildStrategy::ReduceStrategy::kAllReduce: @@ -416,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), local_scopes_, places_); #endif - result->Get("ops").emplace_back(op_handle); + result->Get(kGraphOps).emplace_back(op_handle); auto *in = - result->Get("vars").at(src_dev_id).at(p_name).back().get(); + result->Get(kGraphVars).at(src_dev_id).at(p_name).back().get(); op_handle->AddInput(in); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->Get("vars").at(i).at(p_name); + auto &vars = result->Get(kGraphVars).at(i).at(p_name); auto *out_var = new VarHandle( result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(), i, p_name, p); @@ -437,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ir::Node *node, int dev_id) const { - result->Get("ops").emplace_back( + result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, node, dev_id); @@ -446,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new AllReduceOpHandle( + result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back(new AllReduceOpHandle( + result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), local_scopes_, places_)); #endif - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->Get("vars")[i][og]; + auto &vars = result->Get(kGraphVars)[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); @@ -475,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ir::Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->Get(kGraphOps).emplace_back(new DataBalanceOpHandle( result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->Get(kGraphOps).emplace_back(new DataBalanceOpHandle( result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), local_scopes_, places_)); #endif - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); for (const std::string &d_name : datas) { - auto &vars = result->Get("vars")[i][d_name]; + auto &vars = result->Get(kGraphVars)[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); auto var = new VarHandle( @@ -512,7 +508,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( return is_pg_once; } -int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { +int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, + ir::Node *node) const { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { return -1; } @@ -525,15 +522,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(param_grad.size(), 2U); - int dev_id = GetVarDeviceID(param_grad[1]); + int dev_id = GetVarDeviceID(graph, param_grad[1]); PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", node->Op()->Type(), param_grad[0], param_grad[1]); return dev_id; } -int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { - auto got = var_name_on_devices_.find(varname); - return got == var_name_on_devices_.end() ? -1 : got->second; +int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, + const std::string &varname) const { + auto &sharded_var_device = graph.Get(kShardedVarDevice); + auto got = sharded_var_device.find(varname); + return got == sharded_var_device.end() ? -1 : got->second; } void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { @@ -551,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); - result->Get("ops").emplace_back(op_handle); + result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. @@ -572,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->Get("ops").emplace_back( + result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); CreateOpHandleIOs(result, node, scope_idx); } @@ -582,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new ReduceOpHandle( + result->Get(kGraphOps).emplace_back(new ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back(new ReduceOpHandle( + result->Get(kGraphOps).emplace_back(new ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), local_scopes_, places_)); #endif - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->Get("vars")[i][og]; + auto &vars = result->Get(kGraphVars)[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); } - auto &vars = result->Get("vars")[dst_dev_id][og]; + auto &vars = result->Get(kGraphVars)[dst_dev_id][og]; auto var = new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), vars.size(), dst_dev_id, og, places_[dst_dev_id]); @@ -613,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, // on it. void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { - for (auto &prev_op : result->Get("ops")) { + for (auto &prev_op : result->Get(kGraphOps)) { if (prev_op->Name() == prev_op_name) { auto *dep_var = new DummyVarHandle(result->CreateControlDepVar()); prev_op->AddOutput(dep_var); - result->Get("dep_vars").emplace(dep_var); + result->Get(kGraphDepVars).emplace(dep_var); op->AddInput(dep_var); } } @@ -638,20 +637,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, if (node->Op()->Type() == "split_byref" || node->Op()->Type() == "split_selected_rows") { // TODO(paddle-dev): getting the first var is not safe. - op_dev_id = GetVarDeviceID(input_var_names[0]); + op_dev_id = GetVarDeviceID(*result, input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get(kShardedVarDevice) + .emplace(varname, op_dev_id); } } for (auto &varname : output_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get(kShardedVarDevice) + .emplace(varname, op_dev_id); } } else if (node->Op()->Type() == "concat") { - op_dev_id = GetVarDeviceID(input_var_names[0]); + op_dev_id = GetVarDeviceID(*result, input_var_names[0]); for (auto &varname : output_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get(kShardedVarDevice) + .emplace(varname, op_dev_id); } } else { PADDLE_ENFORCE( @@ -665,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, CreateComputationalOp(result, node, op_dev_id); if (node->Op()->Type() == "concat") { - ConnectOp(result, result->Get("ops").back().get(), + ConnectOp(result, result->Get(kGraphOps).back().get(), "fetch_barrier"); } } @@ -676,7 +678,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, int op_dev_id = -1; if (node->Op()->Type() == "send") { // TODO(paddle-dev): getting the first var is not safe. - op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); + op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name()); PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), "This hack no longer holds, please fix."); // the variable name which contains .block means it was splited by @@ -691,7 +693,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, } op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get(kShardedVarDevice) + .emplace(varname, op_dev_id); } } } else if (node->Op()->Type() == "recv") { @@ -701,7 +704,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, } op_dev_id = GetAppropriateDeviceID(output_var_names); for (auto &varname : output_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get(kShardedVarDevice) + .emplace(varname, op_dev_id); } } else { // send_barrier and fetch_barrier op can be scheduled on device 0 @@ -711,18 +715,18 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", node->Op()->Type()); - result->Get("ops").emplace_back(new RPCOpHandle( + result->Get(kGraphOps).emplace_back(new RPCOpHandle( result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], node->Op()->Type(), places_[op_dev_id])); // TODO(panyx0718): This might not be needed anymore. if (node->Op()->Type() == "send_barrier") { - ConnectOp(result, result->Get("ops").back().get(), "send"); + ConnectOp(result, result->Get(kGraphOps).back().get(), "send"); } else if (node->Op()->Type() == "recv") { - ConnectOp(result, result->Get("ops").back().get(), + ConnectOp(result, result->Get(kGraphOps).back().get(), "send_barrier"); } else if (node->Op()->Type() == "fetch_barrier") { - ConnectOp(result, result->Get("ops").back().get(), "recv"); + ConnectOp(result, result->Get(kGraphOps).back().get(), "recv"); } else if (node->Op()->Type() == "send") { // do nothing } else { @@ -744,3 +748,11 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { } // namespace details } // namespace framework } // namespace paddle + +REGISTER_PASS(multi_device_pass, + paddle::framework::details::MultiDevSSAGraphBuilder) + .RequirePassAttr(paddle::framework::details::kLossVarName) + .RequirePassAttr(paddle::framework::details::kPlaces) + .RequirePassAttr(paddle::framework::details::kParams) + .RequirePassAttr(paddle::framework::details::kLocalScopes) + .RequirePassAttr(paddle::framework::details::kStrategy); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 55076f227b5ab56d66b5053173c9e915da23b15f..099dbe5abef6458c4613c9f680440734f59cb6e2 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -31,39 +31,27 @@ class Scope; namespace details { class MultiDevSSAGraphBuilder : public SSAGraphBuilder { - public: -#ifdef PADDLE_WITH_CUDA - MultiDevSSAGraphBuilder(const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs, - const BuildStrategy &strategy); -#else - MultiDevSSAGraphBuilder(const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, - const BuildStrategy &strategy); -#endif - std::unique_ptr Apply( + protected: + std::unique_ptr ApplyImpl( std::unique_ptr graph) const override; - int GetVarDeviceID(const std::string &varname) const override; private: void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, size_t device_id) const; + void Init() const; private: - std::string loss_var_name_; - const std::vector &places_; - const std::vector &local_scopes_; - std::unordered_set grad_names_; + mutable std::string loss_var_name_; + mutable std::vector places_; + mutable std::vector local_scopes_; + mutable std::unordered_set grad_names_; #ifdef PADDLE_WITH_CUDA - platform::NCCLContextMap *nccl_ctxs_; + mutable platform::NCCLContextMap *nccl_ctxs_; #endif + int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const; + bool IsScaleLossOp(ir::Node *node) const; void CreateRPCOp(ir::Graph *result, ir::Node *node) const; @@ -97,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &og, std::unordered_set *og_has_been_broadcast) const; - int GetOpDeviceID(ir::Node *node) const; + int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; @@ -113,9 +101,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector &var_names) const; private: - BuildStrategy strategy_; + mutable BuildStrategy strategy_; mutable std::unordered_map all_vars_; - mutable std::unordered_map var_name_on_devices_; mutable std::vector balance_vars_; void SetCommunicationContext(OpHandleBase *op_handle, diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index cbfbcb1c0cd24f16773f9633310166371600790c..1b188aec5995edb73835bcf5b851952db0f95f48 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ExecutionStrategy strategy, std::vector local_scopes, std::vector var_infos, std::vector places, std::unique_ptr&& underlying_executor); + + const ir::Graph& Graph() const { return underlying_executor_->Graph(); } + FeedFetchList Run(const std::vector& fetch_tensors) override; private: diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 506e7eb35cd977869424223cb863dd64dbaa9d30..575532540a624afde5f6dab25b11e9eac93c6448 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -18,7 +18,7 @@ namespace paddle { namespace framework { namespace details { void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { - for (auto &var_map : graph->Get("vars")) { + for (auto &var_map : graph->Get(kGraphVars)) { for (auto &name_pair : var_map) { if (name_pair.second.size() <= 1) { continue; @@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); - graph->Get("dep_vars").emplace(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); } } } @@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset) { - auto &var_holders = graph->Get("vars")[place_offset]; + auto &var_holders = graph->Get(kGraphVars)[place_offset]; auto &var_holder = var_holders[node->Name()]; VarHandle *var = nullptr; if (var_holder.empty()) { @@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, ir::Node *new_node, const platform::Place &place, size_t place_offset) { - auto &vars = graph->Get("vars")[place_offset][new_node->Name()]; + auto &vars = + graph->Get(kGraphVars)[place_offset][new_node->Name()]; size_t version = vars.size(); auto var = new VarHandle(new_node, version, place_offset, new_node->Name(), place); @@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, } void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { - for (auto &op : graph->Get("ops")) { + for (auto &op : graph->Get(kGraphOps)) { if (!op->Outputs().empty()) { continue; } auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); - graph->Get("dep_vars").emplace(dummy_leaf); + graph->Get(kGraphDepVars).emplace(dummy_leaf); op->AddOutput(dummy_leaf); } } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 2b4f31f2ff3444f909e3be5eb810ae6737e237b2..53a4ad003d51a27a044d7a142434545eca0d5965 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -39,21 +39,25 @@ namespace details { typedef std::vector< std::unordered_map>>> GraphVars; +const char kGraphVars[] = "vars"; // aux variables to represent dependency. Useful to resolve data hazard. typedef std::unordered_set> GraphDepVars; +const char kGraphDepVars[] = "dep_vars"; // all operators. NOTE that even we use a vector here, the operators is // unordered. typedef std::vector> GraphOps; +const char kGraphOps[] = "ops"; + +typedef std::unordered_map ShardedVarDevice; +const char kShardedVarDevice[] = "sharded_var_device"; class SSAGraphBuilder : public ir::Pass { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual int GetVarDeviceID(const std::string &var_name) const = 0; - DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc deleted file mode 100644 index b4b49d3de6da2e5fd7836668619e42d10bb6b35a..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" -#include -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" -#include "paddle/fluid/framework/details/ssa_graph_checker.h" -#include "paddle/fluid/framework/details/ssa_graph_printer.h" - -namespace paddle { -namespace framework { -namespace details { -std::unique_ptr SSAGraphBuilderFactory::Create() { - std::unique_ptr res( -#ifdef PADDLE_WITH_CUDA - new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_, - local_scopes_, nccl_ctxs_, strategy_) -#else - new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_, - local_scopes_, strategy_) -#endif - ); // NOLINT - - if (!strategy_.debug_graphviz_path_.empty()) { - std::unique_ptr fout( - new std::ofstream(strategy_.debug_graphviz_path_)); - PADDLE_ENFORCE(fout->good()); - std::unique_ptr graphviz_printer( - new GraphvizSSAGraphPrinter()); - res.reset(new SSAGraghBuilderWithPrinter( - std::move(fout), std::move(graphviz_printer), std::move(res))); - } - res.reset(new SSAGraghBuilderWithChecker(std::move(res))); - - return res; -} -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.h b/paddle/fluid/framework/details/ssa_graph_builder_factory.h deleted file mode 100644 index 91a119de83ed3d1573803e48faf86c874eed98d6..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include "paddle/fluid/framework/details/build_strategy.h" -#include "paddle/fluid/framework/details/ssa_graph_builder.h" -#include "paddle/fluid/platform/place.h" - -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/nccl_helper.h" -#endif - -namespace paddle { -namespace framework { -class Scope; -namespace details { - -class SSAGraphBuilderFactory { - public: - SSAGraphBuilderFactory(const std::vector& places, - const std::string& loss_var_name, - const std::unordered_set& param_names, - const std::vector& local_scopes, - const BuildStrategy& strategy) - : places_(places), - loss_var_name_(loss_var_name), - param_names_(param_names), - local_scopes_(local_scopes), - strategy_(strategy) { -#ifdef PADDLE_WITH_CUDA - nccl_ctxs_ = nullptr; -#endif - } - -#ifdef PADDLE_WITH_CUDA - void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) { - nccl_ctxs_ = nccl_ctxs; - } -#endif - - std::unique_ptr Create(); - - private: - std::vector places_; - std::string loss_var_name_; - std::unordered_set param_names_; - std::vector local_scopes_; - BuildStrategy strategy_; - -#ifdef PADDLE_WITH_CUDA - platform::NCCLContextMap* nccl_ctxs_; -#endif -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index 0438b096109a287366610d06ef2bd14c765a8f43..b9e1cda1f24810009bc74a7abdf0156f723a1755 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } }; - for (auto &var_map : graph->Get("vars")) { + for (auto &var_map : graph->Get(kGraphVars)) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { insert_pending_var(version_pair.get()); @@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } } - for (auto &var : graph->Get("dep_vars")) { + for (auto &var : graph->Get(kGraphDepVars)) { insert_pending_var(var.get()); } - for (auto &op : graph->Get("ops")) { + for (auto &op : graph->Get(kGraphOps)) { if (op->Inputs().empty()) { ready_ops.insert(op.get()); } else { @@ -85,3 +85,10 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } // namespace details } // namespace framework } // namespace paddle + +REGISTER_PASS(multi_device_check_pass, + paddle::framework::details::SSAGraghBuilderWithChecker) + .RequireGraphAttr(paddle::framework::details::kGraphVars) + .RequireGraphAttr(paddle::framework::details::kGraphDepVars) + .RequireGraphAttr(paddle::framework::details::kGraphOps) + .RequireGraphAttr(paddle::framework::details::kShardedVarDevice); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 51ce6e5ecad755613551aa6525b5cfbe4a8933ef..0e861ecb236361992d9883e3bd0e679f7563b539 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -23,26 +23,14 @@ namespace framework { namespace details { class SSAGraghBuilderWithChecker : public SSAGraphBuilder { - public: - explicit SSAGraghBuilderWithChecker( - std::unique_ptr&& builder) - : builder_(std::move(builder)) {} - - std::unique_ptr Apply( + protected: + std::unique_ptr ApplyImpl( std::unique_ptr graph) const override { - auto new_graph = builder_->Apply(std::move(graph)); - PADDLE_ENFORCE(IsValidGraph(new_graph.get())); - return new_graph; - } - - int GetVarDeviceID(const std::string& var_name) const override { - return builder_->GetVarDeviceID(var_name); + PADDLE_ENFORCE(IsValidGraph(graph.get())); + return graph; } bool IsValidGraph(const ir::Graph* graph) const; - - private: - std::unique_ptr builder_; }; } // namespace details diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h index 8815ec89b23bc874471eefde5fa855cd2a4bde1f..96fffb7d9430cd00b3823ada9fbe9a65a6bd718c 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.h +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -32,7 +32,9 @@ class SSAGraphExecutor { virtual ~SSAGraphExecutor(); - virtual FeedFetchList Run(const std::vector &fetch_tensors) = 0; + virtual const ir::Graph& Graph() const = 0; + + virtual FeedFetchList Run(const std::vector& fetch_tensors) = 0; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc index 20aab1464400aa9bb1bd6af11c06269c688a8308..ec3f31ab8d135efd2c77018e90cec46b25ca5e66 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -22,7 +22,7 @@ namespace details { template static inline void IterAllVar(const ir::Graph &graph, Callback callback) { - for (auto &each : graph.Get("vars")) { + for (auto &each : graph.Get(kGraphVars)) { for (auto &pair1 : each) { for (auto &pair2 : pair1.second) { callback(*pair2); @@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) { } } - for (auto &var : graph.Get("dep_vars")) { + for (auto &var : graph.Get(kGraphDepVars)) { callback(*var); } } @@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, }); size_t op_id = 0; - for (auto &op : graph.Get("ops")) { + for (auto &op : graph.Get(kGraphOps)) { std::string op_name = "op_" + std::to_string(op_id++); sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" << std::endl; @@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, } // namespace details } // namespace framework } // namespace paddle + +REGISTER_PASS(multi_device_print_pass, + paddle::framework::details::SSAGraghBuilderWithPrinter); diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index a77c1bad3f15bca9064ded860696eb68b033b090..5eafd1805c3102dbd3cdfa68ee1495631c182b51 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -14,7 +14,9 @@ #pragma once +#include #include +#include #include #include "paddle/fluid/framework/details/ssa_graph_builder.h" @@ -34,38 +36,15 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { }; class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { - public: - SSAGraghBuilderWithPrinter(std::ostream& sout, - std::unique_ptr&& printer, - std::unique_ptr&& builder) - : printer_(std::move(printer)), - builder_(std::move(builder)), - stream_ref_(sout) {} - - SSAGraghBuilderWithPrinter(std::unique_ptr&& sout, - std::unique_ptr&& printer, - std::unique_ptr&& builder) - : printer_(std::move(printer)), - builder_(std::move(builder)), - stream_ptr_(std::move(sout)), - stream_ref_(*stream_ptr_) {} - - std::unique_ptr Apply( + protected: + std::unique_ptr ApplyImpl( std::unique_ptr graph) const override { - auto new_graph = builder_->Apply(std::move(graph)); - printer_->Print(*new_graph, stream_ref_); - return new_graph; + std::unique_ptr fout( + new std::ofstream(Get("debug_graphviz_path"))); + PADDLE_ENFORCE(fout->good()); + Get("graph_printer").Print(*graph, *fout); + return graph; } - - int GetVarDeviceID(const std::string& var_name) const override { - return builder_->GetVarDeviceID(var_name); - } - - private: - std::unique_ptr printer_; - std::unique_ptr builder_; - std::unique_ptr stream_ptr_; - std::ostream& stream_ref_; }; } // namespace details diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index c19f74476f9a1498a7d61f5faf204e9966aea155..eec405073377b2782d7636c08e6eb3a7bd41202d 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( std::unordered_set delayed_ops; // Transform SSAGraph to pending_ops & pending_vars - for (auto &var_map : graph_->Get("vars")) { + for (auto &var_map : graph_->Get(details::kGraphVars)) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); } } } - for (auto &var : graph_->Get("dep_vars")) { + for (auto &var : graph_->Get(details::kGraphDepVars)) { InsertPendingVar(&pending_vars, &ready_vars, var.get()); } - for (auto &op : graph_->Get("ops")) { + for (auto &op : graph_->Get(details::kGraphOps)) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op.get()); } else { @@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( std::unordered_map> fetched_vars; for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_->Get("vars")) { + for (auto &var_map : graph_->Get(details::kGraphVars)) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 3d67daa45e20fdea52689684397ad01f2f4cd783..82d6b5272aba161bb19067ebef054bc4bbb8701c 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { const std::vector &places, std::unique_ptr &&graph); + const ir::Graph &Graph() const { return *graph_; } // Run a SSAGraph by a thread pool // Use topological sort algorithm FeedFetchList Run(const std::vector &fetch_tensors) override; diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 6447452ae58344273fe569c91168c7c95a901c8d..bf7d76a8a6e173e648cea5aaba9b7202d787173b 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -1,6 +1,9 @@ cc_library(node SRCS node.cc DEPS proto_desc) cc_library(graph SRCS graph.cc DEPS node) cc_library(graph_helper SRCS graph_helper.cc DEPS graph) -cc_library(pass SRCS pass.cc DEPS graph node) -cc_test(graph_test SRCS graph_test.cc DEPS graph op_registry) -cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry) +cc_library(pass SRCS pass.cc DEPS graph node graph_helper) +cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) + +cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) +cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) +cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 4f59ec82a7d1217621c95d9a4a433a9af43e95da..c9d55fbf525a1a476ac469e8e57462169a7db2da 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -40,14 +40,21 @@ class Graph { attr_dels_.clear(); } + bool Has(const std::string &attr_name) const { + return attrs_.find(attr_name) != attrs_.end(); + } + template AttrType &Get(const std::string &attr_name) const { + PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.", + attr_name); return *boost::any_cast(attrs_.at(attr_name)); } template void Set(const std::string &attr_name, AttrType *attr) { - PADDLE_ENFORCE(attrs_.count(attr_name) == 0); + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", + attr_name); attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { VLOG(3) << "deleting " << attr_name; diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cb812d1388bf74d173a4dc7a99561e730f8e95a --- /dev/null +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/ir/graph_viz_pass.h" + +namespace paddle { +namespace framework { +namespace ir { +static const char kGraphVizPath[] = "graph_viz_path"; + +std::unique_ptr GraphVizPass::ApplyImpl( + std::unique_ptr graph) const { + const std::string graph_viz_path = Get(kGraphVizPath); + std::unique_ptr fout(new std::ofstream(graph_viz_path)); + PADDLE_ENFORCE(fout->good()); + std::ostream& sout = *fout; + + size_t var_id = 0; + std::unordered_map vars; + + sout << "digraph G {\n"; + + for (const ir::Node* n : graph->Nodes()) { + if (n->NodeType() != ir::Node::Type::kVariable) continue; + size_t cur_var_id = var_id++; + vars[n] = cur_var_id; + + sout << "var_" << cur_var_id << " [label=\"" << n->Name() << "\"]" + << std::endl; + } + + size_t op_id = 0; + for (const ir::Node* n : graph->Nodes()) { + if (n->NodeType() != ir::Node::Type::kOperation) continue; + std::string op_name = "op_" + std::to_string(op_id++); + sout << op_name << " [label=\"" << n->Name() << "\", shape=rect]" + << std::endl; + for (auto in : n->inputs) { + std::string var_name = "var_" + std::to_string(vars[in]); + sout << var_name << " -> " << op_name << std::endl; + } + + for (auto out : n->outputs) { + std::string var_name = "var_" + std::to_string(vars[out]); + sout << op_name << " -> " << var_name << std::endl; + } + } + + sout << "}\n"; + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass) + .RequirePassAttr(paddle::framework::ir::kGraphVizPath); diff --git a/paddle/fluid/framework/ir/graph_viz_pass.h b/paddle/fluid/framework/ir/graph_viz_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..1fd8c8a26e9581ccf605d4271a49ec2e90d8b997 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_viz_pass.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class GraphVizPass : public Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index c05d7d0bb54c8ba5938e08f7e8dace8f607d7b89..d7158eba62686be57499df697466797e4034ea8f 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -13,7 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { -namespace framework {} // namespace framework +namespace framework { +namespace ir { +std::unique_ptr Pass::Apply(std::unique_ptr graph) const { + PADDLE_ENFORCE(!applied_, "Pass can only Apply() once."); + PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); + for (const std::string& attr : required_pass_attrs_) { + PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), + "Required pass atrribute %s not set.", attr); + } + for (const std::string& attr : required_graph_attrs_) { + PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.", + attr); + } + auto applied_graph = ApplyImpl(std::move(graph)); + // TODO(panyx0718): Add more verifications. + PADDLE_ENFORCE(!HasCircle(*applied_graph), + "Illegal Pass. Generated graph shouldn't has cycle."); + applied_ = true; + return applied_graph; +} + +PassRegistry& PassRegistry::Instance() { + static PassRegistry g_pass_info_map; + return g_pass_info_map; +} +} // namespace ir +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index f52ba788d55ddb9ed27baa3f6ff0a97e52370fe0..0f14083d259172f5b5f1ed80c7d38312d711beb5 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -14,21 +14,187 @@ limitations under the License. */ #pragma once +#include +#include +#include + #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/variant.h" namespace paddle { namespace framework { namespace ir { +template +struct PassRegistrar; class Pass { public: Pass() = default; - virtual ~Pass() {} + virtual ~Pass() { + for (auto &attr : attrs_) { + if (attr_dels_.find(attr.first) != attr_dels_.end()) { + attr_dels_[attr.first](); + } + } + attrs_.clear(); + attr_dels_.clear(); + } + + std::unique_ptr Apply(std::unique_ptr graph) const; + + // Get a reference to the attributed previously set. + template + AttrType &Get(const std::string &attr_name) const { + PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), + "%s attr not registered for pass.", attr_name); + return *boost::any_cast(attrs_.at(attr_name)); + } + + // Set a pointer to the attribute. Pass takes ownership of the attribute. + template + void Set(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass", + attr_name); + attrs_[attr_name] = attr; + attr_dels_[attr_name] = [attr, attr_name]() { + VLOG(3) << "deleting " << attr_name; + delete attr; + }; + } + + // Set a pointer to the attribute. Pass doesn't take ownership. Caller + // should delete the attribute. + template + void SetNotOwned(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0); + attrs_[attr_name] = attr; + } + + protected: + virtual std::unique_ptr ApplyImpl( + std::unique_ptr graph) const = 0; + + private: + template + friend struct PassRegistrar; + + void RegisterRequiredPassAttrs(const std::unordered_set &attrs) { + required_pass_attrs_.insert(attrs.begin(), attrs.end()); + } + + void RegisterRequiredGraphAttrs( + const std::unordered_set &attrs) { + required_graph_attrs_.insert(attrs.begin(), attrs.end()); + } + + mutable bool applied_{false}; + std::unordered_set required_pass_attrs_; + std::unordered_set required_graph_attrs_; + std::map attrs_; + std::map> attr_dels_; +}; + +using PassCreator = std::function()>; + +class Registrar { + public: + // In our design, various kinds of passes, + // have their corresponding registry and registrar. The action of + // registration is in the constructor of a global registrar variable, which + // are not used in the code that calls package framework, and would + // be removed from the generated binary file by the linker. To avoid such + // removal, we add Touch to all registrar classes and make USE_PASS macros to + // call this method. So, as long as the callee code calls USE_PASS, the global + // registrar variable won't be removed by the linker. + void Touch() {} +}; - virtual std::unique_ptr Apply(std::unique_ptr graph) const = 0; +class PassRegistry { + public: + static PassRegistry &Instance(); + + bool Has(const std::string &pass_type) const { + return map_.find(pass_type) != map_.end(); + } + + void Insert(const std::string &pass_type, const PassCreator &pass_creator) { + PADDLE_ENFORCE(!Has(pass_type), "Pass %s has been registered", pass_type); + map_.insert({pass_type, pass_creator}); + } + + std::unique_ptr Get(const std::string &pass_type) const { + PADDLE_ENFORCE(Has(pass_type), "Pass %s has not been registered", + pass_type); + return map_.at(pass_type)(); + } + + private: + PassRegistry() = default; + std::unordered_map map_; + + DISABLE_COPY_AND_ASSIGN(PassRegistry); }; + +template +struct PassRegistrar : public Registrar { + explicit PassRegistrar(const char *pass_type) { + PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type), + "'%s' is registered more than once.", pass_type); + PassRegistry::Instance().Insert( + pass_type, [this]() -> std::unique_ptr { + std::unique_ptr pass(new PassType()); + pass->RegisterRequiredPassAttrs(this->required_pass_attrs_); + pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_); + return pass; + }); + } + + PassRegistrar &RequirePassAttr(const std::string &attr) { + required_pass_attrs_.insert(attr); + return *this; + } + + PassRegistrar &RequireGraphAttr(const std::string &attr) { + required_graph_attrs_.insert(attr); + return *this; + } + + private: + std::unordered_set required_pass_attrs_; + std::unordered_set required_graph_attrs_; +}; + +#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ + __test_global_namespace_##uniq_name##__>::value, \ + msg) + +// Register a new pass that can be applied on the IR. +#define REGISTER_PASS(pass_type, pass_class) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __reg_pass__##pass_type, \ + "REGISTER_PASS must be called in global namespace"); \ + static ::paddle::framework::ir::PassRegistrar \ + __pass_registrar_##pass_type##__(#pass_type); \ + int TouchPassRegistrar_##pass_type() { \ + __pass_registrar_##pass_type##__.Touch(); \ + return 0; \ + } \ + static ::paddle::framework::ir::PassRegistrar \ + &__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \ + __pass_registrar_##pass_type##__ + +#define USE_PASS(pass_type) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __use_pass_itself_##pass_type, \ + "USE_PASS must be called in global namespace"); \ + extern int TouchPassRegistrar_##pass_type(); \ + static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \ + TouchPassRegistrar_##pass_type() + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5b5011412ed39e033a7a65921e9c64ce2d54c638 --- /dev/null +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/pass.h" +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { +void BuildCircleGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + + o2->outputs.push_back(v2); + o1->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o1); +} + +class TestPass : public Pass { + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const { + graph->Set("copy_test_pass_attr", new int); + graph->Set("copy_test_graph_attr", new int); + + int test_pass_attr = this->Get("test_pass_attr"); + graph->Get("copy_test_pass_attr") = test_pass_attr + 1; + + int test_graph_attr = graph->Get("test_graph_attr"); + graph->Get("copy_test_graph_attr") = test_graph_attr + 1; + return graph; + } +}; + +TEST(PassTest, TestPassAttrCheck) { + ProgramDesc prog; + auto pass = PassRegistry::Instance().Get("test_pass"); + std::unique_ptr graph(new Graph(prog)); + std::string exception; + try { + graph = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("test_pass_attr not set") != exception.npos); + + int val = 1; + graph.reset(new Graph(prog)); + pass->SetNotOwned("test_pass_attr", &val); + + try { + graph = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("test_graph_attr not set") != exception.npos); + + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 1; + graph = pass->Apply(std::move(graph)); + ASSERT_EQ(graph->Get("copy_test_pass_attr"), 2); + ASSERT_EQ(graph->Get("copy_test_graph_attr"), 2); + + try { + graph = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("Pass can only Apply() once") != exception.npos); + + pass = PassRegistry::Instance().Get("test_pass"); + pass->SetNotOwned("test_pass_attr", &val); + graph.reset(new Graph(prog)); + BuildCircleGraph(graph.get()); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 2; + try { + auto tmp = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("shouldn't has cycle") != exception.npos); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(test_pass, paddle::framework::ir::TestPass) + .RequirePassAttr("test_pass_attr") + .RequireGraphAttr("test_graph_attr"); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d1dc5fcd97b77fb7707c7d48f6eaeef140d3f306..7c1c29fd9a81c558f7fd05abf52cd0a6dd522190 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -679,6 +679,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, if (var == nullptr) continue; if (var->IsType()) { CheckTensorNANOrInf(vname, var->Get()); + } else if (var->IsType()) { + CheckTensorNANOrInf(vname, var->Get().value()); } } } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 02c836bea194553bb9c4bc5677cc408dd302e9ce..b5f01a9a2b76472063658f1a051a2ee3c65559b7 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -19,19 +19,80 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_viz_pass.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" #endif #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" -#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" +#include "paddle/fluid/framework/details/ssa_graph_checker.h" +#include "paddle/fluid/framework/details/ssa_graph_printer.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace framework { +std::unique_ptr ApplyParallelExecutorPass( + const ProgramDesc &main_program, const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶m_names, + const std::vector &local_scopes, const bool use_cuda, +#ifdef PADDLE_WITH_CUDA + const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) { +#else + const BuildStrategy &strategy) { +#endif + // Convert the program to graph. + std::unique_ptr graph(new ir::Graph(main_program)); + + // Apply a graph viz pass to record a graph. + if (!strategy.debug_graphviz_path_.empty()) { + auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); + const std::string graph_path = string::Sprintf( + "%s%s", strategy.debug_graphviz_path_.c_str(), "_original_graph"); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + graph = viz_pass->Apply(std::move(graph)); + } + + // Convert graph to run on multi-devices. + auto multi_device_pass = + ir::PassRegistry::Instance().Get("multi_device_pass"); + multi_device_pass->SetNotOwned>("places", + &places); + multi_device_pass->SetNotOwned("loss_var_name", + &loss_var_name); + multi_device_pass->SetNotOwned>( + "params", ¶m_names); + multi_device_pass->SetNotOwned>("local_scopes", + &local_scopes); + multi_device_pass->SetNotOwned("strategy", &strategy); + +#ifdef PADDLE_WITH_CUDA + platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; + multi_device_pass->SetNotOwned("nccl_ctxs", nctx); +#endif + graph = multi_device_pass->Apply(std::move(graph)); + + // Apply a graph print pass to record a graph with device info. + if (!strategy.debug_graphviz_path_.empty()) { + auto multi_device_print_pass = + ir::PassRegistry::Instance().Get("multi_device_print_pass"); + multi_device_print_pass->SetNotOwned( + "debug_graphviz_path", &strategy.debug_graphviz_path_); + multi_device_print_pass->Set( + "graph_printer", new details::GraphvizSSAGraphPrinter); + graph = multi_device_print_pass->Apply(std::move(graph)); + } + + // Verify that the graph is correct for multi-device executor. + auto multi_device_check_pass = + ir::PassRegistry::Instance().Get("multi_device_check_pass"); + graph = multi_device_check_pass->Apply(std::move(graph)); + return graph; +} + class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(const std::vector &places) @@ -119,21 +180,19 @@ ParallelExecutor::ParallelExecutor( var_infos.back().persistable_ = var->Persistable(); } - // Step 3. Convert main_program to SSA form and dependency graph. Also, insert - // ncclOp - details::SSAGraphBuilderFactory builder_factory( - member_->places_, loss_var_name, params, member_->local_scopes_, - build_strategy); - if (member_->use_cuda_) { +// Step 3. Convert main_program to SSA form and dependency graph. Also, insert +// ncclOp #ifdef PADDLE_WITH_CUDA - builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); + std::unique_ptr graph = ApplyParallelExecutorPass( + main_program, member_->places_, loss_var_name, params, + member_->local_scopes_, member_->use_cuda_, build_strategy, + member_->nccl_ctxs_.get()); #else - PADDLE_THROW("Not compiled with CUDA."); + std::unique_ptr graph = ApplyParallelExecutorPass( + main_program, member_->places_, loss_var_name, params, + member_->local_scopes_, member_->use_cuda_, build_strategy); #endif - } - builder_ = builder_factory.Create(); - std::unique_ptr graph(new ir::Graph(main_program)); - graph = builder_->Apply(std::move(graph)); + member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( @@ -146,11 +205,18 @@ void ParallelExecutor::BCastParamsToDevices( // the initializing bcast, all vars would be bcast from device(0), // otherwise // bcast from the specified device. - bool initializing = builder_.get() == nullptr ? true : false; - + bool initializing = member_->executor_ ? false : true; for (auto &var : vars) { - int var_dev_id = - builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var); + int var_dev_id = -1; + if (member_->executor_) { + auto &sharded_var_device = + member_->executor_->Graph().Get( + details::kShardedVarDevice); + if (sharded_var_device.find(var) != sharded_var_device.end()) { + var_dev_id = sharded_var_device.at(var); + } + } + if (!initializing && var_dev_id == -1) continue; framework::Variable *main_var = nullptr; @@ -286,3 +352,8 @@ ParallelExecutor::~ParallelExecutor() { } // namespace framework } // namespace paddle + +USE_PASS(graph_viz_pass); +USE_PASS(multi_device_pass); +USE_PASS(multi_device_check_pass); +USE_PASS(multi_device_print_pass); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index ffb9934a2d702b2bf6db7ad75a6bf9867e1e9901..d624956acde86cefc4ec1dec80df3738bcf1d8be 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -70,7 +70,6 @@ class ParallelExecutor { private: ParallelExecutorPrivate *member_; - std::unique_ptr builder_; }; } // namespace framework diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index e0d7937ae2f3ce4bda12f3771727e2992d63cb9b..a6f68f8b0c0a9b07c326888e30c0c911e7861607 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -60,3 +60,7 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor) cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) + +IF(WITH_GPU) + nv_test(cuda_helper_test SRCS cuda_helper_test.cu) +ENDIF() diff --git a/paddle/fluid/platform/cuda_device_function.h b/paddle/fluid/platform/cuda_device_function.h index ecec4178f2d9937920e52eb74bf9068b84e741a0..23457ff5fe1ec27094113ba0dde26adc64c716b5 100644 --- a/paddle/fluid/platform/cuda_device_function.h +++ b/paddle/fluid/platform/cuda_device_function.h @@ -14,6 +14,10 @@ limitations under the License. */ #pragma once #include +// NOTE(): support float16 to half in header file. +#define PADDLE_CUDA_FP16 +#include +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace platform { @@ -36,6 +40,18 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, #endif } +// CUDA 9.0 have native compatible float16 shfl_down +#if CUDA_VERSION < 9000 +template <> +__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, + float16 val, int delta, + int width) { + half tmp = static_cast(val); + __shfl_down(tmp, static_cast(delta), width); + return float16(tmp); +} +#endif + template __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line, int width = 32) { @@ -46,6 +62,11 @@ __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line, #endif } +template +HOSTDEVICE T Infinity() { + return INFINITY; +} + template __device__ T reduceSum(T val, int tid, int len) { // NOTE(zcd): The warp size should be taken from the diff --git a/paddle/fluid/platform/cuda_helper_test.cu b/paddle/fluid/platform/cuda_helper_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..4a47ba5ccad4de338844e60f6fcbd6b7c11e891b --- /dev/null +++ b/paddle/fluid/platform/cuda_helper_test.cu @@ -0,0 +1,118 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#define PADDLE_CUDA_FP16 +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; +using paddle::platform::float16; + +#define CUDA_ATOMIC_KERNEL(op, T) \ + __global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \ + i += blockDim.x * gridDim.x) { \ + paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \ + } \ + } + +template +struct AddFunctor { + T operator()(const T& a, const T& b) { return a + b; } +}; + +template +struct SubFunctor { + T operator()(const T& a, const T& b) { return a - b; } +}; + +// NOTE(dzhwinter): the float16 add has small underflow/overflow +// so we use EXPECT_NEAR to check the result. +#define ARITHMETIC_KERNEL_LAUNCH(op, T) \ + void Test##T##op(size_t num) { \ + T *in1, *in2, *out; \ + T *d_in1, *d_in2; \ + size_t size = sizeof(T) * num; \ + cudaMalloc(reinterpret_cast(&d_in1), size); \ + cudaMalloc(reinterpret_cast(&d_in2), size); \ + in1 = reinterpret_cast(malloc(size)); \ + in2 = reinterpret_cast(malloc(size)); \ + out = reinterpret_cast(malloc(size)); \ + std::minstd_rand engine; \ + std::uniform_real_distribution dist(0.0, 1.0); \ + for (size_t i = 0; i < num; ++i) { \ + in1[i] = static_cast(dist(engine)); \ + in2[i] = static_cast(dist(engine)); \ + } \ + cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \ + cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \ + op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \ + cudaDeviceSynchronize(); \ + cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \ + cudaDeviceSynchronize(); \ + for (size_t i = 0; i < num; ++i) { \ + EXPECT_NEAR(static_cast(out[i]), \ + static_cast(op##Functor()(in1[i], in2[i])), \ + 0.001); \ + } \ + free(in1); \ + free(in2); \ + free(out); \ + cudaFree(d_in1); \ + cudaFree(d_in2); \ + } +CUDA_ATOMIC_KERNEL(Add, float); +CUDA_ATOMIC_KERNEL(Add, double); +CUDA_ATOMIC_KERNEL(Add, float16); + +ARITHMETIC_KERNEL_LAUNCH(Add, float); +ARITHMETIC_KERNEL_LAUNCH(Add, double); +ARITHMETIC_KERNEL_LAUNCH(Add, float16); + +namespace paddle { +namespace platform { +USE_CUDA_ATOMIC(Sub, int); +}; +}; +CUDA_ATOMIC_KERNEL(Sub, int); +ARITHMETIC_KERNEL_LAUNCH(Sub, int); + +// cuda primitives +TEST(CudaAtomic, Add) { + TestfloatAdd(static_cast(10)); + TestfloatAdd(static_cast(1024 * 1024)); + TestdoubleAdd(static_cast(10)); + TestdoubleAdd(static_cast(1024 * 1024)); +} + +TEST(CudaAtomic, Sub) { + TestintSub(static_cast(10)); + TestintSub(static_cast(1024 * 1024)); +} + +TEST(CudaAtomic, float16) { + using paddle::platform::float16; + Testfloat16Add(static_cast(1)); + Testfloat16Add(static_cast(2)); + Testfloat16Add(static_cast(3)); + + Testfloat16Add(static_cast(10)); + Testfloat16Add(static_cast(1024 * 1024)); +} diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index d535ed2f89df6a0b311ec068ecd92c8e3183cee7..94ce83975a7f13daa2b6a4d480cb22cc95811b9b 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -14,12 +14,14 @@ limitations under the License. */ #pragma once #include +#include +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace platform { #define CUDA_ATOMIC_WRAPPER(op, T) \ - __device__ __forceinline__ T CudaAtomic##op(T* address, const T val) + __device__ __forceinline__ T CudaAtomic##op(T *address, const T val) #define USE_CUDA_ATOMIC(op, T) \ CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } @@ -42,17 +44,17 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) { static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT "long long should be int64"); return CudaAtomicAdd( - reinterpret_cast(address), // NOLINT - static_cast(val)); // NOLINT + reinterpret_cast(address), // NOLINT + static_cast(val)); // NOLINT } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 USE_CUDA_ATOMIC(Add, double); #else CUDA_ATOMIC_WRAPPER(Add, double) { - unsigned long long int* address_as_ull = // NOLINT - reinterpret_cast(address); // NOLINT - unsigned long long int old = *address_as_ull, assumed; // NOLINT + unsigned long long int *address_as_ull = // NOLINT + reinterpret_cast(address); // NOLINT + unsigned long long int old = *address_as_ull, assumed; // NOLINT do { assumed = old; @@ -64,6 +66,67 @@ CUDA_ATOMIC_WRAPPER(Add, double) { return __longlong_as_double(old); } +#endif + +#ifdef PADDLE_CUDA_FP16 +// NOTE(dzhwinter): cuda do not have atomicCAS for half. +// Just use the half address as a unsigned value address and +// do the atomicCAS. According to the value store at high 16 bits +// or low 16 bits, then do a different sum and CAS. +// Given most warp-threads will failed on the atomicCAS, so this +// implemented should be avoided in high concurrency. It's will be +// slower than the way convert value into 32bits and do a full atomicCAS. + +// convert the value into float and do the add arithmetic. +// then store the result into a uint32. +inline __device__ uint32_t add_to_low_half(uint32_t val, float x) { + float16 low_half; + // the float16 in lower 16bits + low_half.x = static_cast(val & 0xffffu); + low_half = static_cast(static_cast(low_half) + x); + return (val & 0xffff0000u) | low_half.x; +} + +inline __device__ uint32_t add_to_high_half(uint32_t val, float x) { + float16 high_half; + // the float16 in higher 16bits + high_half.x = static_cast(val >> 16); + high_half = static_cast(static_cast(high_half) + x); + return (val & 0xffffu) | (static_cast(high_half.x) << 16); +} + +CUDA_ATOMIC_WRAPPER(Add, float16) { + // concrete packed float16 value may exsits in lower or higher 16bits + // of the 32bits address. + uint32_t *address_as_ui = + reinterpret_cast(reinterpret_cast(address) - + (reinterpret_cast(address) & 2)); + float val_f = static_cast(val); + uint32_t old = *address_as_ui; + uint32_t sum; + uint32_t newval; + uint32_t assumed; + if (((size_t)address & 2) == 0) { + // the float16 value stay at lower 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f)); + } while (old != assumed); + float16 ret; + ret.x = old & 0xffffu; + return ret; + } else { + // the float16 value stay at higher 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f)); + } while (old != assumed); + float16 ret; + ret.x = old >> 16; + return ret; + } +} + #endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index ffd183af68514dbb1a8b3de39000c9ca3f56ddc3..efb021c838e3680ab2cdd1c4b298cf7ec2186478 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -67,8 +67,11 @@ struct float16; } // namespace platform } // namespace paddle +// NOTE(): +// Do not move the eigen.h header, otherwise the eigen_vector will failed. #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/platform/hostdevice.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace platform { @@ -898,6 +901,30 @@ struct is_pod { is_standard_layout::value; }; +template <> +struct is_floating_point + : std::integral_constant< + bool, std::is_same::type>::value> {}; +template <> +struct is_signed { + static const bool value = true; +}; + +template <> +struct is_unsigned { + static const bool value = false; +}; + +inline bool isnan(const paddle::platform::float16& a) { + return paddle::platform::isnan(a); +} + +inline bool isinf(const paddle::platform::float16& a) { + return paddle::platform::isinf(a); +} + template <> struct numeric_limits { static const bool is_specialized = true; diff --git a/paddle/fluid/platform/float16_test.cc b/paddle/fluid/platform/float16_test.cc index ede294be1e2e26693bd3ead2ccd5e6a6c8a075bc..27e930e6e0a76982b3f27619f38a4a08d82cafa1 100644 --- a/paddle/fluid/platform/float16_test.cc +++ b/paddle/fluid/platform/float16_test.cc @@ -141,10 +141,36 @@ TEST(float16, lod_tensor_cpu) { } } +TEST(float16, floating) { + // compile time assert. + PADDLE_ASSERT(std::is_floating_point::value); +} + TEST(float16, print) { float16 a = float16(1.0f); std::cout << a << std::endl; } +// CPU test +TEST(float16, isinf) { + float16 a; + a.x = 0x7c00; + float16 b = float16(INFINITY); + float16 c = static_cast(INFINITY); + EXPECT_EQ(std::isinf(a), true); + EXPECT_EQ(std::isinf(b), true); + EXPECT_EQ(std::isinf(c), true); +} + +TEST(float16, isnan) { + float16 a; + a.x = 0x7fff; + float16 b = float16(NAN); + float16 c = static_cast(NAN); + EXPECT_EQ(std::isnan(a), true); + EXPECT_EQ(std::isnan(b), true); + EXPECT_EQ(std::isnan(c), true); +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/float16_test.cu b/paddle/fluid/platform/float16_test.cu index 1b9cf9b5d3fa2121b588c31d7cf2f4c50cb951bc..e2b7ca9b03809113c31af8ff4d3ad3713748f330 100644 --- a/paddle/fluid/platform/float16_test.cu +++ b/paddle/fluid/platform/float16_test.cu @@ -11,11 +11,13 @@ limitations under the License. */ #include "paddle/fluid/platform/float16.h" +#include #include +#include +#include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/legacy/utils/Logging.h" #define ARITHMETIC_KERNEL(op_type, sign) \ __global__ void op_type(const half* in1, const half* in2, half* out) { \ @@ -241,6 +243,72 @@ TEST(float16, lod_tensor_on_gpu) { } } +template +struct Functor { + bool operator()(const T& val) { + return std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16)); + } +}; + +TEST(float16, typeid) { + // the framework heavily used typeid hash + Functor functor; + float16 a = float16(.0f); + Functor functor2; + int b(0); + + // compile time assert + PADDLE_ASSERT(functor(a) == true); + PADDLE_ASSERT(functor2(b) == false); +} + +// GPU test +TEST(float16, isinf) { + float16 a; + a.x = 0x7c00; + float16 b = float16(INFINITY); + // underflow to 0 + float16 native_a(5e-40f); + // overflow to inf + float16 native_b(5e40f); + EXPECT_EQ(std::isinf(a), true); + EXPECT_EQ(std::isinf(b), true); + EXPECT_EQ(std::isinf(native_b), true); + EXPECT_EQ(native_a, float16(0)); +} + +TEST(float16, isnan) { + float16 a; + a.x = 0x7fff; + float16 b = float16(NAN); + float16 c = float16(5e40); + // inf * +-0 will get a nan + float16 d = c * float16(0); + EXPECT_EQ(std::isnan(a), true); + EXPECT_EQ(std::isnan(b), true); + EXPECT_EQ(std::isnan(d), true); +} + +TEST(float16, cast) { + float16 a; + a.x = 0x0070; + auto b = a; + { + // change semantic, keep the same value + float16 c = reinterpret_cast(reinterpret_cast(b)); + EXPECT_EQ(b, c); + } + + { + // use uint32 low 16 bit store float16 + uint32_t c = reinterpret_cast(b); + float16 d; + d.x = c; + EXPECT_EQ(b, d); + } +} + } // namespace platform } // namespace paddle #endif // PADDLE_CUDA_FP16 diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index f05ae6d5d1900560e37370121bf64f1fcab14357..3ee1c636ace504e14cf7d6c106df1bc3e864d660 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -21,6 +21,7 @@ from ..layer_helper import LayerHelper, unique_name from ..initializer import force_init_on_cpu from ops import logical_and, logical_not, logical_or import numpy +import warnings __all__ = [ 'While', @@ -280,6 +281,9 @@ class ParallelDo(object): """ def __init__(self, places, use_nccl=False, name=None): + warnings.warn( + "API ParallelDo is deprecated since 0.15.0. Please use ParallelExecutor instead.", + Warning) self.helper = LayerHelper("parallel_do", name=name) self.inputs = [] self.places = places @@ -338,7 +342,7 @@ class ParallelDo(object): return [parent_block.var(name) for name in params] - def complete_op(self): + def _complete_op(self): main_program = self.helper.main_program current_block = main_program.current_block() parent_block = self.parent_block() @@ -394,7 +398,7 @@ class BlockGuardWithCompletion(BlockGuard): if exc_type is not None: return False self.rnn.status = StaticRNN.AFTER_RNN_BLOCK - self.rnn.complete_op() + self.rnn._complete_op() return super(BlockGuardWithCompletion, self).__exit__(exc_type, exc_val, exc_tb) @@ -470,7 +474,7 @@ class StaticRNN(object): if shape is None or batch_ref is None: raise ValueError( "if init is None, memory at least need shape and batch_ref") - parent_block = self.parent_block() + parent_block = self._parent_block() var_name = unique_name.generate("@".join( [self.helper.name, "memory_boot"])) boot_var = parent_block.create_var( @@ -527,7 +531,7 @@ class StaticRNN(object): outputs={'Out': tmp_o}, attrs={'dtype': o.dtype}) - out_var = self.parent_block().create_var( + out_var = self._parent_block().create_var( name=tmp_o.name, shape=[self.seq_len] + list(tmp_o.shape), dtype=tmp_o.dtype) @@ -543,7 +547,7 @@ class StaticRNN(object): raise TypeError("update memory should take variables") self.memories[mem.name].mem = var - def parent_block(self): + def _parent_block(self): prog = self.helper.main_program parent_idx = prog.current_block().parent_idx assert parent_idx >= 0 @@ -560,10 +564,10 @@ class StaticRNN(object): else: return self.outputs - def complete_op(self): + def _complete_op(self): main_program = self.helper.main_program rnn_block = main_program.current_block() - parent_block = self.parent_block() + parent_block = self._parent_block() local_inputs = set() @@ -643,7 +647,7 @@ class WhileGuard(BlockGuard): if exc_type is not None: return False self.while_op.status = While.AFTER_WHILE_BLOCK - self.while_op.complete() + self.while_op._complete() return super(WhileGuard, self).__exit__(exc_type, exc_val, exc_tb) @@ -690,7 +694,7 @@ class While(object): def block(self): return WhileGuard(self) - def complete(self): + def _complete(self): main_program = self.helper.main_program while_block = main_program.current_block() parent_block = main_program.block(main_program.current_block() diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index 0871ad715fa6c939b9fb07d4dc963d91168de8bf..3b67b3f5ccd67f86f87f292d83a6039ff46260bd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -19,6 +19,7 @@ import math import unittest import os +import sys import signal import subprocess @@ -56,7 +57,7 @@ class TestDistSeResneXt2x2(unittest.TestCase): except os.error: retry_times -= 1 - def no_test_with_place(self): + def test_with_place(self): # *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN required_envs = { "PATH": os.getenv("PATH"), @@ -70,9 +71,15 @@ class TestDistSeResneXt2x2(unittest.TestCase): local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \ (self._python_interp, "127.0.0.1:1234", "127.0.0.1:1234", 1) local_proc = subprocess.Popen( - local_cmd.split(" "), stdout=subprocess.PIPE, env=env_local) + local_cmd.split(" "), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env_local) local_proc.wait() - local_ret = local_proc.stdout.read() + out, err = local_proc.communicate() + local_ret = out + sys.stderr.write('local_loss: %s\n' % local_ret) + sys.stderr.write('local_stderr: %s\n' % err) # Run dist train to compare with local results ps0, ps1 = self.start_pserver() @@ -92,13 +99,22 @@ class TestDistSeResneXt2x2(unittest.TestCase): FNULL = open(os.devnull, 'w') tr0_proc = subprocess.Popen( - tr0_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env0) + tr0_cmd.split(" "), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env0) tr1_proc = subprocess.Popen( - tr1_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env1) + tr1_cmd.split(" "), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env1) tr0_proc.wait() tr1_proc.wait() - loss_data0 = tr0_proc.stdout.read() + out, err = tr0_proc.communicate() + sys.stderr.write('dist_stderr: %s\n' % err) + loss_data0 = out + sys.stderr.write('dist_loss: %s\n' % loss_data0) lines = loss_data0.split("\n") dist_first_loss = eval(lines[0].replace(" ", ","))[0] dist_last_loss = eval(lines[1].replace(" ", ","))[0] diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 4a9ea6af747c36e5817ede5fafbadeea79fb07ac..2d9c089c0b7667c875aae05cb4e6040b007f3d55 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -347,6 +347,7 @@ class DistributeTranspiler(object): # step1 pserver_program = Program() + pserver_program.random_seed = self.origin_program.random_seed # step2: Create vars to receive vars at parameter servers. recv_inputs = [] for v in self.param_grad_ep_mapping[endpoint]["params"]: @@ -544,6 +545,7 @@ class DistributeTranspiler(object): """ s_prog = Program() orig_s_prog = default_startup_program() + s_prog.random_seed = orig_s_prog.random_seed params = self.param_grad_ep_mapping[endpoint]["params"] def _get_splited_name_and_shape(varname):