diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc index e0d23d7222345de18bdfd3acae777467249001d4..59c5b373d12ac7f7fffa787990730c9b3f4897e5 100755 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -77,6 +77,7 @@ int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) { return 0; } int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) { + ContextStatusGurad status_guard(_context_ptr, TrainerStatus::Saving); auto fs = _context_ptr->file_system; auto* ps_client = _context_ptr->pslib->ps_client(); auto* environment = _context_ptr->environment.get(); @@ -154,6 +155,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) { if (!environment->is_master_node(EnvironmentRole::WORKER)) { return 0; } + VLOG(2) << "Start Load Model"; auto* fs = _context_ptr->file_system.get(); std::set loaded_table_set; auto model_dir = _context_ptr->epoch_accessor->checkpoint_path(); @@ -177,6 +179,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) { loaded_table_set.insert(itr.first); } } + VLOG(2) << "Finish Load Model"; return 0; } @@ -223,6 +226,7 @@ int LearnerProcess::run() { //Step2. 运行训练网络 { + ContextStatusGurad status_guard(_context_ptr, TrainerStatus::Training); std::map> backup_input_map; for (auto& executor : _executors) { environment->barrier(EnvironmentRole::WORKER); diff --git a/paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc b/paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc index 10a61fcbdafb03466b07c11f8509badb0dfb3819..e4a0876fa727f86d1b00da8cb6b7fb536d0e7137 100644 --- a/paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc +++ b/paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc @@ -39,8 +39,8 @@ public: virtual int initialize(YAML::Node config, std::shared_ptr context_ptr) { Shuffler::initialize(config, context_ptr); - _max_concurrent_num = config["max_concurrent_num"].as(4); // 最大并发发送数 - _max_package_size = config["max_package_size"].as(1024); // 最大包个数,一次发送package个数据 + _max_concurrent_num = config["max_concurrent_num"].as(6); // 最大并发发送数 + _max_package_size = config["max_package_size"].as(256); // 最大包个数,一次发送package个数据 _shuffle_data_msg_type = config["shuffle_data_msg_type"].as(3); // c2c msg type _finish_msg_type = config["finish_msg_type"].as(4); // c2c msg type @@ -62,6 +62,8 @@ public: data_channel.swap(input_channel); set_channel(data_channel); + _item_send_count = 0; + _item_receive_count = 0; auto* environment = _trainer_context->environment.get(); auto worker_num = environment->node_num(EnvironmentRole::WORKER); std::vector>> waits(concurrent_num); @@ -86,8 +88,9 @@ public: status = 1; break; } + _item_send_count += read_size; for (int i = 0; i < worker_num; ++i) { - send_buffer_worker.clear(); + send_buffer_worker[i].clear(); } for (int i = 0; i < read_size; ++i) { auto worker_idx = _shuffle_key_func(send_buffer[i].id) % worker_num; @@ -119,19 +122,19 @@ public: } } } - VLOG(5) << "start send finish, worker_num: " << worker_num; + VLOG(2) << "start send finish, worker_num: " << worker_num; waits[0].clear(); for (int i = 0; i < worker_num; ++i) { waits[0].push_back(send_finish(i)); } - VLOG(5) << "wait all finish"; + VLOG(2) << "wait all finish"; for (int i = 0; i < worker_num; ++i) { if (waits[0][i].get() != 0) { LOG(WARNING) << "fail to send finish " << i; status = -1; } } - VLOG(5) << "finish shuffler, status: " << status; + VLOG(2) << "finish shuffler_send_channel, total_send:" << _item_send_count; return status < 0 ? status : 0; } @@ -174,6 +177,7 @@ private: // 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel if (wait_num == 0) { reset_channel(); + VLOG(2) << "finish shuffle_receive_channel, receive_count: " << _item_receive_count; _wait_num_mutex.unlock(); } else { std::lock_guard lock(_wait_num_mutex); @@ -182,7 +186,7 @@ private: } int32_t write_to_channel(std::vector&& items) { size_t items_size = items.size(); - VLOG(5) << "write_to_channel, items_size: " << items_size; + _item_receive_count += items_size; return _out_channel->Write(std::move(items)) == items_size ? 0 : -1; } @@ -207,6 +211,8 @@ private: } std::future send_shuffle_data(int to_client_id, std::vector& items) { + // server端也开启了client, worker节点为偶数编号 + to_client_id = 2 * to_client_id; VLOG(5) << "send_shuffle_data, to_client_id: " << to_client_id << ", items_size: " << items.size(); paddle::framework::BinaryArchive ar; ar << items; @@ -215,6 +221,8 @@ private: } std::future send_finish(int to_client_id) { + // server端也开启了client, worker节点为偶数编号 + to_client_id = 2 * to_client_id; VLOG(5) << "send_finish, to_client_id: " << to_client_id; static const std::string empty_str; return _trainer_context->pslib->ps_client()->send_client2client_msg(_finish_msg_type, to_client_id, empty_str); @@ -230,6 +238,8 @@ private: bthread::Mutex _wait_num_mutex; std::atomic _wait_num; + std::atomic _item_send_count; + std::atomic _item_receive_count; }; REGIST_CLASS(Shuffler, GlobalShuffler); diff --git a/paddle/fluid/train/custom_trainer/feed/trainer_context.h b/paddle/fluid/train/custom_trainer/feed/trainer_context.h index f0af4e9354ca4b621be943dbca418f2ac3ce704e..6e61fbf5d62d3890542e9d0b8c072c9f4d4a7883 100755 --- a/paddle/fluid/train/custom_trainer/feed/trainer_context.h +++ b/paddle/fluid/train/custom_trainer/feed/trainer_context.h @@ -95,20 +95,23 @@ private: class TrainerContext { public: + TrainerContext() { + trainer_status.resize(2, 0); + } inline paddle::ps::PSClient* ps_client() { return pslib->ps_client(); } inline bool is_status(TrainerStatus status) { - auto bit_idx = static_cast(status); - return ((trainer_status >> bit_idx) & 1) > 0; + auto status_idx = static_cast(status); + return trainer_status[status_idx] > 0; } // 非线程安全, 其实TrainerContext所有成员的线程安全性 取决于 成员本身的线程安全性 inline void set_status(TrainerStatus status, bool on) { - auto bit_idx = static_cast(status); - trainer_status = trainer_status & (1L << bit_idx); + auto status_idx = static_cast(status); + trainer_status[status_idx] = on ? 1 : 0; } - uint32_t trainer_status; // trainer当前,由于可同时处于多种状态,这里分bit存储状态 + std::vector trainer_status; YAML::Node trainer_config; paddle::platform::CPUPlace cpu_place; @@ -122,6 +125,20 @@ public: std::shared_ptr cache_dict; //大模型cache词典 }; +class ContextStatusGurad { +public: + ContextStatusGurad(TrainerContext* context, TrainerStatus status) : + _context(context), _status(status) { + _context->set_status(_status, true); + } + virtual ~ContextStatusGurad() { + _context->set_status(_status, false); + } +private: + TrainerStatus _status; + TrainerContext* _context = nullptr; +}; + } // namespace feed } // namespace custom_trainer } // namespace paddle