/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include // NOLINT #include #include // NOLINT #include #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/trainer_desc.pb.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/timer.h" namespace paddle { namespace framework { class PullDenseWorker { public: virtual ~PullDenseWorker() {} virtual void Initialize(const TrainerDesc& param); int Start(); void Stop(); void SetRootScope(Scope* scope) { root_scope_ = scope; } void IncreaseThreadVersion(int thread_id, uint64_t table_id); void ResetThreadVersion(uint64_t table_id); void Wait(std::vector<::std::future>* status_vec); static std::shared_ptr GetInstance() { if (NULL == s_instance_) { s_instance_.reset(new paddle::framework::PullDenseWorker()); } return s_instance_; } private: PullDenseWorker() : root_scope_(NULL) {} void Run(); bool CheckUpdateParam(uint64_t table_id); private: static std::shared_ptr s_instance_; std::shared_ptr fleet_ptr_; PullDenseWorkerParameter param_; DownpourWorkerParameter dwp_param_; Scope* root_scope_; bool running_; static std::map last_versions_; static std::map current_version_; static std::mutex mutex_for_version_; static std::map> training_versions_; static std::map> dense_value_names_; std::thread t_; int thread_num_; int sleep_time_ms_; int threshold_; std::vector<::std::future> pull_dense_status_; uint32_t pull_dense_fail_times_ = 0; std::vector base_norm_param_; std::vector mean_; std::vector scale_; float squared_sum_epsilon_ = 1e-4; std::mutex mutex_for_mean_scale_; float total_batch_num_ = 0; }; // should incorporate different type of device class DeviceWorker { public: DeviceWorker() {} virtual ~DeviceWorker() {} virtual void Initialize(const TrainerDesc& desc) = 0; virtual void SetDeviceIndex(int tid) = 0; virtual void TrainFiles() = 0; virtual void PrintFetchVars() = 0; virtual void TrainFilesWithProfiler() = 0; virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0; // will make this zero copy in the future virtual void BindingDataFeedMemory() = 0; virtual void SetRootScope(Scope* root_scope); virtual void SetDataFeed(const std::shared_ptr& data_feed); virtual void SetPlace(const paddle::platform::Place& place) { place_ = place; } protected: Scope* root_scope_; paddle::platform::Place place_; std::shared_ptr device_reader_; int64_t batch_num_; FetchConfig fetch_config_; }; class CPUWorkerBase : public DeviceWorker { public: CPUWorkerBase() {} virtual ~CPUWorkerBase() {} virtual void SetDeviceIndex(int tid) { thread_id_ = tid; } virtual void TrainFiles() = 0; virtual void TrainFilesWithProfiler() {} virtual void PrintFetchVars() {} virtual void CreateDeviceResource(const ProgramDesc& main_prog) {} protected: int thread_id_; }; class HogwildWorker : public CPUWorkerBase { public: HogwildWorker() {} virtual ~HogwildWorker() {} virtual void Initialize(const TrainerDesc& desc); virtual void TrainFiles(); virtual void TrainFilesWithProfiler(); virtual void PrintFetchVars(); virtual void CreateDeviceResource(const ProgramDesc& main_prog); virtual void BindingDataFeedMemory(); protected: void CreateThreadOperators(const ProgramDesc& program); void CreateThreadScope(const ProgramDesc& program); std::vector op_names_; std::vector ops_; Scope* thread_scope_; HogwildWorkerParameter param_; std::vector skip_ops_; }; class DownpourWorker : public HogwildWorker { public: DownpourWorker() {} virtual ~DownpourWorker() {} virtual void Initialize(const TrainerDesc& desc); virtual void TrainFiles(); virtual void TrainFilesWithProfiler(); protected: std::shared_ptr fleet_ptr_; std::shared_ptr pull_dense_worker_; void FillSparseValue(size_t table_id); void PushGradients(); void CollectLabelInfo(size_t table_id); private: bool need_to_push_dense_; bool need_to_push_sparse_; DownpourWorkerParameter param_; // just save the value in param_ for easy access std::map label_var_name_; std::map> sparse_key_names_; std::map> sparse_value_names_; std::map> sparse_grad_names_; std::map> dense_value_names_; std::map> dense_grad_names_; // feasign std::map> features_; // feasign stats std::map> feature_labels_; // feasign embedding std::map>> feature_values_; // feasign embedding gradient std::map>> feature_grads_; // skipped ops std::vector skip_ops_; std::shared_ptr _pull_dense_worker; std::vector<::std::future> push_sparse_status_; std::vector<::std::future> push_dense_status_; }; } // namespace framework } // namespace paddle