未验证 提交 fe8495a7 编写于 作者: C chengduo 提交者: GitHub

[WIP] Refine MultiDevSSAGraph (#15040)

* refine parallel_exe
test=develop

* rename shared_var_device

* code refine

* add test_weight_decay

* remove Sort
test=develop

* Add SortForReduce
test=develop

* code refine
test=develop

* follow comment
test=develop
上级 85471533
......@@ -45,7 +45,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
#endif
int GetVarDeviceID(
const ir::Graph &graph, const std::string &varname,
const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const;
bool IsScaleLossOp(ir::Node *node) const;
......@@ -57,12 +57,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const;
std::vector<std::string> FindDistTrainRecvVars(
const std::vector<ir::Node *> &nodes) const;
void CreateComputationalOps(ir::Graph *result, ir::Node *node,
size_t num_places) const;
......@@ -77,7 +71,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
int dev_id) const;
int GetOpDeviceID(
const ir::Graph &graph, ir::Node *node,
ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
......@@ -100,6 +94,15 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
std::vector<ir::Node *> SortForReduceMode(
const std::vector<ir::Node *> &) const;
int GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &shared_var_device,
std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops)
const;
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
......
......@@ -23,66 +23,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
namespace {
void CheckProgram(const ProgramDesc &program) {
#define _INT(role) static_cast<int>(role)
std::map<int, bool> visit;
for (OpDesc *op : program.Block(0).AllOps()) {
// For backward compatibility, some program doesn't have role added.
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
int role_id =
boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
visit[role_id] = true;
switch (role_id) {
case _INT(OpRole::kForward):
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
LOG(ERROR) << "Cannot add backward operator before forward operator "
<< op->Type();
}
break;
case _INT(OpRole::kBackward):
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator %s after optimize operator.",
op->Type());
break;
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
_INT(OpRole::kLoss)) == visit.end(),
"Cannot add backward|loss operator before "
"forward|loss operator %s.",
op->Type());
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add forward|loss operator %s after optimize operator.",
op->Type());
break;
case _INT(OpRole::kOptimize):
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
"Optimize operators %s must follow backward operator.",
op->Type());
break;
case _INT(OpRole::kLRSched):
case _INT(OpRole::kDist):
case _INT(OpRole::kRPC):
case _INT(OpRole::kNotSpecified):
break;
default:
LOG(FATAL) << "Unknown operator role. Don't add new role because "
"you don't know what you are doing.";
}
}
#undef _INT
}
} // namespace
Graph::Graph(const ProgramDesc &program) : program_(program) {
CheckProgram(program_);
auto var_nodes = InitFromProgram(program_);
ResolveHazard(var_nodes);
}
......
......@@ -320,6 +320,7 @@ void ParallelExecutor::BCastParamsToDevices(
if (paddle::platform::is_gpu_place(main_tensor.place())) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::vector<void *> buffers;
buffers.reserve(member_->places_.size());
size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
for (size_t i = 0; i < member_->places_.size(); ++i) {
......@@ -353,9 +354,7 @@ void ParallelExecutor::BCastParamsToDevices(
#endif
} else {
platform::CPUPlace cpu;
for (size_t i = 0; i < member_->places_.size(); ++i) {
if (i == 0) continue;
for (size_t i = 1; i < member_->places_.size(); ++i) {
auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
......
......@@ -148,7 +148,7 @@ class ParallelExecutor(object):
trainers_endpoints), "num_trainers == len(end_points)"
build_strategy.trainers_endpoints = trainers_endpoints
# step5: get persistable_vars, parameter_vars, places. persistable_vars
# step6: get persistable_vars, places. persistable_vars
# need be broadcast to other local_scope.
persistable_vars = set([
cpt.to_text(v.name) for v in [
......@@ -164,7 +164,7 @@ class ParallelExecutor(object):
places = list(map(place_obj, self._places))
# step6: init ParallelExecutor
# step7: init ParallelExecutor
self.executor = core.ParallelExecutor(
places, persistable_vars, main.desc,
cpt.to_text(loss_name)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
from functools import partial
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
def get_places():
places = []
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
@contextlib.contextmanager
def prog_scope_guard(main_prog, startup_prog):
scope = fluid.core.Scope()
with fluid.unique_name.guard():
with fluid.scope_guard(scope):
with fluid.program_guard(main_prog, startup_prog):
yield
def bow_net(data,
label,
dict_dim,
is_sparse=False,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
"""
BOW net
This model is from https://github.com/PaddlePaddle/models:
fluid/PaddleNLP/text_classification/nets.py
"""
emb = fluid.layers.embedding(
input=data, is_sparse=is_sparse, size=[dict_dim, emb_dim])
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
class TestWeightDecay(unittest.TestCase):
def setUp(self):
self.word_dict = paddle.dataset.imdb.word_dict()
reader = paddle.batch(
paddle.dataset.imdb.train(self.word_dict), batch_size=4)()
self.train_data = [next(reader) for _ in range(5)]
self.learning_rate = .5
def run_executor(self, place, feed_list, loss):
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
exe.run(fluid.default_startup_program())
main_prog = fluid.default_main_program()
loss_set = []
for data in self.train_data:
out = exe.run(main_prog,
feed=feeder.feed(data),
fetch_list=[loss.name])
print("loss %s" % (np.average(out)))
loss_set.append(np.average(out))
return loss_set
def run_parallel_exe(self,
place,
feed_list,
loss,
use_cuda=True,
use_reduce=False,
use_fast_executor=False,
use_ir_memory_optimize=False):
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
exe.run(fluid.default_startup_program())
exec_strategy = fluid.ExecutionStrategy()
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.memory_optimize = use_ir_memory_optimize
parallel_exe = fluid.ParallelExecutor(
use_cuda,
loss_name=loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
loss_set = []
for data in self.train_data:
out = parallel_exe.run(feed=feeder.feed(data),
fetch_list=[loss.name])
print("loss %s" % (np.average(out)))
loss_set.append(np.average(out))
return loss_set
def check_weight_decay(self,
place,
model,
use_parallel_exe=False,
use_reduce=False):
main_prog = fluid.framework.Program()
startup_prog = fluid.framework.Program()
startup_prog.random_seed = 1
with prog_scope_guard(main_prog=main_prog, startup_prog=startup_prog):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
avg_cost = model(data, label, len(self.word_dict))
param_list = [(var, var * self.learning_rate)
for var in main_prog.block(0).all_parameters()]
optimizer = fluid.optimizer.Adagrad(
learning_rate=self.learning_rate)
optimizer.minimize(avg_cost)
for params in param_list:
updated_p = fluid.layers.elementwise_sub(
x=params[0], y=params[1])
fluid.layers.assign(input=updated_p, output=params[0])
if use_parallel_exe:
loss = self.run_parallel_exe(
place, [data, label],
loss=avg_cost,
use_cuda=True,
use_reduce=use_reduce)
else:
loss = self.run_executor(place, [data, label], loss=avg_cost)
return loss
def test_weight_decay(self):
model = partial(bow_net, is_sparse=False)
for place in get_places():
loss = self.check_weight_decay(place, model, use_parallel_exe=False)
loss2 = self.check_weight_decay(
place, model, use_parallel_exe=True, use_reduce=False)
for i in range(len(loss)):
assert np.isclose(a=loss[i], b=loss2[i], rtol=5e-5)
loss3 = self.check_weight_decay(
place, model, use_parallel_exe=True, use_reduce=True)
for i in range(len(loss)):
assert np.isclose(a=loss[i], b=loss3[i], rtol=5e-5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册