未验证 提交 49313d40 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #9548 from panyx0718/group_nccl_all_reduce

Group nccl all reduce and improve performance (~14% for 4 device resnext)
......@@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
}
}
std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; }
std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,9 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
......@@ -34,6 +37,10 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return true; };
protected:
void RunImpl() override;
};
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -53,6 +55,10 @@ class OpHandleBase {
void AddOutput(VarHandleBase *out);
// If the Op involves data transfer of multiple devices that
// will likely block other computations.
virtual bool IsMultiDeviceTransfer() { return false; }
protected:
virtual void RunImpl() = 0;
};
......
......@@ -23,22 +23,36 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph)
std::unique_ptr<SSAGraph> &&graph, bool allow_op_delay)
: SSAGraphExecutor(std::move(graph)),
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
use_event_(use_event) {}
use_event_(use_event),
running_ops_(0),
allow_op_delay_(allow_op_delay) {}
void ThreadedSSAGraphExecutor::RunDelayedOps(
const std::unordered_set<OpHandleBase *> &delayed_ops) {
for (auto op : delayed_ops) {
op->Run(use_event_);
}
}
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
BlockingQueue<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops;
std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
std::unordered_set<VarHandleBase *> delayed_vars;
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
pending_vars.insert(&var);
......@@ -106,7 +120,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) {
RunOp(ready_vars, op);
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
delayed_ops.insert(op);
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
ready_vars.Extend(op->outputs_);
continue;
}
running_ops_++;
RunOp(&ready_vars, op);
}
ready_ops.clear();
};
......@@ -118,13 +139,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
// Step 3. Execution
while (!pending_vars.empty()) {
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
// 1. Run All Ready ops
run_all_ready_ops();
// 2. Find ready variable
bool timeout;
auto cur_ready_vars = ready_vars.PopAll(1000, &timeout);
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
if (timeout) {
if (exception_) {
......@@ -141,13 +162,29 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
ready_ops.insert(op);
if (delayed_vars.find(ready_var) != delayed_vars.end()) {
blocked_by_delayed_ops.insert(op);
} else {
ready_ops.insert(op);
}
}
}
}
// When there are no other ops to schedule, schedule buffered delayed
// ops and unblock other ops.
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
RunDelayedOps(delayed_ops);
delayed_ops.clear();
for (auto *op : blocked_by_delayed_ops) {
ready_ops.insert(op);
}
blocked_by_delayed_ops.clear();
}
// Keep loop until all vars are ready.
}
PADDLE_ENFORCE(ready_ops.empty());
PADDLE_ENFORCE(delayed_ops.empty());
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
++computation_count_;
auto sync_computation = [&] {
......@@ -182,12 +219,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) {
auto op_run = [&ready_var_q, op, this] {
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] {
try {
VLOG(10) << op->Name() << " : " << op->DebugString();
op->Run(use_event_);
ready_var_q.Extend(op->outputs_);
running_ops_--;
ready_var_q->Extend(op->outputs_);
} catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
......
......@@ -14,7 +14,12 @@
#pragma once
#include <chrono>
#include <deque>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
......@@ -70,7 +75,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph);
std::unique_ptr<SSAGraph> &&graph,
bool allow_op_delay);
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
......@@ -79,9 +85,11 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() {}
private:
void RunOp(BlockingQueue<VarHandleBase *> &ready_var_q,
void RunOp(BlockingQueue<VarHandleBase *> *ready_var_q,
details::OpHandleBase *op);
void RunDelayedOps(const std::unordered_set<OpHandleBase *> &delayed_ops);
private:
std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_;
......@@ -89,6 +97,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
platform::DeviceContextPool fetch_ctxs_;
const bool use_event_;
std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_;
bool allow_op_delay_;
size_t computation_count_{0};
size_t max_async_computation{100};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include <string>
#include <vector>
......@@ -47,7 +48,7 @@ ParallelExecutor::ParallelExecutor(
const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope)
const std::string &loss_var_name, Scope *scope, bool allow_op_delay)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
......@@ -82,8 +83,8 @@ ParallelExecutor::ParallelExecutor(
auto graph = builder.Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
num_threads, use_event, member_->local_scopes_, places,
std::move(graph)));
num_threads, use_event, member_->local_scopes_, places, std::move(graph),
allow_op_delay));
// Step 3. Create vars in each scope;
for (auto *scope : member_->local_scopes_) {
......@@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
platform::RecordBlock b(0);
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
......
......@@ -14,8 +14,9 @@ limitations under the License. */
#pragma once
#include <future>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -37,7 +38,8 @@ class ParallelExecutor {
const std::unordered_set<std::string>& params,
const ProgramDesc& startup_program,
const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope);
const std::string& loss_var_name, Scope* scope,
bool allow_op_delay);
void Run(const std::vector<std::string>& fetch_tensors,
const std::string& fetched_var_name = "fetched_var");
......
......@@ -504,10 +504,10 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope) {
Scope *scope, bool allow_op_delay) {
new (&self) ParallelExecutor(num_threads, use_event, places,
params, startup_program, main_program,
loss_var_name, scope);
loss_var_name, scope, allow_op_delay);
})
.def("run", &ParallelExecutor::Run);
......
......@@ -21,7 +21,11 @@ __all__ = ['ParallelExecutor']
class ParallelExecutor(object):
def __init__(self, loss_name, use_cuda, num_threads=None):
def __init__(self,
loss_name,
use_cuda,
num_threads=None,
allow_op_delay=False):
places = []
if use_cuda:
for i in xrange(core.get_cuda_device_count()):
......@@ -35,7 +39,12 @@ class ParallelExecutor(object):
places.append(p)
if num_threads is None:
num_threads = min(len(places) * 2, multiprocessing.cpu_count())
if use_cuda:
# Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future.
num_threads = len(places)
else:
min(len(places) * 2, multiprocessing.cpu_count())
startup = framework.default_startup_program()
main = framework.default_main_program()
......@@ -52,7 +61,8 @@ class ParallelExecutor(object):
startup.desc,
main.desc,
loss_name,
scope)
scope,
allow_op_delay)
self.scope = scope
def run(self, fetch_list):
......
......@@ -29,6 +29,7 @@ function(py_test_modules TARGET_NAME)
endfunction()
# test time consuming OPs in a separate process for expliot parallism
list(REMOVE_ITEM TEST_OPS test_parallel_executor)
list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list(REMOVE_ITEM TEST_OPS test_dyn_rnn)
list(REMOVE_ITEM TEST_OPS test_mul_op)
......@@ -64,6 +65,7 @@ else()
endif(WITH_FAST_BUNDLE_TEST)
# tests with high overhead
py_test_modules(test_parallel_executor MODULES test_parallel_executor)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR})
py_test_modules(test_train_dyn_rnn MODULES test_dyn_rnn)
py_test_modules(test_mul_op MODULES test_mul_op)
......
......@@ -135,18 +135,18 @@ def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio):
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
def SE_ResNeXt152(batch_size=4):
def SE_ResNeXt152Small(batch_size=2):
img = fluid.layers.fill_constant(
shape=[batch_size, 3, 224, 224], dtype='float32', value=0.0)
label = fluid.layers.fill_constant(
shape=[batch_size, 1], dtype='int64', value=0.0)
conv = conv_bn_layer(
input=img, num_filters=64, filter_size=3, stride=2, act='relu')
input=img, num_filters=16, filter_size=3, stride=2, act='relu')
conv = conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu')
input=conv, num_filters=16, filter_size=3, stride=1, act='relu')
conv = conv_bn_layer(
input=conv, num_filters=128, filter_size=3, stride=1, act='relu')
input=conv, num_filters=16, filter_size=3, stride=1, act='relu')
conv = fluid.layers.pool2d(
input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
......@@ -184,7 +184,8 @@ class TestParallelExecutorBase(unittest.TestCase):
method,
memory_opt=True,
iter=10,
batch_size=None):
batch_size=None,
allow_op_delay=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -194,7 +195,10 @@ class TestParallelExecutorBase(unittest.TestCase):
if memory_opt:
fluid.memory_optimize(main)
exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True)
exe = fluid.ParallelExecutor(
loss_name=loss.name,
use_cuda=True,
allow_op_delay=allow_op_delay)
if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count()
begin = time.time()
......@@ -222,7 +226,7 @@ class TestMNIST(TestParallelExecutorBase):
def setUpClass(cls):
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
reader = paddle.batch(mnist.train(), batch_size=4)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
......@@ -236,9 +240,11 @@ class TestMNIST(TestParallelExecutorBase):
def test_simple_fc(self):
self.check_network_convergence(simple_fc_net)
self.check_network_convergence(simple_fc_net, allow_op_delay=True)
def test_batchnorm_fc(self):
self.check_network_convergence(fc_with_batchnorm)
self.check_network_convergence(fc_with_batchnorm, allow_op_delay=True)
class TestResnet(TestParallelExecutorBase):
......@@ -262,10 +268,10 @@ class TestResnet(TestParallelExecutorBase):
def test_resnet(self):
import functools
batch_size = 4
batch_size = 2
self.check_network_convergence(
functools.partial(
SE_ResNeXt152, batch_size=batch_size),
SE_ResNeXt152Small, batch_size=batch_size),
iter=20,
batch_size=batch_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册