提交 0a775cbc 编写于 作者: W willzhang4a58

fix bug: parallel_ctx init


Former-commit-id: 76cc22d3
上级 3ca374c8
...@@ -4,6 +4,9 @@ namespace oneflow { ...@@ -4,6 +4,9 @@ namespace oneflow {
void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) { void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
actor_id_ = task_proto.task_id(); actor_id_ = task_proto.task_id();
if (task_proto.has_parallel_ctx()) {
parallel_ctx_.reset(new ParallelContext(task_proto.parallel_ctx()));
}
for (const ExecNodeProto& node : task_proto.exec_sequence().exec_node()) { for (const ExecNodeProto& node : task_proto.exec_sequence().exec_node()) {
ExecKernel ek; ExecKernel ek;
ek.kernel = ek.kernel =
......
...@@ -36,7 +36,7 @@ class Actor { ...@@ -36,7 +36,7 @@ class Actor {
// Util // Util
Actor() = default; Actor() = default;
virtual const ParallelContext* parallel_ctx() const { return nullptr; } const ParallelContext* parallel_ctx() const { return parallel_ctx_.get(); }
DeviceType GetDeviceType() const; DeviceType GetDeviceType() const;
virtual void VirtualActorInit(const TaskProto&) {} virtual void VirtualActorInit(const TaskProto&) {}
int64_t RegstDescId4Name(const std::string& name) const; int64_t RegstDescId4Name(const std::string& name) const;
...@@ -90,6 +90,7 @@ class Actor { ...@@ -90,6 +90,7 @@ class Actor {
private: private:
int64_t actor_id_; int64_t actor_id_;
std::unique_ptr<ParallelContext> parallel_ctx_;
std::vector<ExecKernel> exec_kernel_vec_; std::vector<ExecKernel> exec_kernel_vec_;
HashMap<int64_t, std::vector<std::unique_ptr<Regst>>> produced_regsts_; HashMap<int64_t, std::vector<std::unique_ptr<Regst>>> produced_regsts_;
HashMap<std::string, int64_t> name2regst_desc_id_; HashMap<std::string, int64_t> name2regst_desc_id_;
......
...@@ -15,17 +15,10 @@ class CompActor : public Actor { ...@@ -15,17 +15,10 @@ class CompActor : public Actor {
virtual void VirtualCompActorInit(const TaskProto& task_proto) {} virtual void VirtualCompActorInit(const TaskProto& task_proto) {}
const ParallelContext* parallel_ctx() const override {
return &parallel_ctx_;
}
private: private:
void VirtualActorInit(const TaskProto& task_proto) override { void VirtualActorInit(const TaskProto& task_proto) override {
parallel_ctx_ = task_proto.parallel_ctx();
VirtualCompActorInit(task_proto); VirtualCompActorInit(task_proto);
} }
ParallelContext parallel_ctx_;
}; };
inline int64_t GetLastPieceIdForModelVersionId(int64_t model_version_id) { inline int64_t GetLastPieceIdForModelVersionId(int64_t model_version_id) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册