From 5646f7100ef153f5072b2014d947ebf1dd7a50bb Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Tue, 2 Mar 2021 19:46:05 +0800 Subject: [PATCH] add api for scheduler --- paddle/fluid/framework/device_worker.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 6ecc02bbae6..a74a57ed604 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -451,7 +451,7 @@ class HeterBoxWorker : public HogwildWorker { virtual void CacheProgram(const ProgramDesc& main_program) { new (&program_) ProgramDesc(main_program); } - virtual void ProduceTasks() override; + void ProduceTasks() override; virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } virtual void SetEvent(const cudaEvent_t event) { event_ = event; } virtual void TrainFilesWithProfiler() {} @@ -550,7 +550,7 @@ class PSGPUWorker : public HogwildWorker { virtual void CacheProgram(const ProgramDesc& main_program) { new (&program_) ProgramDesc(main_program); } - virtual void ProduceTasks() override; + void ProduceTasks() override; virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } virtual void SetEvent(const cudaEvent_t event) { event_ = event; } virtual void TrainFilesWithProfiler() {} @@ -654,6 +654,8 @@ class SectionWorker : public DeviceWorker { void SetDeviceIndex(int tid) override {} void SetThreadIndex(int thread_id) { thread_id_ = thread_id; } void SetMicrobatchNum(int num) { num_microbatches_ = num; } + void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; } + void SetPipelineStage(int stage) { pipeline_stage_ = stage; } void SetMicrobatchScopes(const std::vector& scope) { microbatch_scopes_ = scope; } @@ -666,6 +668,8 @@ class SectionWorker : public DeviceWorker { int section_id_; int thread_id_; int num_microbatches_; + int num_pipeline_stages_; + int pipeline_stage_; std::vector microbatch_scopes_; std::vector skip_vars_; const Scope* minibatch_scope_; -- GitLab