提交 46f3a39e 编写于 作者: X Xin Pan

polish and add comments.

上级 d0ac9253
...@@ -37,6 +37,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { ...@@ -37,6 +37,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
std::string Name() const override; std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsDelayedOp() override { return true; }; bool IsDelayedOp() override { return true; };
protected: protected:
......
...@@ -45,7 +45,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -45,7 +45,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<VarHandleBase *> pending_vars; std::unordered_set<VarHandleBase *> pending_vars;
BlockingQueue<VarHandleBase *> ready_vars; BlockingQueue<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops; std::unordered_set<OpHandleBase *> ready_ops;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops; std::unordered_set<OpHandleBase *> delayed_ops;
std::unordered_set<OpHandleBase *> after_delayed_ops; std::unordered_set<OpHandleBase *> after_delayed_ops;
std::unordered_set<VarHandleBase *> delayed_vars; std::unordered_set<VarHandleBase *> delayed_vars;
......
...@@ -16,7 +16,6 @@ import core ...@@ -16,7 +16,6 @@ import core
import multiprocessing import multiprocessing
import framework import framework
import executor import executor
import sys
__all__ = ['ParallelExecutor'] __all__ = ['ParallelExecutor']
...@@ -36,7 +35,12 @@ class ParallelExecutor(object): ...@@ -36,7 +35,12 @@ class ParallelExecutor(object):
places.append(p) places.append(p)
if num_threads is None: if num_threads is None:
num_threads = len(places) if use_cuda:
# Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future.
num_threads = len(places)
else:
min(len(places) * 2, multiprocessing.cpu_count())
startup = framework.default_startup_program() startup = framework.default_startup_program()
main = framework.default_main_program() main = framework.default_main_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册