预测多 instance 多 stream 方案讨论
Created by: Superjomn
预测时候,需要能够控制单个 executor 内op执行的 stream。
目前,stream为进程级单例,导致一个process内所有的 thread instance 都跑在一个stream上。下面介绍一个初步方案:
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/feed_op.cc#L33
这里,类似给 Instance().Get(Place)
加一个默认参数,变成 Instance().Get(Place, OpBase*),如果给定
OpBase*参数,则获取
opbase->stream()并返回; 否则,默认返回
thread_local stream = Create stream from default stream`.
暂时,这个只能加 ON_INFER
宏来降低对训练的影响。
示例修改:
static stream_t& thread_local_stream() {
static thread_local stream;
// call once init stream ...
return stream;
}
OpBase:
class OperatorBase {
stream_t stream() {
if (!local_stream_) return DeviceContextPool(thread_local_stream()); // set stream
return *local_stream_;
}
void SetStream(const shared_ptr<stream_t>& stream); // ...
private:
stream_t* local_stream_{nullptr};
}
usage:
class FeedOp : public OperatorBase {
void RunImpl() {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place, this);
}
}
By adding an SetStream
method to OperatorBase
, we can control the stream an operator runs on by some extra algorithm in the Executor
.