未验证 提交 6bf298bf 编写于 作者: X xujiaqi01 提交者: GitHub

support preload thread, optimize hdfs log, fix master+patch bug (#19695)

* support preload thread
* sleep before fleet wrapper exit for pslib core dump
* optimize hdfs log
* fix master+patch bug
上级 a0d80754
......@@ -50,6 +50,7 @@ DatasetImpl<T>::DatasetImpl() {
min_merge_size_ = 2;
parse_ins_id_ = false;
parse_content_ = false;
preload_thread_num_ = 0;
}
// set filelist, file_idx_ will reset to zero.
......@@ -120,6 +121,7 @@ void DatasetImpl<T>::SetMergeByInsId(
const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas,
int min_merge_size, bool keep_unmerged_ins) {
merge_by_insid_ = true;
parse_ins_id_ = true;
merge_slots_list_ = merge_slot_list;
erase_duplicate_feas_ = erase_duplicate_feas;
min_merge_size_ = min_merge_size;
......@@ -202,10 +204,21 @@ void DatasetImpl<T>::LoadIntoMemory() {
template <typename T>
void DatasetImpl<T>::PreLoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::PreLoadIntoMemory() begin";
preload_threads_.clear();
for (int64_t i = 0; i < thread_num_; ++i) {
preload_threads_.push_back(std::thread(
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
if (preload_thread_num_ != 0) {
CHECK(preload_thread_num_ == preload_readers_.size());
preload_threads_.clear();
for (int64_t i = 0; i < preload_thread_num_; ++i) {
preload_threads_.push_back(
std::thread(&paddle::framework::DataFeed::LoadIntoMemory,
preload_readers_[i].get()));
}
} else {
CHECK(thread_num_ == readers_.size());
preload_threads_.clear();
for (int64_t i = 0; i < thread_num_; ++i) {
preload_threads_.push_back(std::thread(
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
}
}
VLOG(3) << "DatasetImpl<T>::PreLoadIntoMemory() end";
}
......@@ -420,6 +433,47 @@ void DatasetImpl<T>::DestroyReaders() {
cur_channel_ = 1 - cur_channel_;
}
template <typename T>
void DatasetImpl<T>::SetPreLoadThreadNum(int thread_num) {
preload_thread_num_ = thread_num;
}
template <typename T>
void DatasetImpl<T>::CreatePreLoadReaders() {
VLOG(3) << "Begin CreatePreLoadReaders";
if (preload_thread_num_ == 0) {
preload_thread_num_ = thread_num_;
}
CHECK(preload_thread_num_ > 0) << "thread num should > 0";
CHECK(input_channel_ != nullptr);
preload_readers_.clear();
for (int i = 0; i < preload_thread_num_; ++i) {
preload_readers_.push_back(
DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
preload_readers_[i]->Init(data_feed_desc_);
preload_readers_[i]->SetThreadId(i);
preload_readers_[i]->SetThreadNum(preload_thread_num_);
preload_readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
preload_readers_[i]->SetFileListIndex(&file_idx_);
preload_readers_[i]->SetFileList(filelist_);
preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetInputChannel(input_channel_.get());
preload_readers_[i]->SetOutputChannel(nullptr);
preload_readers_[i]->SetConsumeChannel(nullptr);
}
VLOG(3) << "End CreatePreLoadReaders";
}
template <typename T>
void DatasetImpl<T>::DestroyPreLoadReaders() {
VLOG(3) << "Begin DestroyPreLoadReaders";
preload_readers_.clear();
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(
preload_readers_);
file_idx_ = 0;
VLOG(3) << "End DestroyPreLoadReaders";
}
template <typename T>
int64_t DatasetImpl<T>::GetMemoryDataSize() {
return input_channel_->Size();
......
......@@ -114,6 +114,12 @@ class Dataset {
virtual int64_t GetShuffleDataSize() = 0;
// merge by ins id
virtual void MergeByInsId() = 0;
// create preload readers
virtual void CreatePreLoadReaders() = 0;
// destroy preload readers after prelaod done
virtual void DestroyPreLoadReaders() = 0;
// set preload thread num
virtual void SetPreLoadThreadNum(int thread_num) = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
......@@ -172,11 +178,15 @@ class DatasetImpl : public Dataset {
virtual int64_t GetMemoryDataSize();
virtual int64_t GetShuffleDataSize();
virtual void MergeByInsId() {}
virtual void CreatePreLoadReaders();
virtual void DestroyPreLoadReaders();
virtual void SetPreLoadThreadNum(int thread_num);
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
paddle::framework::Channel<T> input_channel_;
int channel_num_;
std::vector<paddle::framework::Channel<T>> multi_output_channel_;
......@@ -206,6 +216,7 @@ class DatasetImpl : public Dataset {
int min_merge_size_;
std::vector<std::string> merge_slots_list_;
bool slots_shuffle_fea_eval_ = false;
int preload_thread_num_;
};
// use std::vector<MultiSlotType> or Record as data type
......
......@@ -188,6 +188,7 @@ void FleetWrapper::PullSparseVarsSync(
auto status = t.get();
if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
}
......@@ -479,6 +480,7 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
#else
......@@ -508,6 +510,7 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "save model failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
#else
......@@ -524,6 +527,7 @@ double FleetWrapper::GetCacheThreshold() {
ret.wait();
if (cache_threshold < 0) {
LOG(ERROR) << "get cache threshold failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return cache_threshold;
......@@ -542,6 +546,7 @@ void FleetWrapper::CacheShuffle(int table_id, const std::string& path,
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "cache shuffle failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
#else
......@@ -557,6 +562,7 @@ int32_t FleetWrapper::SaveCache(int table_id, const std::string& path,
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "table save cache failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return feasign_cnt;
......@@ -626,6 +632,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope,
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push shrink dense param failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
#else
......
......@@ -55,7 +55,11 @@ namespace framework {
class FleetWrapper {
public:
virtual ~FleetWrapper() {}
FleetWrapper() { scale_sparse_gradient_with_batch_size_ = true; }
FleetWrapper() {
scale_sparse_gradient_with_batch_size_ = true;
// trainer sleep some time for pslib core dump
sleep_seconds_before_fail_exit_ = 300;
}
// Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys
// Param<out>: fea_values
......@@ -195,6 +199,7 @@ class FleetWrapper {
protected:
static bool is_initialized_;
bool scale_sparse_gradient_with_batch_size_;
int32_t sleep_seconds_before_fail_exit_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};
......
......@@ -80,9 +80,6 @@ void PullDenseWorker::Stop() {
if (running_) {
running_ = false;
t_.join();
// pull dense when stop, to make sure local dense params are same as
// pserver, so save paddle model will save dense model same as pserver
PullDense(true);
}
}
......
......@@ -111,6 +111,13 @@ void BindDataset(py::module* m) {
.def("slots_shuffle", &framework::Dataset::SlotsShuffle,
py::call_guard<py::gil_scoped_release>())
.def("set_fea_eval", &framework::Dataset::SetFeaEval,
py::call_guard<py::gil_scoped_release>())
.def("set_preload_thread_num", &framework::Dataset::SetPreLoadThreadNum,
py::call_guard<py::gil_scoped_release>())
.def("create_preload_readers", &framework::Dataset::CreatePreLoadReaders,
py::call_guard<py::gil_scoped_release>())
.def("destroy_preload_readers",
&framework::Dataset::DestroyPreLoadReaders,
py::call_guard<py::gil_scoped_release>());
}
......
......@@ -423,10 +423,13 @@ class InMemoryDataset(DatasetBase):
self._prepare_to_run()
self.dataset.load_into_memory()
def preload_into_memory(self):
def preload_into_memory(self, thread_num=None):
"""
Load data into memory in async mode
Args:
thread_num(int): preload thread num
Examples:
.. code-block:: python
......@@ -438,6 +441,10 @@ class InMemoryDataset(DatasetBase):
dataset.wait_preload_done()
"""
self._prepare_to_run()
if thread_num is None:
thread_num = self.thread_num
self.dataset.set_preload_thread_num(thread_num)
self.dataset.create_preload_readers()
self.dataset.preload_into_memory()
def wait_preload_done(self):
......@@ -455,6 +462,7 @@ class InMemoryDataset(DatasetBase):
dataset.wait_preload_done()
"""
self.dataset.wait_preload_done()
self.dataset.destroy_preload_readers()
def local_shuffle(self):
"""
......
......@@ -926,7 +926,10 @@ class FleetUtil(object):
if not client.is_exist(dest):
client.makedirs(dest)
client.upload(dest, model_name)
if os.path.isdir(model_name):
client.upload_dir(dest, model_name)
else:
client.upload(dest, model_name)
fleet._role_maker._barrier_worker()
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HDFS Utils"""
"""HDFS Utils."""
import os
import sys
......@@ -84,7 +84,7 @@ class HDFSClient(object):
ret_code, ret_out, ret_err = proc.returncode, output, errors
_logger.info(
'Times: %d, Running command: %s. Return code: %d, Error: %s' %
'Times: %d, Running command: %s. Return code: %d, Msg: %s' %
(x, whole_commands, proc.returncode, errors))
if ret_code == 0:
......@@ -93,6 +93,13 @@ class HDFSClient(object):
return ret_code, ret_out, ret_err
def cat(self, hdfs_path=None):
"""
cat hdfs file
Args:
hdfs_path(str): the hdfs file path
Returns:
file content
"""
if self.is_file(hdfs_path):
exist_cmd = ['-cat', hdfs_path]
returncode, output, errors = self.__run_hdfs_cmd(
......@@ -101,8 +108,7 @@ class HDFSClient(object):
_logger.error("HDFS cat HDFS path: {} failed".format(hdfs_path))
return ""
else:
_logger.error("HDFS cat HDFS path: {} succeed".format(
hdfs_path))
_logger.info("HDFS cat HDFS path: {} succeed".format(hdfs_path))
return output.strip()
else:
......@@ -190,7 +196,7 @@ class HDFSClient(object):
whether the remote HDFS path exists
Args:
hdfs_path: HDFS path.
hdfs_path(str): HDFS path.
Returns:
True or False
......@@ -224,9 +230,10 @@ class HDFSClient(object):
Move a file or folder on HDFS.
Args:
hdfs_path(str): HDFS path.
overwrite(bool|False): If the path already exists and overwrite is False, will return False.
hdfs_src_path(str): HDFS path
hdfs_dst_path(str): HDFS path
overwrite(bool|False): If the path already exists and overwrite is
False, will return False.
Returns:
True or False
"""
......@@ -256,8 +263,9 @@ class HDFSClient(object):
def make_local_dirs(local_path):
"""
create a directiory local, is same to mkdir
Args:
local_path: local path that wants to create a directiory.
local_path(str): local path that wants to create a directiory.
"""
try:
os.makedirs(local_path)
......@@ -270,7 +278,8 @@ class HDFSClient(object):
Create a remote directory, recursively if necessary.
Args:
hdfs_path(str): Remote path. Intermediate directories will be created appropriately.
hdfs_path(str): Remote path. Intermediate directories will be
created appropriately.
Returns:
True or False
......@@ -290,7 +299,7 @@ class HDFSClient(object):
_logger.error("HDFS mkdir path: {} failed".format(hdfs_path))
return False
else:
_logger.error("HDFS mkdir path: {} successfully".format(hdfs_path))
_logger.info("HDFS mkdir path: {} successfully".format(hdfs_path))
return True
def ls(self, hdfs_path):
......@@ -298,7 +307,7 @@ class HDFSClient(object):
ls directory contents about HDFS hdfs_path
Args:
hdfs_path(str): Remote HDFS path will be ls.
hdfs_path(str): Remote HDFS path will be ls.
Returns:
List: a contents list about hdfs_path.
......@@ -332,9 +341,8 @@ class HDFSClient(object):
list directory contents about HDFS hdfs_path recursively
Args:
hdfs_path(str): Remote HDFS path.
only_file(bool|True): will discard folders.
sort(bool|True): will be sorted by create time.
hdfs_path(str): Remote HDFS path.
excludes(list): excludes
Returns:
List: a contents list about hdfs_path.
......@@ -373,7 +381,18 @@ class HDFSClient(object):
return ret_lines
@staticmethod
def split_flies(files, trainer_id, trainers):
def split_files(files, trainer_id, trainers):
"""
split file list
Args:
files(list): file list
trainer_id(int): trainer mpi rank id
trainers(int): all trainers num
Returns:
fileist(list): file list of current trainer
"""
remainder = len(files) % trainers
blocksize = len(files) / trainers
......@@ -402,6 +421,8 @@ class HDFSClient(object):
hdfs_path(str): path on hdfs
local_path(str): path on local
multi_processes(int|5): the download data process at the same time, default=5
overwrite(bool): is overwrite
retry_times(int): retry times
Returns:
List:
......@@ -478,7 +499,7 @@ class HDFSClient(object):
local_path(str): path on local
multi_processes(int|5): the upload data process at the same time, default=5
overwrite(bool|False): will overwrite file on HDFS or not
sync(bool|True): upload files sync or not.
retry_times(int): upload file max retry time.
Returns:
None
......@@ -497,6 +518,15 @@ class HDFSClient(object):
return True
def get_local_files(path):
"""
get local files
Args:
path(str): local path
Returns:
list of local files
"""
rlist = []
if not os.path.exists(path):
......@@ -537,6 +567,32 @@ class HDFSClient(object):
_logger.info("Finish upload datas from {} to {}".format(local_path,
hdfs_path))
def upload_dir(self, dest_dir, local_dir, overwrite=False):
"""
upload dir to hdfs
Args:
dest_dir(str): hdfs dest dir
local_dir(str): hdfs local dir
overwrite(bool): is overwrite
Returns:
return code
"""
local_dir = local_dir.rstrip("/")
dest_dir = dest_dir.rstrip("/")
local_basename = os.path.basename(local_dir)
if self.is_exist(dest_dir + "/" + local_basename) and overwrite:
self.delete(dest_dir + "/" + local_basename)
if not self.is_exist(dest_dir):
self.makedirs(dest_dir)
put_command = ["-put", local_dir, dest_dir]
returncode, output, errors = self.__run_hdfs_cmd(put_command,
retry_times)
if returncode != 0:
_logger.error("Put local dir: {} to HDFS dir: {} failed".format(
local_dir, dest_dir))
return False
return True
if __name__ == "__main__":
hadoop_home = "/home/client/hadoop-client/hadoop/"
......
......@@ -233,6 +233,14 @@ class TestDataset(unittest.TestCase):
except Exception as e:
self.assertTrue(False)
dataset.set_merge_by_lineid(slots_vars)
dataset.preload_into_memory()
dataset.wait_preload_done()
dataset.release_memory()
dataset.preload_into_memory(1)
dataset.wait_preload_done()
fleet_ptr = fluid.core.Fleet()
os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册