未验证 提交 6382b62f 编写于 作者: W Wu Yi 提交者: GitHub

Collective ops (#15572)

* wip allreduce in op

* wip

* wip

* wip

* wip adding test

* wip for conflict with mp mode

* fix tests test=develop

* fix cpu build test=develop

* fix travis clang format test=develop

* fix cpu build test=develop

* update api.spec test=develop

* delete comment test=develop

* fix cpplint test=develop

* fix test=develop

* follow comment test=develop

* add file test=develop

* fix build test=develop

* update test=develop

* to be compatible with sync_bn, and fix mp mode in develop test=develop
上级 b9fc80a1
...@@ -166,7 +166,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -166,7 +166,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
result.Set(kGraphOps, new GraphOps); result.Set(kGraphOps, new GraphOps);
bool is_forwarding = true; bool is_forwarding = true;
bool insert_collection_ops = NeedCollectiveOps();
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
if (DealWithSpecialOp(&result, node)) { if (DealWithSpecialOp(&result, node)) {
...@@ -185,8 +184,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -185,8 +184,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
CreateComputationalOps(&result, node, places_.size()); CreateComputationalOps(&result, node, places_.size());
} }
// Insert collection ops // Insert collective ops if nranks > 1
if (!is_forwarding && insert_collection_ops) { if (!is_forwarding && Get<size_t>(kNRanks) > 1) {
try { try {
bool is_bk_op = bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr( static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
...@@ -205,8 +204,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -205,8 +204,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
auto &p_name = backward_vars[i]; auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1]; auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
if (NeedCollectiveForGrad(g_name, sorted_ops)) {
InsertCollectiveOp(&result, p_name, g_name); InsertCollectiveOp(&result, p_name, g_name);
}
} }
} catch (boost::bad_get e) { } catch (boost::bad_get e) {
} }
...@@ -271,8 +271,20 @@ bool MultiDevSSAGraphBuilderBase::UseGPU() const { ...@@ -271,8 +271,20 @@ bool MultiDevSSAGraphBuilderBase::UseGPU() const {
return use_gpu; return use_gpu;
} }
bool MultiDevSSAGraphBuilderBase::NeedCollectiveOps() const { bool MultiDevSSAGraphBuilderBase::NeedCollectiveForGrad(
return Get<size_t>(kNRanks) > 1; const std::string &grad_name, std::vector<ir::Node *> ops) const {
// if we have allreduce_op for current gradient variable in the graph,
// then we don't need to add allreduce_op_handle for this gradient
// NOTE: This is for the case that all gradients should add collective ops
for (auto *node : ops) {
if (node->Op()->Type() != "allreduce") continue;
for (auto in_name : node->Op()->InputArgumentNames()) {
if (in_name == grad_name) {
return false;
}
}
}
return true;
} }
void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -55,7 +58,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -55,7 +58,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool UseGPU() const; bool UseGPU() const;
bool NeedCollectiveOps() const; bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
......
...@@ -254,18 +254,29 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -254,18 +254,29 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
member_->places_, nccl_id, build_strategy.num_trainers_, member_->places_, nccl_id, build_strategy.num_trainers_,
build_strategy.trainer_id_)); build_strategy.trainer_id_));
std::unique_ptr<platform::NCCLContextMap> dev_nccl_ctxs; // Initialize device context's nccl comm, will be used by normal
dev_nccl_ctxs.reset(new platform::NCCLContextMap(member_->places_)); // Operators like sync_batch_norm, and collective ops.
// Initialize device context's nccl comm // NOTE: more than one ParallelExecutor with same place, the nccl comm will
// Note, more than one ParallelExecutor with same place, the nccl comm will
// be rewrite and there will be some problem. // be rewrite and there will be some problem.
// NOTE: NCCL group-calls and non-group-calls can not use the same
// NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use
// same communicators.
std::unique_ptr<platform::NCCLContextMap> dev_nccl_ctxs;
if (nccl_id == nullptr) {
dev_nccl_ctxs.reset(new platform::NCCLContextMap(member_->places_));
}
for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) {
auto &nccl_ctx = dev_nccl_ctxs->at(dev_id);
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>( auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
pool.Get(member_->places_[dev_id])); pool.Get(member_->places_[dev_id]));
dev_ctx->set_nccl_comm(nccl_ctx.comm()); if (nccl_id != nullptr) {
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[dev_id]);
dev_ctx->set_nccl_comm(nccl_ctx.comm());
} else {
auto &nccl_ctx = dev_nccl_ctxs->at(member_->places_[dev_id]);
dev_ctx->set_nccl_comm(nccl_ctx.comm());
}
} }
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
struct MutableDataFunctor {
MutableDataFunctor(void** data, framework::LoDTensor* tensor,
const platform::Place& place)
: data_(data), tensor_(tensor), place_(place) {}
template <typename T>
void apply() {
*data_ = tensor_->mutable_data<T>(place_);
}
void** data_;
framework::LoDTensor* tensor_;
platform::Place place_;
};
class AllReduceOp : public framework::OperatorBase {
using OperatorBase::OperatorBase;
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place),
"AllReduce op can run on gpu place only for now.");
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* ctx = pool.Get(place);
auto in_names = Inputs("X");
auto out_names = Outputs("Out");
PADDLE_ENFORCE_EQ(in_names.size(), 1, "Only support one input");
PADDLE_ENFORCE_EQ(out_names.size(), 1, "Only support one output");
auto* in = scope.FindVar(in_names[0]);
auto* out = scope.FindVar(out_names[0]);
PADDLE_ENFORCE(in->IsType<framework::LoDTensor>() ||
out->IsType<framework::LoDTensor>(),
"Only support allreduce LoDTensors");
int dtype = -1;
auto in_tensor = in->Get<framework::LoDTensor>();
dtype = platform::ToNCCLDataType(in_tensor.type());
int64_t numel = in_tensor.numel();
auto* sendbuff = in_tensor.data<void>();
auto* out_tensor = out->GetMutable<framework::LoDTensor>();
out_tensor->Resize(in_tensor.dims());
void* recvbuff = nullptr;
framework::VisitDataType(in_tensor.type(),
MutableDataFunctor(&recvbuff, out_tensor, place));
auto cuda_ctx = static_cast<platform::CUDADeviceContext*>(ctx);
auto* comm = cuda_ctx->nccl_comm();
// FIXME(typhoonzero): should use nccl stream here.
auto stream = cuda_ctx->stream();
int reduce_type = Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case 0:
red_type = ncclSum;
break;
case 1:
red_type = ncclProd;
break;
case 2:
red_type = ncclMax;
break;
case 3:
red_type = ncclMin;
break;
}
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
comm, stream));
#endif
}
};
class AllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor), tensor to be allreduced.");
AddOutput("Out", "(Tensor) the result of allreduced.");
AddAttr<int>("reduce_type", "(int) determin the reduce type.")
.SetDefault(0);
AddComment(R"DOC(
***AllReduce Operator***
Call NCCL AllReduce internally. Note that this op must be used when one
thread is managing one GPU device.
For speed reasons, reduce_type should be an integer:
0: sum
1: prod
2: max
3: min
If input and output are the same variable, in-place allreduce will be used.
)DOC");
}
};
class AllReduceOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(allreduce, ops::AllReduceOp,
paddle::framework::EmptyGradOpMaker, ops::AllReduceOpMaker,
ops::AllReduceOpShapeInference);
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#endif #endif
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <typeindex> #include <typeindex>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -79,7 +80,6 @@ struct NCCLContext { ...@@ -79,7 +80,6 @@ struct NCCLContext {
: ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {} : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {}
cudaStream_t stream() const { return ctx_->stream(); } cudaStream_t stream() const { return ctx_->stream(); }
ncclComm_t comm() const { return comm_; } ncclComm_t comm() const { return comm_; }
int device_id() const { int device_id() const {
...@@ -105,9 +105,6 @@ struct NCCLContextMap { ...@@ -105,9 +105,6 @@ struct NCCLContextMap {
order_.size(), contexts_.size(), order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device"); "NCCL Context Map does not support contain two or more same device");
if (places.size() <= 1 && num_trainers == 1) {
return;
}
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]); std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// if num_trainers == 1, should create a new nccl id for local comms. // if num_trainers == 1, should create a new nccl id for local comms.
if (num_trainers == 1 && nccl_id == nullptr) { if (num_trainers == 1 && nccl_id == nullptr) {
...@@ -127,8 +124,8 @@ struct NCCLContextMap { ...@@ -127,8 +124,8 @@ struct NCCLContextMap {
} else { } else {
rank = trainer_id; rank = trainer_id;
} }
VLOG(30) << "init nccl rank: " << rank << " nranks: " << nranks VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks
<< "gpu id: " << gpu_id; << " gpu id: " << gpu_id;
PADDLE_ENFORCE(cudaSetDevice(gpu_id)); PADDLE_ENFORCE(cudaSetDevice(gpu_id));
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
comms.get() + i, nranks, *nccl_id, rank)); comms.get() + i, nranks, *nccl_id, rank));
......
...@@ -33,6 +33,7 @@ from .detection import * ...@@ -33,6 +33,7 @@ from .detection import *
from . import metric_op from . import metric_op
from .metric_op import * from .metric_op import *
from .learning_rate_scheduler import * from .learning_rate_scheduler import *
from .collective import *
__all__ = [] __all__ = []
__all__ += nn.__all__ __all__ += nn.__all__
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from ..layer_helper import LayerHelper, unique_name
def _allreduce(x, out=None, reduce_type="sum"):
helper = LayerHelper("allreduce", **locals())
# Convert string reduce type to op int type
red_typ_int = 0
if reduce_type == "sum":
red_typ_int = 0
elif reduce_type == "prod":
red_typ_int = 1
elif reduce_type == "max":
red_typ_int = 2
elif reduce_type == "min":
red_typ_int = 3
else:
raise TypeError("reduce type can only be [sum|prod|max|min]")
if out is None:
out = helper.create_variable(
name=unique_name.generate(".".join([x.name, 'tmp'])),
shape=x.shape,
dtype=x.dtype,
type=x.type,
persistable=x.persistable,
stop_gradient=True)
helper.append_op(
type='allreduce',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={"reduce_type": red_typ_int})
return out
...@@ -28,21 +28,9 @@ import six ...@@ -28,21 +28,9 @@ import six
from functools import reduce from functools import reduce
__all__ = [ __all__ = [
'While', 'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than',
'Switch', 'equal', 'array_read', 'array_length', 'IfElse', 'DynamicRNN', 'StaticRNN',
'increment', 'reorder_lod_tensor_by_rank', 'Print', 'is_empty'
'array_write',
'create_array',
'less_than',
'equal',
'array_read',
'array_length',
'IfElse',
'DynamicRNN',
'StaticRNN',
'reorder_lod_tensor_by_rank',
'Print',
'is_empty',
] ]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import time
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import unittest
from multiprocessing import Process
import os
import signal
from functools import reduce
from test_dist_base import TestDistRunnerBase, runtime_main
DTYPE = "float32"
paddle.dataset.mnist.fetch()
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
def cnn_model(data):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=data,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.01)))
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.01)))
SIZE = 10
input_shape = conv_pool_2.shape
param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5
predict = fluid.layers.fc(
input=conv_pool_2,
size=SIZE,
act="softmax",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)))
return predict
class TestDistMnist2x2(TestDistRunnerBase):
def get_model(self, batch_size=2, single_device=False):
# Input data
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Train program
predict = cnn_model(images)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
# Evaluator
batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(
input=predict, label=label, total=batch_size_tensor)
inference_program = fluid.default_main_program().clone()
# Reader
train_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
# Optimization
# TODO(typhoonzero): fix distributed adam optimizer
# opt = fluid.optimizer.AdamOptimizer(
# learning_rate=0.001, beta1=0.9, beta2=0.999)
opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9)
if single_device:
opt.minimize(avg_cost)
else:
# multi device or distributed multi device
params_grads = opt.backward(avg_cost)
data_parallel_param_grads = []
for p, g in params_grads:
# NOTE: scale will be done on loss scale in multi_devices_graph_pass using nranks.
grad_reduce = fluid.layers.collective._allreduce(g)
data_parallel_param_grads.append([p, grad_reduce])
opt.apply_gradients(data_parallel_param_grads)
return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict
if __name__ == "__main__":
runtime_main(TestDistMnist2x2)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
class TestDistMnistNCCL2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._nccl2_reduce_layer = True
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_allreduce_op.py", delta=1e-5)
if __name__ == '__main__':
unittest.main()
...@@ -33,7 +33,10 @@ DEFAULT_BATCH_SIZE = 2 ...@@ -33,7 +33,10 @@ DEFAULT_BATCH_SIZE = 2
class TestDistRunnerBase(object): class TestDistRunnerBase(object):
def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1): def get_model(self,
batch_size=DEFAULT_BATCH_SIZE,
lr=0.1,
single_device=False):
raise NotImplementedError( raise NotImplementedError(
"get_model should be implemented by child classes.") "get_model should be implemented by child classes.")
...@@ -76,8 +79,12 @@ class TestDistRunnerBase(object): ...@@ -76,8 +79,12 @@ class TestDistRunnerBase(object):
def run_trainer(self, args): def run_trainer(self, args):
self.lr = args.lr self.lr = args.lr
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ if args.nccl2_reduce_layer_local_run:
self.get_model(batch_size=args.batch_size) test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size, single_device=True)
else:
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size)
if args.mem_opt: if args.mem_opt:
fluid.memory_optimize(fluid.default_main_program(), skip_grads=True) fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
...@@ -87,7 +94,7 @@ class TestDistRunnerBase(object): ...@@ -87,7 +94,7 @@ class TestDistRunnerBase(object):
args.endpoints, args.trainers, args.endpoints, args.trainers,
args.sync_mode, args.dc_asgd) args.sync_mode, args.dc_asgd)
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
elif args.update_method == "nccl2": elif args.update_method == "nccl2" or args.update_method == "nccl2_reduce_layer":
# transpile for nccl2 # transpile for nccl2
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2" config.mode = "nccl2"
...@@ -110,9 +117,9 @@ class TestDistRunnerBase(object): ...@@ -110,9 +117,9 @@ class TestDistRunnerBase(object):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
strategy.num_threads = 1 exec_strategy.num_threads = 1
strategy.allow_op_delay = False exec_strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy() build_stra = fluid.BuildStrategy()
# FIXME force disable enable_inplace and memory_optimize # FIXME force disable enable_inplace and memory_optimize
...@@ -124,23 +131,25 @@ class TestDistRunnerBase(object): ...@@ -124,23 +131,25 @@ class TestDistRunnerBase(object):
else: else:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
pass_builder = None
if args.batch_merge_repeat > 1: if args.batch_merge_repeat > 1:
pass_builder = build_stra._finalize_strategy_and_create_passes() pass_builder = build_stra._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass( mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass") len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass")
mypass.set("num_repeats", args.batch_merge_repeat) mypass.set("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2": if args.update_method == "nccl2" or args.update_method == "nccl2_reduce_layer":
build_stra.num_trainers = len(args.endpoints.split(",")) build_stra.num_trainers = len(args.endpoints.split(","))
build_stra.trainer_id = args.trainer_id build_stra.trainer_id = args.trainer_id
else: else:
# case args.update_method == "nccl2_reduce_layer":
build_stra.num_trainers = 1 build_stra.num_trainers = 1
build_stra.trainer_id = 0 build_stra.trainer_id = 0
binary = compiler.CompiledProgram(trainer_prog).with_data_parallel( binary = compiler.CompiledProgram(trainer_prog).with_data_parallel(
loss_name=avg_cost.name, loss_name=avg_cost.name,
build_strategy=build_stra, build_strategy=build_stra,
exec_strategy=strategy) exec_strategy=exec_strategy)
feed_var_list = [ feed_var_list = [
var for var in trainer_prog.global_block().vars.values() var for var in trainer_prog.global_block().vars.values()
...@@ -182,7 +191,7 @@ def runtime_main(test_class): ...@@ -182,7 +191,7 @@ def runtime_main(test_class):
'--update_method', '--update_method',
type=str, type=str,
default="local", default="local",
choices=["pserver", "nccl2", "local"]) choices=["pserver", "nccl2", "local", "nccl2_reduce_layer"])
parser.add_argument('--trainer_id', type=int, required=False, default=0) parser.add_argument('--trainer_id', type=int, required=False, default=0)
parser.add_argument('--trainers', type=int, required=False, default=1) parser.add_argument('--trainers', type=int, required=False, default=1)
parser.add_argument( parser.add_argument(
...@@ -198,6 +207,11 @@ def runtime_main(test_class): ...@@ -198,6 +207,11 @@ def runtime_main(test_class):
parser.add_argument('--lr', required=False, type=float, default=0.001) parser.add_argument('--lr', required=False, type=float, default=0.001)
parser.add_argument( parser.add_argument(
'--batch_merge_repeat', required=False, type=int, default=1) '--batch_merge_repeat', required=False, type=int, default=1)
parser.add_argument(
'--nccl2_reduce_layer_local_run',
required=False,
type=bool,
default=False)
args = parser.parse_args() args = parser.parse_args()
...@@ -242,6 +256,11 @@ class TestDistBase(unittest.TestCase): ...@@ -242,6 +256,11 @@ class TestDistBase(unittest.TestCase):
self._dc_asgd = False # must use with async mode self._dc_asgd = False # must use with async mode
self._use_reader_alloc = True self._use_reader_alloc = True
self._nccl2_mode = False self._nccl2_mode = False
# FIXME(typhoonzero): I added this stupid argument to enable
# testing allreduce layers, which users can call layers.allreduce
# to accumulate tensors at anywhere. Find a better way to do this
# test, reduce check this argument everywhere.
self._nccl2_reduce_layer = False
self._lr = 0.001 self._lr = 0.001
self._setup_config() self._setup_config()
self._after_setup_config() self._after_setup_config()
...@@ -307,10 +326,16 @@ class TestDistBase(unittest.TestCase): ...@@ -307,10 +326,16 @@ class TestDistBase(unittest.TestCase):
cmd += " --batch_size %d" % batch_size cmd += " --batch_size %d" % batch_size
if batch_merge_repeat > 1: if batch_merge_repeat > 1:
cmd += " --batch_merge_repeat %d" % batch_merge_repeat cmd += " --batch_merge_repeat %d" % batch_merge_repeat
if self._nccl2_reduce_layer:
cmd += " --nccl2_reduce_layer_local_run 1"
if self.__use_cuda: if self.__use_cuda:
cmd += " --use_cuda" cmd += " --use_cuda"
env_local = {"CUDA_VISIBLE_DEVICES": "0"} env_local = {
"CUDA_VISIBLE_DEVICES": "0",
"PADDLE_TRAINERS_NUM": "1",
"PADDLE_TRAINER_ID": "0"
}
else: else:
env_local = {'CPU_NUM': '1'} env_local = {'CPU_NUM': '1'}
...@@ -427,29 +452,30 @@ class TestDistBase(unittest.TestCase): ...@@ -427,29 +452,30 @@ class TestDistBase(unittest.TestCase):
sys.stderr.write("ps1 stderr: %s\n" % fn.read()) sys.stderr.write("ps1 stderr: %s\n" % fn.read())
# print log # print log
if stat0 == 0:
sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out))
with open("/tmp/tr0_err.log", "r") as fn: with open("/tmp/tr0_err.log", "r") as fn:
sys.stderr.write('trainer 0 stderr: %s\n' % fn.read()) sys.stderr.write('trainer 0 stderr: %s\n' % fn.read())
if stat1 == 0:
sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out))
with open("/tmp/tr1_err.log", "r") as fn: with open("/tmp/tr1_err.log", "r") as fn:
sys.stderr.write('trainer 1 stderr: %s\n' % fn.read()) sys.stderr.write('trainer 1 stderr: %s\n' % fn.read())
return pickle.loads(tr0_out), pickle.loads(tr1_out) return pickle.loads(tr0_out), pickle.loads(tr1_out)
def _run_cluster_nccl2(self, model, envs, check_error_log): def _run_cluster_nccl2(self, model, envs, nccl2_reduce_layer,
check_error_log):
# NOTE: we reuse ps_endpoints as nccl2 worker endpoints # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
worker_endpoints = self._ps_endpoints.split(",") worker_endpoints = self._ps_endpoints.split(",")
w0_ep, w1_ep = worker_endpoints w0_ep, w1_ep = worker_endpoints
if nccl2_reduce_layer:
update_method = "nccl2_reduce_layer"
else:
update_method = "nccl2"
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f" tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method %s --lr %f"
tr0_cmd = tr_cmd % \ tr0_cmd = tr_cmd % \
(self._python_interp, model, self._ps_endpoints, (self._python_interp, model, self._ps_endpoints,
0, w0_ep, self._lr) 0, w0_ep, update_method, self._lr)
tr1_cmd = tr_cmd % \ tr1_cmd = tr_cmd % \
(self._python_interp, model, self._ps_endpoints, (self._python_interp, model, self._ps_endpoints,
1, w1_ep, self._lr) 1, w1_ep, update_method, self._lr)
if self._mem_opt: if self._mem_opt:
tr0_cmd += " --mem_opt" tr0_cmd += " --mem_opt"
...@@ -463,8 +489,17 @@ class TestDistBase(unittest.TestCase): ...@@ -463,8 +489,17 @@ class TestDistBase(unittest.TestCase):
if self.__use_cuda: if self.__use_cuda:
tr0_cmd += " --use_cuda" tr0_cmd += " --use_cuda"
tr1_cmd += " --use_cuda" tr1_cmd += " --use_cuda"
env0 = {"CUDA_VISIBLE_DEVICES": "0"} env0 = {
env1 = {"CUDA_VISIBLE_DEVICES": "1"} "CUDA_VISIBLE_DEVICES": "0",
# for test nccl2 layer
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ID": "0"
}
env1 = {
"CUDA_VISIBLE_DEVICES": "1",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ID": "1"
}
else: else:
env0 = {'CPU_NUM': '1'} env0 = {'CPU_NUM': '1'}
env1 = {'CPU_NUM': '1'} env1 = {'CPU_NUM': '1'}
...@@ -498,8 +533,6 @@ class TestDistBase(unittest.TestCase): ...@@ -498,8 +533,6 @@ class TestDistBase(unittest.TestCase):
# print log # print log
sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err) sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err) sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
sys.stderr.write('trainer 0 stdout: %s\n' % tr0_out)
sys.stderr.write('trainer 1 stdout: %s\n' % tr1_out)
return pickle.loads(tr0_out), pickle.loads(tr1_out) return pickle.loads(tr0_out), pickle.loads(tr1_out)
...@@ -528,10 +561,14 @@ class TestDistBase(unittest.TestCase): ...@@ -528,10 +561,14 @@ class TestDistBase(unittest.TestCase):
local_losses\ local_losses\
= self._run_local(model_file, required_envs, = self._run_local(model_file, required_envs,
check_error_log) check_error_log)
if self._nccl2_mode: if self._nccl2_mode:
tr0_losses, tr1_losses = self._run_cluster_nccl2( if self._nccl2_reduce_layer:
model_file, required_envs, check_error_log) tr0_losses, tr1_losses = self._run_cluster_nccl2(
model_file, required_envs, True, check_error_log)
else:
tr0_losses, tr1_losses = self._run_cluster_nccl2(
model_file, required_envs, False, check_error_log)
else: else:
tr0_losses, tr1_losses = self._run_cluster( tr0_losses, tr1_losses = self._run_cluster(
model_file, required_envs, check_error_log) model_file, required_envs, check_error_log)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册