diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 6ecc02bbae61697c9d8f1c4eea11fe7210884569..a74a57ed604a2bb497b2288c7b8431e5481a0cf3 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -451,7 +451,7 @@ class HeterBoxWorker : public HogwildWorker { virtual void CacheProgram(const ProgramDesc& main_program) { new (&program_) ProgramDesc(main_program); } - virtual void ProduceTasks() override; + void ProduceTasks() override; virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } virtual void SetEvent(const cudaEvent_t event) { event_ = event; } virtual void TrainFilesWithProfiler() {} @@ -550,7 +550,7 @@ class PSGPUWorker : public HogwildWorker { virtual void CacheProgram(const ProgramDesc& main_program) { new (&program_) ProgramDesc(main_program); } - virtual void ProduceTasks() override; + void ProduceTasks() override; virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } virtual void SetEvent(const cudaEvent_t event) { event_ = event; } virtual void TrainFilesWithProfiler() {} @@ -654,6 +654,8 @@ class SectionWorker : public DeviceWorker { void SetDeviceIndex(int tid) override {} void SetThreadIndex(int thread_id) { thread_id_ = thread_id; } void SetMicrobatchNum(int num) { num_microbatches_ = num; } + void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; } + void SetPipelineStage(int stage) { pipeline_stage_ = stage; } void SetMicrobatchScopes(const std::vector& scope) { microbatch_scopes_ = scope; } @@ -666,6 +668,8 @@ class SectionWorker : public DeviceWorker { int section_id_; int thread_id_; int num_microbatches_; + int num_pipeline_stages_; + int pipeline_stage_; std::vector microbatch_scopes_; std::vector skip_vars_; const Scope* minibatch_scope_;