提交 5ce1a960 编写于 作者: Y Yancey1989

move bcast op into pass

上级 b681537e
...@@ -140,5 +140,11 @@ def parse_args(): ...@@ -140,5 +140,11 @@ def parse_args():
'--use_lars', '--use_lars',
action='store_true', action='store_true',
help='If set, use lars for optimizers, ONLY support resnet module.') help='If set, use lars for optimizers, ONLY support resnet module.')
parser.add_argument(
'--reduce_strategy',
type=str,
choices=['reduce', 'all_reduce'],
default='all_reduce',
help='Specify the reduce strategy, can be reduce, all_reduce')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, ...@@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
strategy = fluid.ExecutionStrategy() strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.cpus strategy.num_threads = args.cpus
strategy.allow_op_delay = False strategy.allow_op_delay = False
build_strategy = fluid.BuildStrategy()
if args.reduce_strategy == "reduce":
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.Reduce
else:
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.AllReduce
avg_loss = train_args[0] avg_loss = train_args[0]
if args.update_method == "pserver": if args.update_method == "pserver":
...@@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, ...@@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
avg_loss.name, avg_loss.name,
main_program=train_prog, main_program=train_prog,
exec_strategy=strategy, exec_strategy=strategy,
build_strategy=build_strategy,
num_trainers=num_trainers, num_trainers=num_trainers,
trainer_id=trainer_id) trainer_id=trainer_id)
......
...@@ -67,11 +67,14 @@ def cnn_model(data): ...@@ -67,11 +67,14 @@ def cnn_model(data):
def get_model(args, is_train, main_prog, startup_prog): def get_model(args, is_train, main_prog, startup_prog):
# NOTE: mnist is small, we don't implement data sharding yet. # NOTE: mnist is small, we don't implement data sharding yet.
filelist = [ opt = None
os.path.join(args.data_path, f) for f in os.listdir(args.data_path) data_file_handle = None
]
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
if args.use_reader_op: if args.use_reader_op:
filelist = [
os.path.join(args.data_path, f)
for f in os.listdir(args.data_path)
]
data_file_handle = fluid.layers.open_files( data_file_handle = fluid.layers.open_files(
filenames=filelist, filenames=filelist,
shapes=[[-1, 1, 28, 28], (-1, 1)], shapes=[[-1, 1, 28, 28], (-1, 1)],
...@@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog): ...@@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog):
if is_train: if is_train:
opt = fluid.optimizer.AdamOptimizer( opt = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, beta1=0.9, beta2=0.999) learning_rate=0.001, beta1=0.9, beta2=0.999)
opt.minimize() opt.minimize(avg_cost)
if args.memory_optimize: if args.memory_optimize:
fluid.memory_optimize(main_prog) fluid.memory_optimize(main_prog)
......
...@@ -46,7 +46,12 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -46,7 +46,12 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif #endif
void AllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
platform::RecordEvent r("all_reduce", nullptr); if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
if (NoDummyInputSize() == 1) { if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
......
...@@ -15,12 +15,19 @@ ...@@ -15,12 +15,19 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void BroadcastOpHandle::RunImpl() { void BroadcastOpHandle::RunImpl() {
if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
if (places_.size() == 1) return; if (places_.size() == 1) return;
// The input and output may have dummy vars. // The input and output may have dummy vars.
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/data_balance_op_handle.h" #include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -86,6 +87,12 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan( ...@@ -86,6 +87,12 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
} }
void DataBalanceOpHandle::RunImpl() { void DataBalanceOpHandle::RunImpl() {
if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
PADDLE_ENFORCE_GT(places_.size(), 1, PADDLE_ENFORCE_GT(places_.size(), 1,
"Data balance can only be enabled when the number of " "Data balance can only be enabled when the number of "
"places to run larger than 1."); "places to run larger than 1.");
......
...@@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
size_t cur_device_id = 0; size_t cur_device_id = 0;
bool is_forwarding = true; bool is_forwarding = true;
bool is_dist_train = false;
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
if (boost::get<int>( if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
CreateRPCOp(&result, node); int op_dev_id = CreateRPCOp(&result, node);
PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") {
auto recv_vars_attr =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient]
if (recv_vars_attr[0].find(".block") == std::string::npos) {
bcast_var_name_set[op_dev_id].emplace(recv_vars_attr[0]);
}
}
is_dist_train = true;
} else if (IsDistTrainOp(node, send_vars, recv_vars)) { } else if (IsDistTrainOp(node, send_vars, recv_vars)) {
CreateDistTrainOp(&result, node); int op_dev_id = CreateDistTrainOp(&result, node);
if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set[op_dev_id].emplace(origin_param_name);
}
} else if (IsScaleLossOp(node)) { } else if (IsScaleLossOp(node)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
...@@ -414,7 +431,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -414,7 +431,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice) graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(g_name, cur_device_id); .emplace(g_name, cur_device_id);
if (!is_dist_train) {
// will send gradients directly when distributed training
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
}
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) { if (IsSparseGradient(g_name)) {
...@@ -436,14 +456,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -436,14 +456,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} }
} }
} }
bool use_gpu = false; bool use_gpu = false;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
use_gpu = nccl_ctxs_ != nullptr; use_gpu = nccl_ctxs_ != nullptr;
#endif #endif
if (use_gpu || if ((use_gpu &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
is_dist_train) {
// Insert BCast Ops // Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id]; auto &to_bcast_set = bcast_var_name_set[dev_id];
...@@ -676,7 +696,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -676,7 +696,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var; return var;
} }
void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
ir::Node *node) const { ir::Node *node) const {
int op_dev_id = -1; int op_dev_id = -1;
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
...@@ -720,6 +740,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -720,6 +740,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node->Op()->Type()); node->Op()->Type());
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
return op_dev_id;
} }
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
...@@ -738,7 +759,7 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { ...@@ -738,7 +759,7 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const { ir::Node *node) const {
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
...@@ -825,6 +846,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -825,6 +846,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id); CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id);
} }
} }
return op_dev_id;
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
......
...@@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(ir::Graph *result, ir::Node *node) const; int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
void CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
......
...@@ -27,7 +27,11 @@ namespace framework { ...@@ -27,7 +27,11 @@ namespace framework {
namespace details { namespace details {
void ReduceOpHandle::RunImpl() { void ReduceOpHandle::RunImpl() {
platform::RecordEvent r("reduce", nullptr); if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
if (places_.size() == 1) return; if (places_.size() == 1) return;
// the input and output may have dummy var. // the input and output may have dummy var.
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
......
...@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() { ...@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() {
->stream(); ->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp, memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream); platform::CPUPlace(), &coeff_, sizeof(float), stream);
VLOG(1) << place_ << "RUN Scale loss grad op"; VLOG(10) << place_ << "RUN Scale loss grad op";
}); });
#endif #endif
} }
......
...@@ -683,7 +683,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -683,7 +683,6 @@ All parameter, weight, gradient are variables in Paddle.
const std::string &, Scope *, std::vector<Scope *> &, const std::string &, Scope *, std::vector<Scope *> &,
const ExecutionStrategy &, const BuildStrategy &, size_t, const ExecutionStrategy &, const BuildStrategy &, size_t,
size_t>()) size_t>())
.def("_bcast_params", &ParallelExecutor::BCastParamsToDevices)
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element // We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope* // of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
...@@ -279,21 +279,11 @@ class ParallelExecutor(object): ...@@ -279,21 +279,11 @@ class ParallelExecutor(object):
self.executor.run(fetch_list, fetch_var_name) self.executor.run(fetch_list, fetch_var_name)
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
if self.is_dist:
self._bcast_params()
if return_numpy: if return_numpy:
return executor.as_numpy(arr) return executor.as_numpy(arr)
return [arr[i] for i in range(len(arr))] return [arr[i] for i in range(len(arr))]
def _bcast_params(self):
"""
Broadcast the parameters to other devices. It is used during
distributed training.
"""
self.executor._bcast_params(set(self.persistable_vars))
@property @property
def device_count(self): def device_count(self):
return len(self._act_places) return len(self._act_places)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册