提交 5ce1a960 编写于 作者: Y Yancey1989

move bcast op into pass

上级 b681537e
......@@ -140,5 +140,11 @@ def parse_args():
'--use_lars',
action='store_true',
help='If set, use lars for optimizers, ONLY support resnet module.')
parser.add_argument(
'--reduce_strategy',
type=str,
choices=['reduce', 'all_reduce'],
default='all_reduce',
help='Specify the reduce strategy, can be reduce, all_reduce')
args = parser.parse_args()
return args
......@@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.cpus
strategy.allow_op_delay = False
build_strategy = fluid.BuildStrategy()
if args.reduce_strategy == "reduce":
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.Reduce
else:
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.AllReduce
avg_loss = train_args[0]
if args.update_method == "pserver":
......@@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
avg_loss.name,
main_program=train_prog,
exec_strategy=strategy,
build_strategy=build_strategy,
num_trainers=num_trainers,
trainer_id=trainer_id)
......
......@@ -67,11 +67,14 @@ def cnn_model(data):
def get_model(args, is_train, main_prog, startup_prog):
# NOTE: mnist is small, we don't implement data sharding yet.
filelist = [
os.path.join(args.data_path, f) for f in os.listdir(args.data_path)
]
opt = None
data_file_handle = None
with fluid.program_guard(main_prog, startup_prog):
if args.use_reader_op:
filelist = [
os.path.join(args.data_path, f)
for f in os.listdir(args.data_path)
]
data_file_handle = fluid.layers.open_files(
filenames=filelist,
shapes=[[-1, 1, 28, 28], (-1, 1)],
......@@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog):
if is_train:
opt = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, beta1=0.9, beta2=0.999)
opt.minimize()
opt.minimize(avg_cost)
if args.memory_optimize:
fluid.memory_optimize(main_prog)
......
......@@ -46,7 +46,12 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif
void AllReduceOpHandle::RunImpl() {
platform::RecordEvent r("all_reduce", nullptr);
if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1;
} else {
......
......@@ -15,12 +15,19 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
namespace details {
void BroadcastOpHandle::RunImpl() {
if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
if (places_.size() == 1) return;
// The input and output may have dummy vars.
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
......@@ -86,6 +87,12 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
}
void DataBalanceOpHandle::RunImpl() {
if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
PADDLE_ENFORCE_GT(places_.size(), 1,
"Data balance can only be enabled when the number of "
"places to run larger than 1.");
......
......@@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
size_t cur_device_id = 0;
bool is_forwarding = true;
bool is_dist_train = false;
for (ir::Node *node : sorted_ops) {
if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
CreateRPCOp(&result, node);
int op_dev_id = CreateRPCOp(&result, node);
PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") {
auto recv_vars_attr =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient]
if (recv_vars_attr[0].find(".block") == std::string::npos) {
bcast_var_name_set[op_dev_id].emplace(recv_vars_attr[0]);
}
}
is_dist_train = true;
} else if (IsDistTrainOp(node, send_vars, recv_vars)) {
CreateDistTrainOp(&result, node);
int op_dev_id = CreateDistTrainOp(&result, node);
if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set[op_dev_id].emplace(origin_param_name);
}
} else if (IsScaleLossOp(node)) {
// user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ !=
......@@ -414,7 +431,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name);
if (!is_dist_train) {
// will send gradients directly when distributed training
bcast_var_name_set[cur_device_id].emplace(p_name);
}
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(g_name)) {
......@@ -436,14 +456,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
}
}
bool use_gpu = false;
#ifdef PADDLE_WITH_CUDA
use_gpu = nccl_ctxs_ != nullptr;
#endif
if (use_gpu ||
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
if ((use_gpu &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
is_dist_train) {
// Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
......@@ -676,8 +696,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var;
}
void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
ir::Node *node) const {
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
ir::Node *node) const {
int op_dev_id = -1;
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
......@@ -720,6 +740,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node->Op()->Type());
CreateComputationalOp(result, node, op_dev_id);
return op_dev_id;
}
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
......@@ -738,8 +759,8 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
}
// Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const {
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const {
int op_dev_id = -1;
if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe.
......@@ -825,6 +846,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id);
}
}
return op_dev_id;
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
......
......@@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(ir::Graph *result, ir::Node *node) const;
void CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
/**
* Is this operator as the end-point operator before/after send operator.
......
......@@ -27,7 +27,11 @@ namespace framework {
namespace details {
void ReduceOpHandle::RunImpl() {
platform::RecordEvent r("reduce", nullptr);
if (dev_ctxes_.size() > 0UL) {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
} else {
platform::RecordEvent record_event(Name(), nullptr);
}
if (places_.size() == 1) return;
// the input and output may have dummy var.
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
......
......@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() {
->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);
VLOG(1) << place_ << "RUN Scale loss grad op";
VLOG(10) << place_ << "RUN Scale loss grad op";
});
#endif
}
......
......@@ -683,7 +683,6 @@ All parameter, weight, gradient are variables in Paddle.
const std::string &, Scope *, std::vector<Scope *> &,
const ExecutionStrategy &, const BuildStrategy &, size_t,
size_t>())
.def("_bcast_params", &ParallelExecutor::BCastParamsToDevices)
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
......@@ -279,21 +279,11 @@ class ParallelExecutor(object):
self.executor.run(fetch_list, fetch_var_name)
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
if self.is_dist:
self._bcast_params()
if return_numpy:
return executor.as_numpy(arr)
return [arr[i] for i in range(len(arr))]
def _bcast_params(self):
"""
Broadcast the parameters to other devices. It is used during
distributed training.
"""
self.executor._bcast_params(set(self.persistable_vars))
@property
def device_count(self):
return len(self._act_places)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册