提交 15c2f1ab 编写于 作者: S Superjomn

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fea/analysis-ssa

...@@ -263,7 +263,7 @@ function(cc_test TARGET_NAME) ...@@ -263,7 +263,7 @@ function(cc_test TARGET_NAME)
COMMAND ${TARGET_NAME} ${cc_test_ARGS} COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_SERIAL}) if (${cc_test_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY SERIAL 1) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
endif() endif()
endif() endif()
...@@ -328,7 +328,7 @@ function(nv_test TARGET_NAME) ...@@ -328,7 +328,7 @@ function(nv_test TARGET_NAME)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
add_test(${TARGET_NAME} ${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL) if (nv_test_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY SERIAL 1) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
endif() endif()
endif() endif()
......
...@@ -64,6 +64,41 @@ can also contain other things that describe some properties of ...@@ -64,6 +64,41 @@ can also contain other things that describe some properties of
the `Graph` or `Graph` nodes. `Attribute` can be passed the `Graph` or `Graph` nodes. `Attribute` can be passed
across `Pass`. However, it should be used with care. across `Pass`. However, it should be used with care.
```cpp
class Graph {
public:
explicit Graph(const ProgramDesc &program);
bool Has(const std::string &attr_name) const;
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const;
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr);
const std::unordered_set<ir::Node *> &Nodes() const;
// Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc);
// Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc);
// Create a control dependency var that connects 2 operations. The
// var doesn't hold any data. Other than that, it's no different from
// other var, considering dependency analysis.
ir::Node *CreateControlDepVar();
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type);
// Clear all node information of the graph and return the ownership of the
// nodes.
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes();
};
```
#### Pass #### Pass
`Pass` represents a transformation of `Graph`. Its input `Pass` represents a transformation of `Graph`. Its input
...@@ -71,6 +106,54 @@ is a `Graph` and its output is also a `Graph`. For example, ...@@ -71,6 +106,54 @@ is a `Graph` and its output is also a `Graph`. For example,
a `Pass` can simply print out the `Graph`. A `Pass` a `Pass` can simply print out the `Graph`. A `Pass`
can also fuse some `Graph`'s `Node`s. can also fuse some `Graph`'s `Node`s.
```cpp
class Pass {
public:
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const {
// Some correctness check.
auto new_graph = ApplyImpl(std::move(graph));
// Some correctness check.
return new_graph;
}
// Get a reference to the attributed previously set.
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const;
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) ;
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr);
protected:
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const = 0;
};
// In my_pass.cc
class MyPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override {
// do something.
return graph;
}
}
REGISTER_PASS(my_pass, MyPass)
.RequirePassAttr("places")
.RequireGraphAttr("dep_vars");
// To use the pass.
auto my_pass = ir::PassRegistry::Instance().Get("my_pass");
graph = my_pass->Apply(std::move(graph));
// Note: to force link my_pass.cc, in the code:
USE_PASS(my_pass);
```
#### Optimize #### Optimize
`Optimize` contains a series of `Pass` with defined order. `Optimize` contains a series of `Pass` with defined order.
...@@ -86,4 +169,17 @@ maintaining the original modeling logic. ...@@ -86,4 +169,17 @@ maintaining the original modeling logic.
* Graph is transformed from raw model logic to a * Graph is transformed from raw model logic to a
form that is efficient to execute. form that is efficient to execute.
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor ```
// Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
auto graph = Graph(program);
graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
// For more complex Pass, Optimize Process can provide Pass attributes.
auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
mem_opt_pass->Apply(std::move(graph));
graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
Executor exe;
exe.Run(graph);
```
...@@ -170,6 +170,7 @@ paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], var ...@@ -170,6 +170,7 @@ paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], var
paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
...@@ -201,7 +202,6 @@ paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs= ...@@ -201,7 +202,6 @@ paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.complete ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.default ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Switch.default ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
...@@ -225,17 +225,14 @@ paddle.fluid.layers.DynamicRNN.static_input ArgSpec(args=['self', 'x'], varargs= ...@@ -225,17 +225,14 @@ paddle.fluid.layers.DynamicRNN.static_input ArgSpec(args=['self', 'x'], varargs=
paddle.fluid.layers.DynamicRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.DynamicRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.DynamicRNN.update_memory ArgSpec(args=['self', 'ex_mem', 'new_mem'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.DynamicRNN.update_memory ArgSpec(args=['self', 'ex_mem', 'new_mem'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.StaticRNN.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.StaticRNN.complete_op ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.memory ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1)) paddle.fluid.layers.StaticRNN.memory ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1))
paddle.fluid.layers.StaticRNN.output ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.output ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.parent_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.step ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.step ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.step_output ArgSpec(args=['self', 'o'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.step_output ArgSpec(args=['self', 'o'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.update_memory ArgSpec(args=['self', 'mem', 'var'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.StaticRNN.update_memory ArgSpec(args=['self', 'mem', 'var'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.reorder_lod_tensor_by_rank ArgSpec(args=['x', 'rank_table'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.reorder_lod_tensor_by_rank ArgSpec(args=['x', 'rank_table'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.__init__ ArgSpec(args=['self', 'places', 'use_nccl', 'name'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.layers.ParallelDo.__init__ ArgSpec(args=['self', 'places', 'use_nccl', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.ParallelDo.complete_op ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.do ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.ParallelDo.do ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.get_parameters ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.ParallelDo.get_parameters ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.parent_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.ParallelDo.parent_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
......
...@@ -99,7 +99,7 @@ else() ...@@ -99,7 +99,7 @@ else()
endif() endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph) cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s ...@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
......
...@@ -34,30 +34,22 @@ namespace paddle { ...@@ -34,30 +34,22 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
static const char kLossVarName[] = "loss_var_name";
static const char kPlaces[] = "places";
static const char kParams[] = "params";
static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";
void MultiDevSSAGraphBuilder::Init() const {
loss_var_name_ = Get<const std::string>(kLossVarName);
places_ = Get<const std::vector<platform::Place>>(kPlaces);
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);
strategy_ = Get<const BuildStrategy>(kStrategy);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs),
strategy_(strategy) {
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const BuildStrategy &strategy)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
strategy_(strategy) {
#endif #endif
for (auto &p : params) {
for (auto &p : Get<const std::unordered_set<std::string>>(kParams)) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
balance_vars_.resize(places_.size(), 0); balance_vars_.resize(places_.size(), 0);
...@@ -72,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ...@@ -72,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node, ir::Node *node,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -239,8 +231,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { ...@@ -239,8 +231,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
return sorted_ret; return sorted_ret;
} }
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
Init();
// Give the topology sort order and rebuild the graph structure. // Give the topology sort order and rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph); std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
auto nodes = graph->ReleaseNodes(); auto nodes = graph->ReleaseNodes();
...@@ -254,9 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -254,9 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8 // We cannot invoke resize. It is a bug of GCC 4.8
result.Set("vars", new GraphVars(places_.size())); result.Set(kGraphVars, new GraphVars(places_.size()));
result.Set("dep_vars", new GraphDepVars); result.Set(kGraphDepVars, new GraphDepVars);
result.Set("ops", new GraphOps); result.Set(kGraphOps, new GraphOps);
result.Set(kShardedVarDevice, new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// related op in the place 0 // related op in the place 0
...@@ -289,11 +283,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -289,11 +283,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(node); int op_dev_id = GetOpDeviceID(result, node);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id); CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
var_name_on_devices_.emplace(n->Name(), op_dev_id); graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(n->Name(), op_dev_id);
} }
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
...@@ -330,7 +325,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -330,7 +325,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name}); cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices_.emplace(g_name, cur_device_id); graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
...@@ -416,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, ...@@ -416,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
local_scopes_, places_); local_scopes_, places_);
#endif #endif
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
auto *in = auto *in =
result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get(); result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name); auto &vars = result->Get<GraphVars>(kGraphVars).at(i).at(p_name);
auto *out_var = new VarHandle( auto *out_var = new VarHandle(
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(), result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
i, p_name, p); i, p_name, p);
...@@ -437,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, ...@@ -437,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id])); local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
...@@ -446,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ...@@ -446,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
const std::string &og) const { const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_)); local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars")[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
...@@ -475,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -475,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
ir::Graph *result, const std::vector<std::string> &datas) const { ir::Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_)); local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) { for (const std::string &d_name : datas) {
auto &vars = result->Get<GraphVars>("vars")[i][d_name]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back().get());
auto var = new VarHandle( auto var = new VarHandle(
...@@ -512,7 +508,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -512,7 +508,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
ir::Node *node) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
...@@ -525,15 +522,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { ...@@ -525,15 +522,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]); int dev_id = GetVarDeviceID(graph, param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]); node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id; return dev_id;
} }
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
auto got = var_name_on_devices_.find(varname); const std::string &varname) const {
return got == var_name_on_devices_.end() ? -1 : got->second; auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice);
auto got = sharded_var_device.find(varname);
return got == sharded_var_device.end() ? -1 : got->second;
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
...@@ -551,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { ...@@ -551,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_.size(), local_scopes_[i], places_[i], local_scopes_.size(), local_scopes_[i], places_[i],
communication_dev_ctx); communication_dev_ctx);
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators. // factor. So it does not depend on any other operators.
...@@ -572,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, ...@@ -572,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
CreateOpHandleIOs(result, node, scope_idx); CreateOpHandleIOs(result, node, scope_idx);
} }
...@@ -582,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -582,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
const std::string &og, const std::string &og,
int dst_dev_id) const { int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_)); local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>("vars")[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.size(), dst_dev_id, og, places_[dst_dev_id]);
...@@ -613,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -613,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
// on it. // on it.
void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op, void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) { for (auto &prev_op : result->Get<GraphOps>(kGraphOps)) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle(result->CreateControlDepVar()); auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var); result->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
op->AddInput(dep_var); op->AddInput(dep_var);
} }
} }
...@@ -638,20 +637,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -638,20 +637,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if (node->Op()->Type() == "split_byref" || if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") { node->Op()->Type() == "split_selected_rows") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(input_var_names[0]); op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
} }
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
} }
} else if (node->Op()->Type() == "concat") { } else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(input_var_names[0]); op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
} }
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -665,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -665,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
"fetch_barrier"); "fetch_barrier");
} }
} }
...@@ -676,7 +678,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -676,7 +678,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name());
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix."); "This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
...@@ -691,7 +693,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -691,7 +693,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
} }
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
} }
} }
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
...@@ -701,7 +704,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -701,7 +704,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
} }
op_dev_id = GetAppropriateDeviceID(output_var_names); op_dev_id = GetAppropriateDeviceID(output_var_names);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
} }
} else { } else {
// send_barrier and fetch_barrier op can be scheduled on device 0 // send_barrier and fetch_barrier op can be scheduled on device 0
...@@ -711,18 +715,18 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -711,18 +715,18 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
node->Op()->Type()); node->Op()->Type());
result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id])); node->Op()->Type(), places_[op_dev_id]));
// TODO(panyx0718): This might not be needed anymore. // TODO(panyx0718): This might not be needed anymore.
if (node->Op()->Type() == "send_barrier") { if (node->Op()->Type() == "send_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send"); ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "send");
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
"send_barrier"); "send_barrier");
} else if (node->Op()->Type() == "fetch_barrier") { } else if (node->Op()->Type() == "fetch_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv"); ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "recv");
} else if (node->Op()->Type() == "send") { } else if (node->Op()->Type() == "send") {
// do nothing // do nothing
} else { } else {
...@@ -744,3 +748,11 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { ...@@ -744,3 +748,11 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_device_pass,
paddle::framework::details::MultiDevSSAGraphBuilder)
.RequirePassAttr(paddle::framework::details::kLossVarName)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kParams)
.RequirePassAttr(paddle::framework::details::kLocalScopes)
.RequirePassAttr(paddle::framework::details::kStrategy);
...@@ -31,39 +31,27 @@ class Scope; ...@@ -31,39 +31,27 @@ class Scope;
namespace details { namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder { class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public: protected:
#ifdef PADDLE_WITH_CUDA std::unique_ptr<ir::Graph> ApplyImpl(
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs,
const BuildStrategy &strategy);
#else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy);
#endif
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override;
private: private:
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const; size_t device_id) const;
void Init() const;
private: private:
std::string loss_var_name_; mutable std::string loss_var_name_;
const std::vector<platform::Place> &places_; mutable std::vector<platform::Place> places_;
const std::vector<Scope *> &local_scopes_; mutable std::vector<Scope *> local_scopes_;
std::unordered_set<std::string> grad_names_; mutable std::unordered_set<std::string> grad_names_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_; mutable platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(ir::Graph *result, ir::Node *node) const; void CreateRPCOp(ir::Graph *result, ir::Node *node) const;
...@@ -97,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -97,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(ir::Node *node) const; int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
...@@ -113,9 +101,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -113,9 +101,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<std::string> &var_names) const; const std::vector<std::string> &var_names) const;
private: private:
BuildStrategy strategy_; mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_; mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
mutable std::vector<int64_t> balance_vars_; mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(OpHandleBase *op_handle,
......
...@@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy strategy, std::vector<Scope*> local_scopes, ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places, std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor); std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
const ir::Graph& Graph() const { return underlying_executor_->Graph(); }
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override; FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private: private:
......
...@@ -18,7 +18,7 @@ namespace paddle { ...@@ -18,7 +18,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) { if (name_pair.second.size() <= 1) {
continue; continue;
...@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { ...@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var); graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
} }
} }
} }
...@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { ...@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
ir::Graph *graph, ir::Node *node, const platform::Place &place, ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset]; auto &var_holders = graph->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[node->Name()]; auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
...@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, ...@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node, ir::Node *new_node,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()]; auto &vars =
graph->Get<GraphVars>(kGraphVars)[place_offset][new_node->Name()];
size_t version = vars.size(); size_t version = vars.size();
auto var = auto var =
new VarHandle(new_node, version, place_offset, new_node->Name(), place); new VarHandle(new_node, version, place_offset, new_node->Name(), place);
...@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, ...@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
} }
void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
for (auto &op : graph->Get<GraphOps>("ops")) { for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf); graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
} }
......
...@@ -39,21 +39,25 @@ namespace details { ...@@ -39,21 +39,25 @@ namespace details {
typedef std::vector< typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>> std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars; GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is // all operators. NOTE that even we use a vector here, the operators is
// unordered. // unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps; typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
class SSAGraphBuilder : public ir::Pass { class SSAGraphBuilder : public ir::Pass {
public: public:
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected: protected:
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
std::unique_ptr<SSAGraphBuilder> res(
#ifdef PADDLE_WITH_CUDA
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, nccl_ctxs_, strategy_)
#else
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, strategy_)
#endif
); // NOLINT
if (!strategy_.debug_graphviz_path_.empty()) {
std::unique_ptr<std::ostream> fout(
new std::ofstream(strategy_.debug_graphviz_path_));
PADDLE_ENFORCE(fout->good());
std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer(
new GraphvizSSAGraphPrinter());
res.reset(new SSAGraghBuilderWithPrinter(
std::move(fout), std::move(graphviz_printer), std::move(res)));
}
res.reset(new SSAGraghBuilderWithChecker(std::move(res)));
return res;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
class Scope;
namespace details {
class SSAGraphBuilderFactory {
public:
SSAGraphBuilderFactory(const std::vector<platform::Place>& places,
const std::string& loss_var_name,
const std::unordered_set<std::string>& param_names,
const std::vector<Scope*>& local_scopes,
const BuildStrategy& strategy)
: places_(places),
loss_var_name_(loss_var_name),
param_names_(param_names),
local_scopes_(local_scopes),
strategy_(strategy) {
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_ = nullptr;
#endif
}
#ifdef PADDLE_WITH_CUDA
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {
nccl_ctxs_ = nccl_ctxs;
}
#endif
std::unique_ptr<SSAGraphBuilder> Create();
private:
std::vector<platform::Place> places_;
std::string loss_var_name_;
std::unordered_set<std::string> param_names_;
std::vector<Scope*> local_scopes_;
BuildStrategy strategy_;
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap* nccl_ctxs_;
#endif
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} }
}; };
for (auto &var_map : graph->Get<GraphVars>("vars")) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get()); insert_pending_var(version_pair.get());
...@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} }
} }
for (auto &var : graph->Get<GraphDepVars>("dep_vars")) { for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get()); insert_pending_var(var.get());
} }
for (auto &op : graph->Get<GraphOps>("ops")) { for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
...@@ -85,3 +85,10 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -85,3 +85,10 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_device_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars)
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);
...@@ -23,26 +23,14 @@ namespace framework { ...@@ -23,26 +23,14 @@ namespace framework {
namespace details { namespace details {
class SSAGraghBuilderWithChecker : public SSAGraphBuilder { class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public: protected:
explicit SSAGraghBuilderWithChecker( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {}
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph)); PADDLE_ENFORCE(IsValidGraph(graph.get()));
PADDLE_ENFORCE(IsValidGraph(new_graph.get())); return graph;
return new_graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
} }
bool IsValidGraph(const ir::Graph* graph) const; bool IsValidGraph(const ir::Graph* graph) const;
private:
std::unique_ptr<SSAGraphBuilder> builder_;
}; };
} // namespace details } // namespace details
......
...@@ -32,7 +32,9 @@ class SSAGraphExecutor { ...@@ -32,7 +32,9 @@ class SSAGraphExecutor {
virtual ~SSAGraphExecutor(); virtual ~SSAGraphExecutor();
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0; virtual const ir::Graph& Graph() const = 0;
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -22,7 +22,7 @@ namespace details { ...@@ -22,7 +22,7 @@ namespace details {
template <typename Callback> template <typename Callback>
static inline void IterAllVar(const ir::Graph &graph, Callback callback) { static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) { for (auto &each : graph.Get<GraphVars>(kGraphVars)) {
for (auto &pair1 : each) { for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) { for (auto &pair2 : pair1.second) {
callback(*pair2); callback(*pair2);
...@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) { ...@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
} }
} }
for (auto &var : graph.Get<GraphDepVars>("dep_vars")) { for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) {
callback(*var); callback(*var);
} }
} }
...@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>("ops")) { for (auto &op : graph.Get<GraphOps>(kGraphOps)) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
...@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_device_print_pass,
paddle::framework::details::SSAGraghBuilderWithPrinter);
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#include <fstream>
#include <iosfwd> #include <iosfwd>
#include <ostream>
#include <string> #include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
...@@ -34,38 +36,15 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { ...@@ -34,38 +36,15 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
}; };
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public: protected:
SSAGraghBuilderWithPrinter(std::ostream& sout, std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ref_(sout) {}
SSAGraghBuilderWithPrinter(std::unique_ptr<std::ostream>&& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph)); std::unique_ptr<std::ostream> fout(
printer_->Print(*new_graph, stream_ref_); new std::ofstream(Get<const std::string>("debug_graphviz_path")));
return new_graph; PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
return graph;
} }
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
private:
std::unique_ptr<SSAGraphPrinter> printer_;
std::unique_ptr<SSAGraphBuilder> builder_;
std::unique_ptr<std::ostream> stream_ptr_;
std::ostream& stream_ref_;
}; };
} // namespace details } // namespace details
......
...@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<OpHandleBase *> delayed_ops; std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars // Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, &ready_vars, var.get()); InsertPendingVar(&pending_vars, &ready_vars, var.get());
} }
for (auto &op : graph_->Get<details::GraphOps>("ops")) { for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
...@@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
......
...@@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph); std::unique_ptr<ir::Graph> &&graph);
const ir::Graph &Graph() const { return *graph_; }
// Run a SSAGraph by a thread pool // Run a SSAGraph by a thread pool
// Use topological sort algorithm // Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
......
cc_library(node SRCS node.cc DEPS proto_desc) cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node) cc_library(graph SRCS graph.cc DEPS node)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node) cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph op_registry) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
...@@ -40,14 +40,21 @@ class Graph { ...@@ -40,14 +40,21 @@ class Graph {
attr_dels_.clear(); attr_dels_.clear();
} }
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
}
template <typename AttrType> template <typename AttrType>
AttrType &Get(const std::string &attr_name) const { AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.",
attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} }
template <typename AttrType> template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) { void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0); PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph",
attr_name);
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() { attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name; VLOG(3) << "deleting " << attr_name;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle {
namespace framework {
namespace ir {
static const char kGraphVizPath[] = "graph_viz_path";
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout;
size_t var_id = 0;
std::unordered_map<const ir::Node*, size_t> vars;
sout << "digraph G {\n";
for (const ir::Node* n : graph->Nodes()) {
if (n->NodeType() != ir::Node::Type::kVariable) continue;
size_t cur_var_id = var_id++;
vars[n] = cur_var_id;
sout << "var_" << cur_var_id << " [label=\"" << n->Name() << "\"]"
<< std::endl;
}
size_t op_id = 0;
for (const ir::Node* n : graph->Nodes()) {
if (n->NodeType() != ir::Node::Type::kOperation) continue;
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << n->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : n->inputs) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
}
for (auto out : n->outputs) {
std::string var_name = "var_" + std::to_string(vars[out]);
sout << op_name << " -> " << var_name << std::endl;
}
}
sout << "}\n";
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class GraphVizPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -13,7 +13,34 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,34 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework {} // namespace framework namespace framework {
namespace ir {
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
PADDLE_ENFORCE(!applied_, "Pass can only Apply() once.");
PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
"Required pass atrribute %s not set.", attr);
}
for (const std::string& attr : required_graph_attrs_) {
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.",
attr);
}
auto applied_graph = ApplyImpl(std::move(graph));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*applied_graph),
"Illegal Pass. Generated graph shouldn't has cycle.");
applied_ = true;
return applied_graph;
}
PassRegistry& PassRegistry::Instance() {
static PassRegistry g_pass_info_map;
return g_pass_info_map;
}
} // namespace ir
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,21 +14,187 @@ limitations under the License. */ ...@@ -14,21 +14,187 @@ limitations under the License. */
#pragma once #pragma once
#include <functional>
#include <map>
#include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <typename PassType>
struct PassRegistrar;
class Pass { class Pass {
public: public:
Pass() = default; Pass() = default;
virtual ~Pass() {} virtual ~Pass() {
for (auto &attr : attrs_) {
if (attr_dels_.find(attr.first) != attr_dels_.end()) {
attr_dels_[attr.first]();
}
}
attrs_.clear();
attr_dels_.clear();
}
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const;
// Get a reference to the attributed previously set.
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for pass.", attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
}
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass",
attr_name);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;
delete attr;
};
}
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr;
}
protected:
virtual std::unique_ptr<Graph> ApplyImpl(
std::unique_ptr<Graph> graph) const = 0;
private:
template <typename PassType>
friend struct PassRegistrar;
void RegisterRequiredPassAttrs(const std::unordered_set<std::string> &attrs) {
required_pass_attrs_.insert(attrs.begin(), attrs.end());
}
void RegisterRequiredGraphAttrs(
const std::unordered_set<std::string> &attrs) {
required_graph_attrs_.insert(attrs.begin(), attrs.end());
}
mutable bool applied_{false};
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
using PassCreator = std::function<std::unique_ptr<Pass>()>;
class Registrar {
public:
// In our design, various kinds of passes,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_PASS macros to
// call this method. So, as long as the callee code calls USE_PASS, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0; class PassRegistry {
public:
static PassRegistry &Instance();
bool Has(const std::string &pass_type) const {
return map_.find(pass_type) != map_.end();
}
void Insert(const std::string &pass_type, const PassCreator &pass_creator) {
PADDLE_ENFORCE(!Has(pass_type), "Pass %s has been registered", pass_type);
map_.insert({pass_type, pass_creator});
}
std::unique_ptr<Pass> Get(const std::string &pass_type) const {
PADDLE_ENFORCE(Has(pass_type), "Pass %s has not been registered",
pass_type);
return map_.at(pass_type)();
}
private:
PassRegistry() = default;
std::unordered_map<std::string, PassCreator> map_;
DISABLE_COPY_AND_ASSIGN(PassRegistry);
}; };
template <typename PassType>
struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
"'%s' is registered more than once.", pass_type);
PassRegistry::Instance().Insert(
pass_type, [this]() -> std::unique_ptr<Pass> {
std::unique_ptr<Pass> pass(new PassType());
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
return pass;
});
}
PassRegistrar<PassType> &RequirePassAttr(const std::string &attr) {
required_pass_attrs_.insert(attr);
return *this;
}
PassRegistrar<PassType> &RequireGraphAttr(const std::string &attr) {
required_graph_attrs_.insert(attr);
return *this;
}
private:
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \
TouchPassRegistrar_##pass_type()
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
void BuildCircleGraph(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
o2->outputs.push_back(v2);
o1->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o1);
}
class TestPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const {
graph->Set<int>("copy_test_pass_attr", new int);
graph->Set<int>("copy_test_graph_attr", new int);
int test_pass_attr = this->Get<int>("test_pass_attr");
graph->Get<int>("copy_test_pass_attr") = test_pass_attr + 1;
int test_graph_attr = graph->Get<int>("test_graph_attr");
graph->Get<int>("copy_test_graph_attr") = test_graph_attr + 1;
return graph;
}
};
TEST(PassTest, TestPassAttrCheck) {
ProgramDesc prog;
auto pass = PassRegistry::Instance().Get("test_pass");
std::unique_ptr<Graph> graph(new Graph(prog));
std::string exception;
try {
graph = pass->Apply(std::move(graph));
} catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find("test_pass_attr not set") != exception.npos);
int val = 1;
graph.reset(new Graph(prog));
pass->SetNotOwned<int>("test_pass_attr", &val);
try {
graph = pass->Apply(std::move(graph));
} catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find("test_graph_attr not set") != exception.npos);
graph.reset(new Graph(prog));
graph->Set<int>("test_graph_attr", new int);
graph->Get<int>("test_graph_attr") = 1;
graph = pass->Apply(std::move(graph));
ASSERT_EQ(graph->Get<int>("copy_test_pass_attr"), 2);
ASSERT_EQ(graph->Get<int>("copy_test_graph_attr"), 2);
try {
graph = pass->Apply(std::move(graph));
} catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find("Pass can only Apply() once") != exception.npos);
pass = PassRegistry::Instance().Get("test_pass");
pass->SetNotOwned<int>("test_pass_attr", &val);
graph.reset(new Graph(prog));
BuildCircleGraph(graph.get());
graph->Set<int>("test_graph_attr", new int);
graph->Get<int>("test_graph_attr") = 2;
try {
auto tmp = pass->Apply(std::move(graph));
} catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find("shouldn't has cycle") != exception.npos);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(test_pass, paddle::framework::ir::TestPass)
.RequirePassAttr("test_pass_attr")
.RequireGraphAttr("test_graph_attr");
...@@ -679,6 +679,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -679,6 +679,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (var == nullptr) continue; if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>()); CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
} else if (var->IsType<framework::SelectedRows>()) {
CheckTensorNANOrInf(vname, var->Get<framework::SelectedRows>().value());
} }
} }
} }
......
...@@ -19,19 +19,80 @@ limitations under the License. */ ...@@ -19,19 +19,80 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes, const bool use_cuda,
#ifdef PADDLE_WITH_CUDA
const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) {
#else
const BuildStrategy &strategy) {
#endif
// Convert the program to graph.
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
// Apply a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_original_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
// Convert graph to run on multi-devices.
auto multi_device_pass =
ir::PassRegistry::Instance().Get("multi_device_pass");
multi_device_pass->SetNotOwned<const std::vector<platform::Place>>("places",
&places);
multi_device_pass->SetNotOwned<const std::string>("loss_var_name",
&loss_var_name);
multi_device_pass->SetNotOwned<const std::unordered_set<std::string>>(
"params", &param_names);
multi_device_pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
&local_scopes);
multi_device_pass->SetNotOwned<const BuildStrategy>("strategy", &strategy);
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
multi_device_pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
graph = multi_device_pass->Apply(std::move(graph));
// Apply a graph print pass to record a graph with device info.
if (!strategy.debug_graphviz_path_.empty()) {
auto multi_device_print_pass =
ir::PassRegistry::Instance().Get("multi_device_print_pass");
multi_device_print_pass->SetNotOwned<const std::string>(
"debug_graphviz_path", &strategy.debug_graphviz_path_);
multi_device_print_pass->Set<details::GraphvizSSAGraphPrinter>(
"graph_printer", new details::GraphvizSSAGraphPrinter);
graph = multi_device_print_pass->Apply(std::move(graph));
}
// Verify that the graph is correct for multi-device executor.
auto multi_device_check_pass =
ir::PassRegistry::Instance().Get("multi_device_check_pass");
graph = multi_device_check_pass->Apply(std::move(graph));
return graph;
}
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places) explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
...@@ -119,21 +180,19 @@ ParallelExecutor::ParallelExecutor( ...@@ -119,21 +180,19 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = var->Persistable(); var_infos.back().persistable_ = var->Persistable();
} }
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert // Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_,
build_strategy);
if (member_->use_cuda_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy,
member_->nccl_ctxs_.get());
#else #else
PADDLE_THROW("Not compiled with CUDA."); std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy);
#endif #endif
}
builder_ = builder_factory.Create();
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
graph = builder_->Apply(std::move(graph));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
...@@ -146,11 +205,18 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -146,11 +205,18 @@ void ParallelExecutor::BCastParamsToDevices(
// the initializing bcast, all vars would be bcast from device(0), // the initializing bcast, all vars would be bcast from device(0),
// otherwise // otherwise
// bcast from the specified device. // bcast from the specified device.
bool initializing = builder_.get() == nullptr ? true : false; bool initializing = member_->executor_ ? false : true;
for (auto &var : vars) { for (auto &var : vars) {
int var_dev_id = int var_dev_id = -1;
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var); if (member_->executor_) {
auto &sharded_var_device =
member_->executor_->Graph().Get<details::ShardedVarDevice>(
details::kShardedVarDevice);
if (sharded_var_device.find(var) != sharded_var_device.end()) {
var_dev_id = sharded_var_device.at(var);
}
}
if (!initializing && var_dev_id == -1) continue; if (!initializing && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr; framework::Variable *main_var = nullptr;
...@@ -286,3 +352,8 @@ ParallelExecutor::~ParallelExecutor() { ...@@ -286,3 +352,8 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(graph_viz_pass);
USE_PASS(multi_device_pass);
USE_PASS(multi_device_check_pass);
USE_PASS(multi_device_print_pass);
...@@ -70,7 +70,6 @@ class ParallelExecutor { ...@@ -70,7 +70,6 @@ class ParallelExecutor {
private: private:
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::unique_ptr<details::SSAGraphBuilder> builder_;
}; };
} // namespace framework } // namespace framework
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -67,10 +68,15 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -67,10 +68,15 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
const auto &ids_rows = ids_selected_rows->rows(); const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out"); auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size(); const size_t shard_num = outs.size();
for (auto &out : outs) {
out->mutable_rows()->clear();
}
// get rows for outputs // get rows for outputs
for (auto &id : ids_rows) { std::unordered_map<int64_t, size_t> id_to_index;
size_t shard_id = static_cast<size_t>(id) % shard_num; for (size_t i = 0; i < ids_rows.size(); ++i) {
outs[shard_id]->mutable_rows()->push_back(id); id_to_index[ids_rows[i]] = i;
size_t shard_id = static_cast<size_t>(ids_rows[i]) % shard_num;
outs[shard_id]->mutable_rows()->push_back(ids_rows[i]);
} }
int64_t row_width = ids_dims[1]; int64_t row_width = ids_dims[1];
...@@ -80,7 +86,8 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -80,7 +86,8 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
{static_cast<int64_t>(out->rows().size()), row_width}); {static_cast<int64_t>(out->rows().size()), row_width});
T *output = out->mutable_value()->mutable_data<T>(ddim, place); T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (int64_t i = 0; i < ddim[0]; ++i) { for (int64_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width, ids + out->rows()[i] * row_width, memcpy(output + i * row_width,
ids + id_to_index[out->rows()[i]] * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} }
} }
......
...@@ -60,3 +60,7 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) ...@@ -60,3 +60,7 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor) nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor)
cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
IF(WITH_GPU)
nv_test(cuda_helper_test SRCS cuda_helper_test.cu)
ENDIF()
...@@ -14,6 +14,10 @@ limitations under the License. */ ...@@ -14,6 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -36,6 +40,18 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, ...@@ -36,6 +40,18 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
#endif #endif
} }
// CUDA 9.0 have native compatible float16 shfl_down
#if CUDA_VERSION < 9000
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
float16 val, int delta,
int width) {
half tmp = static_cast<half>(val);
__shfl_down(tmp, static_cast<unsigned>(delta), width);
return float16(tmp);
}
#endif
template <typename T> template <typename T>
__forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line, __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
int width = 32) { int width = 32) {
...@@ -46,6 +62,11 @@ __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line, ...@@ -46,6 +62,11 @@ __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
#endif #endif
} }
template <typename T>
HOSTDEVICE T Infinity() {
return INFINITY;
}
template <typename T> template <typename T>
__device__ T reduceSum(T val, int tid, int len) { __device__ T reduceSum(T val, int tid, int len) {
// NOTE(zcd): The warp size should be taken from the // NOTE(zcd): The warp size should be taken from the
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <bitset>
#include <iostream>
#include <random>
#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using paddle::platform::float16;
#define CUDA_ATOMIC_KERNEL(op, T) \
__global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \
i += blockDim.x * gridDim.x) { \
paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \
} \
}
template <typename T>
struct AddFunctor {
T operator()(const T& a, const T& b) { return a + b; }
};
template <typename T>
struct SubFunctor {
T operator()(const T& a, const T& b) { return a - b; }
};
// NOTE(dzhwinter): the float16 add has small underflow/overflow
// so we use EXPECT_NEAR to check the result.
#define ARITHMETIC_KERNEL_LAUNCH(op, T) \
void Test##T##op(size_t num) { \
T *in1, *in2, *out; \
T *d_in1, *d_in2; \
size_t size = sizeof(T) * num; \
cudaMalloc(reinterpret_cast<void**>(&d_in1), size); \
cudaMalloc(reinterpret_cast<void**>(&d_in2), size); \
in1 = reinterpret_cast<T*>(malloc(size)); \
in2 = reinterpret_cast<T*>(malloc(size)); \
out = reinterpret_cast<T*>(malloc(size)); \
std::minstd_rand engine; \
std::uniform_real_distribution<double> dist(0.0, 1.0); \
for (size_t i = 0; i < num; ++i) { \
in1[i] = static_cast<T>(dist(engine)); \
in2[i] = static_cast<T>(dist(engine)); \
} \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \
cudaDeviceSynchronize(); \
cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \
cudaDeviceSynchronize(); \
for (size_t i = 0; i < num; ++i) { \
EXPECT_NEAR(static_cast<float>(out[i]), \
static_cast<float>(op##Functor<T>()(in1[i], in2[i])), \
0.001); \
} \
free(in1); \
free(in2); \
free(out); \
cudaFree(d_in1); \
cudaFree(d_in2); \
}
CUDA_ATOMIC_KERNEL(Add, float);
CUDA_ATOMIC_KERNEL(Add, double);
CUDA_ATOMIC_KERNEL(Add, float16);
ARITHMETIC_KERNEL_LAUNCH(Add, float);
ARITHMETIC_KERNEL_LAUNCH(Add, double);
ARITHMETIC_KERNEL_LAUNCH(Add, float16);
namespace paddle {
namespace platform {
USE_CUDA_ATOMIC(Sub, int);
};
};
CUDA_ATOMIC_KERNEL(Sub, int);
ARITHMETIC_KERNEL_LAUNCH(Sub, int);
// cuda primitives
TEST(CudaAtomic, Add) {
TestfloatAdd(static_cast<size_t>(10));
TestfloatAdd(static_cast<size_t>(1024 * 1024));
TestdoubleAdd(static_cast<size_t>(10));
TestdoubleAdd(static_cast<size_t>(1024 * 1024));
}
TEST(CudaAtomic, Sub) {
TestintSub(static_cast<size_t>(10));
TestintSub(static_cast<size_t>(1024 * 1024));
}
TEST(CudaAtomic, float16) {
using paddle::platform::float16;
Testfloat16Add(static_cast<size_t>(1));
Testfloat16Add(static_cast<size_t>(2));
Testfloat16Add(static_cast<size_t>(3));
Testfloat16Add(static_cast<size_t>(10));
Testfloat16Add(static_cast<size_t>(1024 * 1024));
}
...@@ -14,12 +14,14 @@ limitations under the License. */ ...@@ -14,12 +14,14 @@ limitations under the License. */
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <stdio.h>
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
#define CUDA_ATOMIC_WRAPPER(op, T) \ #define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T* address, const T val) __device__ __forceinline__ T CudaAtomic##op(T *address, const T val)
#define USE_CUDA_ATOMIC(op, T) \ #define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
...@@ -42,17 +44,17 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) { ...@@ -42,17 +44,17 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) {
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64"); "long long should be int64");
return CudaAtomicAdd( return CudaAtomicAdd(
reinterpret_cast<unsigned long long int*>(address), // NOLINT reinterpret_cast<unsigned long long int *>(address), // NOLINT
static_cast<unsigned long long int>(val)); // NOLINT static_cast<unsigned long long int>(val)); // NOLINT
} }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC(Add, double); USE_CUDA_ATOMIC(Add, double);
#else #else
CUDA_ATOMIC_WRAPPER(Add, double) { CUDA_ATOMIC_WRAPPER(Add, double) {
unsigned long long int* address_as_ull = // NOLINT unsigned long long int *address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int*>(address); // NOLINT reinterpret_cast<unsigned long long int *>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT unsigned long long int old = *address_as_ull, assumed; // NOLINT
do { do {
assumed = old; assumed = old;
...@@ -64,6 +66,67 @@ CUDA_ATOMIC_WRAPPER(Add, double) { ...@@ -64,6 +66,67 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
return __longlong_as_double(old); return __longlong_as_double(old);
} }
#endif
#ifdef PADDLE_CUDA_FP16
// NOTE(dzhwinter): cuda do not have atomicCAS for half.
// Just use the half address as a unsigned value address and
// do the atomicCAS. According to the value store at high 16 bits
// or low 16 bits, then do a different sum and CAS.
// Given most warp-threads will failed on the atomicCAS, so this
// implemented should be avoided in high concurrency. It's will be
// slower than the way convert value into 32bits and do a full atomicCAS.
// convert the value into float and do the add arithmetic.
// then store the result into a uint32.
inline __device__ uint32_t add_to_low_half(uint32_t val, float x) {
float16 low_half;
// the float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xffffu);
low_half = static_cast<float16>(static_cast<float>(low_half) + x);
return (val & 0xffff0000u) | low_half.x;
}
inline __device__ uint32_t add_to_high_half(uint32_t val, float x) {
float16 high_half;
// the float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(static_cast<float>(high_half) + x);
return (val & 0xffffu) | (static_cast<uint32_t>(high_half.x) << 16);
}
CUDA_ATOMIC_WRAPPER(Add, float16) {
// concrete packed float16 value may exsits in lower or higher 16bits
// of the 32bits address.
uint32_t *address_as_ui =
reinterpret_cast<uint32_t *>(reinterpret_cast<char *>(address) -
(reinterpret_cast<size_t>(address) & 2));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t sum;
uint32_t newval;
uint32_t assumed;
if (((size_t)address & 2) == 0) {
// the float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old & 0xffffu;
return ret;
} else {
// the float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif #endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -67,8 +67,11 @@ struct float16; ...@@ -67,8 +67,11 @@ struct float16;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
// NOTE():
// Do not move the eigen.h header, otherwise the eigen_vector<bool> will failed.
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -898,6 +901,30 @@ struct is_pod<paddle::platform::float16> { ...@@ -898,6 +901,30 @@ struct is_pod<paddle::platform::float16> {
is_standard_layout<paddle::platform::float16>::value; is_standard_layout<paddle::platform::float16>::value;
}; };
template <>
struct is_floating_point<paddle::platform::float16>
: std::integral_constant<
bool, std::is_same<paddle::platform::float16,
typename std::remove_cv<
paddle::platform::float16>::type>::value> {};
template <>
struct is_signed<paddle::platform::float16> {
static const bool value = true;
};
template <>
struct is_unsigned<paddle::platform::float16> {
static const bool value = false;
};
inline bool isnan(const paddle::platform::float16& a) {
return paddle::platform::isnan(a);
}
inline bool isinf(const paddle::platform::float16& a) {
return paddle::platform::isinf(a);
}
template <> template <>
struct numeric_limits<paddle::platform::float16> { struct numeric_limits<paddle::platform::float16> {
static const bool is_specialized = true; static const bool is_specialized = true;
......
...@@ -141,10 +141,36 @@ TEST(float16, lod_tensor_cpu) { ...@@ -141,10 +141,36 @@ TEST(float16, lod_tensor_cpu) {
} }
} }
TEST(float16, floating) {
// compile time assert.
PADDLE_ASSERT(std::is_floating_point<float16>::value);
}
TEST(float16, print) { TEST(float16, print) {
float16 a = float16(1.0f); float16 a = float16(1.0f);
std::cout << a << std::endl; std::cout << a << std::endl;
} }
// CPU test
TEST(float16, isinf) {
float16 a;
a.x = 0x7c00;
float16 b = float16(INFINITY);
float16 c = static_cast<float16>(INFINITY);
EXPECT_EQ(std::isinf(a), true);
EXPECT_EQ(std::isinf(b), true);
EXPECT_EQ(std::isinf(c), true);
}
TEST(float16, isnan) {
float16 a;
a.x = 0x7fff;
float16 b = float16(NAN);
float16 c = static_cast<float16>(NAN);
EXPECT_EQ(std::isnan(a), true);
EXPECT_EQ(std::isnan(b), true);
EXPECT_EQ(std::isnan(c), true);
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -11,11 +11,13 @@ limitations under the License. */ ...@@ -11,11 +11,13 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <bitset>
#include <iostream>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/legacy/utils/Logging.h"
#define ARITHMETIC_KERNEL(op_type, sign) \ #define ARITHMETIC_KERNEL(op_type, sign) \
__global__ void op_type(const half* in1, const half* in2, half* out) { \ __global__ void op_type(const half* in1, const half* in2, half* out) { \
...@@ -241,6 +243,72 @@ TEST(float16, lod_tensor_on_gpu) { ...@@ -241,6 +243,72 @@ TEST(float16, lod_tensor_on_gpu) {
} }
} }
template <typename T>
struct Functor {
bool operator()(const T& val) {
return std::type_index(typeid(T)) ==
std::type_index(typeid(platform::float16));
}
};
TEST(float16, typeid) {
// the framework heavily used typeid hash
Functor<float16> functor;
float16 a = float16(.0f);
Functor<int> functor2;
int b(0);
// compile time assert
PADDLE_ASSERT(functor(a) == true);
PADDLE_ASSERT(functor2(b) == false);
}
// GPU test
TEST(float16, isinf) {
float16 a;
a.x = 0x7c00;
float16 b = float16(INFINITY);
// underflow to 0
float16 native_a(5e-40f);
// overflow to inf
float16 native_b(5e40f);
EXPECT_EQ(std::isinf(a), true);
EXPECT_EQ(std::isinf(b), true);
EXPECT_EQ(std::isinf(native_b), true);
EXPECT_EQ(native_a, float16(0));
}
TEST(float16, isnan) {
float16 a;
a.x = 0x7fff;
float16 b = float16(NAN);
float16 c = float16(5e40);
// inf * +-0 will get a nan
float16 d = c * float16(0);
EXPECT_EQ(std::isnan(a), true);
EXPECT_EQ(std::isnan(b), true);
EXPECT_EQ(std::isnan(d), true);
}
TEST(float16, cast) {
float16 a;
a.x = 0x0070;
auto b = a;
{
// change semantic, keep the same value
float16 c = reinterpret_cast<float16&>(reinterpret_cast<unsigned&>(b));
EXPECT_EQ(b, c);
}
{
// use uint32 low 16 bit store float16
uint32_t c = reinterpret_cast<uint32_t&>(b);
float16 d;
d.x = c;
EXPECT_EQ(b, d);
}
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
#endif // PADDLE_CUDA_FP16 #endif // PADDLE_CUDA_FP16
...@@ -21,6 +21,7 @@ from ..layer_helper import LayerHelper, unique_name ...@@ -21,6 +21,7 @@ from ..layer_helper import LayerHelper, unique_name
from ..initializer import force_init_on_cpu from ..initializer import force_init_on_cpu
from ops import logical_and, logical_not, logical_or from ops import logical_and, logical_not, logical_or
import numpy import numpy
import warnings
__all__ = [ __all__ = [
'While', 'While',
...@@ -280,6 +281,9 @@ class ParallelDo(object): ...@@ -280,6 +281,9 @@ class ParallelDo(object):
""" """
def __init__(self, places, use_nccl=False, name=None): def __init__(self, places, use_nccl=False, name=None):
warnings.warn(
"API ParallelDo is deprecated since 0.15.0. Please use ParallelExecutor instead.",
Warning)
self.helper = LayerHelper("parallel_do", name=name) self.helper = LayerHelper("parallel_do", name=name)
self.inputs = [] self.inputs = []
self.places = places self.places = places
...@@ -338,7 +342,7 @@ class ParallelDo(object): ...@@ -338,7 +342,7 @@ class ParallelDo(object):
return [parent_block.var(name) for name in params] return [parent_block.var(name) for name in params]
def complete_op(self): def _complete_op(self):
main_program = self.helper.main_program main_program = self.helper.main_program
current_block = main_program.current_block() current_block = main_program.current_block()
parent_block = self.parent_block() parent_block = self.parent_block()
...@@ -394,7 +398,7 @@ class BlockGuardWithCompletion(BlockGuard): ...@@ -394,7 +398,7 @@ class BlockGuardWithCompletion(BlockGuard):
if exc_type is not None: if exc_type is not None:
return False return False
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
self.rnn.complete_op() self.rnn._complete_op()
return super(BlockGuardWithCompletion, self).__exit__(exc_type, exc_val, return super(BlockGuardWithCompletion, self).__exit__(exc_type, exc_val,
exc_tb) exc_tb)
...@@ -470,7 +474,7 @@ class StaticRNN(object): ...@@ -470,7 +474,7 @@ class StaticRNN(object):
if shape is None or batch_ref is None: if shape is None or batch_ref is None:
raise ValueError( raise ValueError(
"if init is None, memory at least need shape and batch_ref") "if init is None, memory at least need shape and batch_ref")
parent_block = self.parent_block() parent_block = self._parent_block()
var_name = unique_name.generate("@".join( var_name = unique_name.generate("@".join(
[self.helper.name, "memory_boot"])) [self.helper.name, "memory_boot"]))
boot_var = parent_block.create_var( boot_var = parent_block.create_var(
...@@ -527,7 +531,7 @@ class StaticRNN(object): ...@@ -527,7 +531,7 @@ class StaticRNN(object):
outputs={'Out': tmp_o}, outputs={'Out': tmp_o},
attrs={'dtype': o.dtype}) attrs={'dtype': o.dtype})
out_var = self.parent_block().create_var( out_var = self._parent_block().create_var(
name=tmp_o.name, name=tmp_o.name,
shape=[self.seq_len] + list(tmp_o.shape), shape=[self.seq_len] + list(tmp_o.shape),
dtype=tmp_o.dtype) dtype=tmp_o.dtype)
...@@ -543,7 +547,7 @@ class StaticRNN(object): ...@@ -543,7 +547,7 @@ class StaticRNN(object):
raise TypeError("update memory should take variables") raise TypeError("update memory should take variables")
self.memories[mem.name].mem = var self.memories[mem.name].mem = var
def parent_block(self): def _parent_block(self):
prog = self.helper.main_program prog = self.helper.main_program
parent_idx = prog.current_block().parent_idx parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0 assert parent_idx >= 0
...@@ -560,10 +564,10 @@ class StaticRNN(object): ...@@ -560,10 +564,10 @@ class StaticRNN(object):
else: else:
return self.outputs return self.outputs
def complete_op(self): def _complete_op(self):
main_program = self.helper.main_program main_program = self.helper.main_program
rnn_block = main_program.current_block() rnn_block = main_program.current_block()
parent_block = self.parent_block() parent_block = self._parent_block()
local_inputs = set() local_inputs = set()
...@@ -643,7 +647,7 @@ class WhileGuard(BlockGuard): ...@@ -643,7 +647,7 @@ class WhileGuard(BlockGuard):
if exc_type is not None: if exc_type is not None:
return False return False
self.while_op.status = While.AFTER_WHILE_BLOCK self.while_op.status = While.AFTER_WHILE_BLOCK
self.while_op.complete() self.while_op._complete()
return super(WhileGuard, self).__exit__(exc_type, exc_val, exc_tb) return super(WhileGuard, self).__exit__(exc_type, exc_val, exc_tb)
...@@ -690,7 +694,7 @@ class While(object): ...@@ -690,7 +694,7 @@ class While(object):
def block(self): def block(self):
return WhileGuard(self) return WhileGuard(self)
def complete(self): def _complete(self):
main_program = self.helper.main_program main_program = self.helper.main_program
while_block = main_program.current_block() while_block = main_program.current_block()
parent_block = main_program.block(main_program.current_block() parent_block = main_program.block(main_program.current_block()
......
...@@ -40,7 +40,7 @@ function(py_test_modules TARGET_NAME) ...@@ -40,7 +40,7 @@ function(py_test_modules TARGET_NAME)
${PYTHON_EXECUTABLE} ${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES} ${PYTHON_EXECUTABLE} ${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (py_test_modules_SERIAL) if (py_test_modules_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY SERIAL 1) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif() endif()
endif() endif()
endfunction() endfunction()
......
...@@ -19,6 +19,7 @@ import math ...@@ -19,6 +19,7 @@ import math
import unittest import unittest
import os import os
import sys
import signal import signal
import subprocess import subprocess
...@@ -56,7 +57,7 @@ class TestDistSeResneXt2x2(unittest.TestCase): ...@@ -56,7 +57,7 @@ class TestDistSeResneXt2x2(unittest.TestCase):
except os.error: except os.error:
retry_times -= 1 retry_times -= 1
def no_test_with_place(self): def test_with_place(self):
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN # *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
required_envs = { required_envs = {
"PATH": os.getenv("PATH"), "PATH": os.getenv("PATH"),
...@@ -70,9 +71,15 @@ class TestDistSeResneXt2x2(unittest.TestCase): ...@@ -70,9 +71,15 @@ class TestDistSeResneXt2x2(unittest.TestCase):
local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \ local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \
(self._python_interp, "127.0.0.1:1234", "127.0.0.1:1234", 1) (self._python_interp, "127.0.0.1:1234", "127.0.0.1:1234", 1)
local_proc = subprocess.Popen( local_proc = subprocess.Popen(
local_cmd.split(" "), stdout=subprocess.PIPE, env=env_local) local_cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_local)
local_proc.wait() local_proc.wait()
local_ret = local_proc.stdout.read() out, err = local_proc.communicate()
local_ret = out
sys.stderr.write('local_loss: %s\n' % local_ret)
sys.stderr.write('local_stderr: %s\n' % err)
# Run dist train to compare with local results # Run dist train to compare with local results
ps0, ps1 = self.start_pserver() ps0, ps1 = self.start_pserver()
...@@ -92,13 +99,22 @@ class TestDistSeResneXt2x2(unittest.TestCase): ...@@ -92,13 +99,22 @@ class TestDistSeResneXt2x2(unittest.TestCase):
FNULL = open(os.devnull, 'w') FNULL = open(os.devnull, 'w')
tr0_proc = subprocess.Popen( tr0_proc = subprocess.Popen(
tr0_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env0) tr0_cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env0)
tr1_proc = subprocess.Popen( tr1_proc = subprocess.Popen(
tr1_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env1) tr1_cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env1)
tr0_proc.wait() tr0_proc.wait()
tr1_proc.wait() tr1_proc.wait()
loss_data0 = tr0_proc.stdout.read() out, err = tr0_proc.communicate()
sys.stderr.write('dist_stderr: %s\n' % err)
loss_data0 = out
sys.stderr.write('dist_loss: %s\n' % loss_data0)
lines = loss_data0.split("\n") lines = loss_data0.split("\n")
dist_first_loss = eval(lines[0].replace(" ", ","))[0] dist_first_loss = eval(lines[0].replace(" ", ","))[0]
dist_last_loss = eval(lines[1].replace(" ", ","))[0] dist_last_loss = eval(lines[1].replace(" ", ","))[0]
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestSplitIdsOp(OpTest): class TestSplitIdsOp(OpTest):
...@@ -31,5 +33,55 @@ class TestSplitIdsOp(OpTest): ...@@ -31,5 +33,55 @@ class TestSplitIdsOp(OpTest):
self.check_output() self.check_output()
class TestSpliteIds(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
return places
def test_check_output(self):
for place in self.get_places():
self.check_with_place(place)
def check_with_place(self, place):
scope = core.Scope()
rows = [0, 5, 7, 4, 9]
height = 20
row_numel = 2
# initialize input variable X
x = scope.var('X').get_selected_rows()
x.set_rows(rows)
x.set_height(height)
np_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
for j in range(row_numel):
np_array[i, j] = rows[i] + j
x_tensor = x.get_tensor()
x_tensor.set(np_array, place)
outs_name = ["out%d" % i for i in xrange(3)]
outs = [
scope.var(var_name).get_selected_rows() for var_name in outs_name
]
# expected output selected rows
expected_out_rows = [[0, 9], [7, 4], [5]]
op = Operator("split_ids", Ids="X", Out=outs_name)
for _ in range(3):
op.run(scope, place)
for i in range(len(outs)):
expected_rows = expected_out_rows[i]
self.assertEqual(outs[i].rows(), expected_rows)
for j in range(len(expected_rows)):
row = expected_rows[j]
self.assertAlmostEqual(
float(row), np.array(outs[i].get_tensor())[j, 0])
self.assertAlmostEqual(
float(row + 1), np.array(outs[i].get_tensor())[j, 1])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -347,6 +347,7 @@ class DistributeTranspiler(object): ...@@ -347,6 +347,7 @@ class DistributeTranspiler(object):
# step1 # step1
pserver_program = Program() pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed
# step2: Create vars to receive vars at parameter servers. # step2: Create vars to receive vars at parameter servers.
recv_inputs = [] recv_inputs = []
for v in self.param_grad_ep_mapping[endpoint]["params"]: for v in self.param_grad_ep_mapping[endpoint]["params"]:
...@@ -544,6 +545,7 @@ class DistributeTranspiler(object): ...@@ -544,6 +545,7 @@ class DistributeTranspiler(object):
""" """
s_prog = Program() s_prog = Program()
orig_s_prog = default_startup_program() orig_s_prog = default_startup_program()
s_prog.random_seed = orig_s_prog.random_seed
params = self.param_grad_ep_mapping[endpoint]["params"] params = self.param_grad_ep_mapping[endpoint]["params"]
def _get_splited_name_and_shape(varname): def _get_splited_name_and_shape(varname):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册