未验证 提交 6bf298bf 编写于 作者: X xujiaqi01 提交者: GitHub

support preload thread, optimize hdfs log, fix master+patch bug (#19695)

* support preload thread
* sleep before fleet wrapper exit for pslib core dump
* optimize hdfs log
* fix master+patch bug
上级 a0d80754
...@@ -50,6 +50,7 @@ DatasetImpl<T>::DatasetImpl() { ...@@ -50,6 +50,7 @@ DatasetImpl<T>::DatasetImpl() {
min_merge_size_ = 2; min_merge_size_ = 2;
parse_ins_id_ = false; parse_ins_id_ = false;
parse_content_ = false; parse_content_ = false;
preload_thread_num_ = 0;
} }
// set filelist, file_idx_ will reset to zero. // set filelist, file_idx_ will reset to zero.
...@@ -120,6 +121,7 @@ void DatasetImpl<T>::SetMergeByInsId( ...@@ -120,6 +121,7 @@ void DatasetImpl<T>::SetMergeByInsId(
const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas, const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas,
int min_merge_size, bool keep_unmerged_ins) { int min_merge_size, bool keep_unmerged_ins) {
merge_by_insid_ = true; merge_by_insid_ = true;
parse_ins_id_ = true;
merge_slots_list_ = merge_slot_list; merge_slots_list_ = merge_slot_list;
erase_duplicate_feas_ = erase_duplicate_feas; erase_duplicate_feas_ = erase_duplicate_feas;
min_merge_size_ = min_merge_size; min_merge_size_ = min_merge_size;
...@@ -202,11 +204,22 @@ void DatasetImpl<T>::LoadIntoMemory() { ...@@ -202,11 +204,22 @@ void DatasetImpl<T>::LoadIntoMemory() {
template <typename T> template <typename T>
void DatasetImpl<T>::PreLoadIntoMemory() { void DatasetImpl<T>::PreLoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::PreLoadIntoMemory() begin"; VLOG(3) << "DatasetImpl<T>::PreLoadIntoMemory() begin";
if (preload_thread_num_ != 0) {
CHECK(preload_thread_num_ == preload_readers_.size());
preload_threads_.clear();
for (int64_t i = 0; i < preload_thread_num_; ++i) {
preload_threads_.push_back(
std::thread(&paddle::framework::DataFeed::LoadIntoMemory,
preload_readers_[i].get()));
}
} else {
CHECK(thread_num_ == readers_.size());
preload_threads_.clear(); preload_threads_.clear();
for (int64_t i = 0; i < thread_num_; ++i) { for (int64_t i = 0; i < thread_num_; ++i) {
preload_threads_.push_back(std::thread( preload_threads_.push_back(std::thread(
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get())); &paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
} }
}
VLOG(3) << "DatasetImpl<T>::PreLoadIntoMemory() end"; VLOG(3) << "DatasetImpl<T>::PreLoadIntoMemory() end";
} }
...@@ -420,6 +433,47 @@ void DatasetImpl<T>::DestroyReaders() { ...@@ -420,6 +433,47 @@ void DatasetImpl<T>::DestroyReaders() {
cur_channel_ = 1 - cur_channel_; cur_channel_ = 1 - cur_channel_;
} }
template <typename T>
void DatasetImpl<T>::SetPreLoadThreadNum(int thread_num) {
preload_thread_num_ = thread_num;
}
template <typename T>
void DatasetImpl<T>::CreatePreLoadReaders() {
VLOG(3) << "Begin CreatePreLoadReaders";
if (preload_thread_num_ == 0) {
preload_thread_num_ = thread_num_;
}
CHECK(preload_thread_num_ > 0) << "thread num should > 0";
CHECK(input_channel_ != nullptr);
preload_readers_.clear();
for (int i = 0; i < preload_thread_num_; ++i) {
preload_readers_.push_back(
DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
preload_readers_[i]->Init(data_feed_desc_);
preload_readers_[i]->SetThreadId(i);
preload_readers_[i]->SetThreadNum(preload_thread_num_);
preload_readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
preload_readers_[i]->SetFileListIndex(&file_idx_);
preload_readers_[i]->SetFileList(filelist_);
preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetInputChannel(input_channel_.get());
preload_readers_[i]->SetOutputChannel(nullptr);
preload_readers_[i]->SetConsumeChannel(nullptr);
}
VLOG(3) << "End CreatePreLoadReaders";
}
template <typename T>
void DatasetImpl<T>::DestroyPreLoadReaders() {
VLOG(3) << "Begin DestroyPreLoadReaders";
preload_readers_.clear();
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(
preload_readers_);
file_idx_ = 0;
VLOG(3) << "End DestroyPreLoadReaders";
}
template <typename T> template <typename T>
int64_t DatasetImpl<T>::GetMemoryDataSize() { int64_t DatasetImpl<T>::GetMemoryDataSize() {
return input_channel_->Size(); return input_channel_->Size();
......
...@@ -114,6 +114,12 @@ class Dataset { ...@@ -114,6 +114,12 @@ class Dataset {
virtual int64_t GetShuffleDataSize() = 0; virtual int64_t GetShuffleDataSize() = 0;
// merge by ins id // merge by ins id
virtual void MergeByInsId() = 0; virtual void MergeByInsId() = 0;
// create preload readers
virtual void CreatePreLoadReaders() = 0;
// destroy preload readers after prelaod done
virtual void DestroyPreLoadReaders() = 0;
// set preload thread num
virtual void SetPreLoadThreadNum(int thread_num) = 0;
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
...@@ -172,11 +178,15 @@ class DatasetImpl : public Dataset { ...@@ -172,11 +178,15 @@ class DatasetImpl : public Dataset {
virtual int64_t GetMemoryDataSize(); virtual int64_t GetMemoryDataSize();
virtual int64_t GetShuffleDataSize(); virtual int64_t GetShuffleDataSize();
virtual void MergeByInsId() {} virtual void MergeByInsId() {}
virtual void CreatePreLoadReaders();
virtual void DestroyPreLoadReaders();
virtual void SetPreLoadThreadNum(int thread_num);
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg); const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_; std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
paddle::framework::Channel<T> input_channel_; paddle::framework::Channel<T> input_channel_;
int channel_num_; int channel_num_;
std::vector<paddle::framework::Channel<T>> multi_output_channel_; std::vector<paddle::framework::Channel<T>> multi_output_channel_;
...@@ -206,6 +216,7 @@ class DatasetImpl : public Dataset { ...@@ -206,6 +216,7 @@ class DatasetImpl : public Dataset {
int min_merge_size_; int min_merge_size_;
std::vector<std::string> merge_slots_list_; std::vector<std::string> merge_slots_list_;
bool slots_shuffle_fea_eval_ = false; bool slots_shuffle_fea_eval_ = false;
int preload_thread_num_;
}; };
// use std::vector<MultiSlotType> or Record as data type // use std::vector<MultiSlotType> or Record as data type
......
...@@ -188,6 +188,7 @@ void FleetWrapper::PullSparseVarsSync( ...@@ -188,6 +188,7 @@ void FleetWrapper::PullSparseVarsSync(
auto status = t.get(); auto status = t.get();
if (status != 0) { if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
exit(-1); exit(-1);
} }
} }
...@@ -479,6 +480,7 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) { ...@@ -479,6 +480,7 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {
ret.wait(); ret.wait();
if (ret.get() != 0) { if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed"; LOG(ERROR) << "load model from path:" << path << " failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1); exit(-1);
} }
#else #else
...@@ -508,6 +510,7 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { ...@@ -508,6 +510,7 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
int32_t feasign_cnt = ret.get(); int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { if (feasign_cnt == -1) {
LOG(ERROR) << "save model failed"; LOG(ERROR) << "save model failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1); exit(-1);
} }
#else #else
...@@ -524,6 +527,7 @@ double FleetWrapper::GetCacheThreshold() { ...@@ -524,6 +527,7 @@ double FleetWrapper::GetCacheThreshold() {
ret.wait(); ret.wait();
if (cache_threshold < 0) { if (cache_threshold < 0) {
LOG(ERROR) << "get cache threshold failed"; LOG(ERROR) << "get cache threshold failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1); exit(-1);
} }
return cache_threshold; return cache_threshold;
...@@ -542,6 +546,7 @@ void FleetWrapper::CacheShuffle(int table_id, const std::string& path, ...@@ -542,6 +546,7 @@ void FleetWrapper::CacheShuffle(int table_id, const std::string& path,
int32_t feasign_cnt = ret.get(); int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { if (feasign_cnt == -1) {
LOG(ERROR) << "cache shuffle failed"; LOG(ERROR) << "cache shuffle failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1); exit(-1);
} }
#else #else
...@@ -557,6 +562,7 @@ int32_t FleetWrapper::SaveCache(int table_id, const std::string& path, ...@@ -557,6 +562,7 @@ int32_t FleetWrapper::SaveCache(int table_id, const std::string& path,
int32_t feasign_cnt = ret.get(); int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { if (feasign_cnt == -1) {
LOG(ERROR) << "table save cache failed"; LOG(ERROR) << "table save cache failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1); exit(-1);
} }
return feasign_cnt; return feasign_cnt;
...@@ -626,6 +632,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope, ...@@ -626,6 +632,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope,
auto status = push_status.get(); auto status = push_status.get();
if (status != 0) { if (status != 0) {
LOG(FATAL) << "push shrink dense param failed, status[" << status << "]"; LOG(FATAL) << "push shrink dense param failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
exit(-1); exit(-1);
} }
#else #else
......
...@@ -55,7 +55,11 @@ namespace framework { ...@@ -55,7 +55,11 @@ namespace framework {
class FleetWrapper { class FleetWrapper {
public: public:
virtual ~FleetWrapper() {} virtual ~FleetWrapper() {}
FleetWrapper() { scale_sparse_gradient_with_batch_size_ = true; } FleetWrapper() {
scale_sparse_gradient_with_batch_size_ = true;
// trainer sleep some time for pslib core dump
sleep_seconds_before_fail_exit_ = 300;
}
// Pull sparse variables from server in Sync mode // Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys // Param<in>: scope, table_id, var_names, fea_keys
// Param<out>: fea_values // Param<out>: fea_values
...@@ -195,6 +199,7 @@ class FleetWrapper { ...@@ -195,6 +199,7 @@ class FleetWrapper {
protected: protected:
static bool is_initialized_; static bool is_initialized_;
bool scale_sparse_gradient_with_batch_size_; bool scale_sparse_gradient_with_batch_size_;
int32_t sleep_seconds_before_fail_exit_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper); DISABLE_COPY_AND_ASSIGN(FleetWrapper);
}; };
......
...@@ -80,9 +80,6 @@ void PullDenseWorker::Stop() { ...@@ -80,9 +80,6 @@ void PullDenseWorker::Stop() {
if (running_) { if (running_) {
running_ = false; running_ = false;
t_.join(); t_.join();
// pull dense when stop, to make sure local dense params are same as
// pserver, so save paddle model will save dense model same as pserver
PullDense(true);
} }
} }
......
...@@ -111,6 +111,13 @@ void BindDataset(py::module* m) { ...@@ -111,6 +111,13 @@ void BindDataset(py::module* m) {
.def("slots_shuffle", &framework::Dataset::SlotsShuffle, .def("slots_shuffle", &framework::Dataset::SlotsShuffle,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_fea_eval", &framework::Dataset::SetFeaEval, .def("set_fea_eval", &framework::Dataset::SetFeaEval,
py::call_guard<py::gil_scoped_release>())
.def("set_preload_thread_num", &framework::Dataset::SetPreLoadThreadNum,
py::call_guard<py::gil_scoped_release>())
.def("create_preload_readers", &framework::Dataset::CreatePreLoadReaders,
py::call_guard<py::gil_scoped_release>())
.def("destroy_preload_readers",
&framework::Dataset::DestroyPreLoadReaders,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} }
......
...@@ -423,10 +423,13 @@ class InMemoryDataset(DatasetBase): ...@@ -423,10 +423,13 @@ class InMemoryDataset(DatasetBase):
self._prepare_to_run() self._prepare_to_run()
self.dataset.load_into_memory() self.dataset.load_into_memory()
def preload_into_memory(self): def preload_into_memory(self, thread_num=None):
""" """
Load data into memory in async mode Load data into memory in async mode
Args:
thread_num(int): preload thread num
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -438,6 +441,10 @@ class InMemoryDataset(DatasetBase): ...@@ -438,6 +441,10 @@ class InMemoryDataset(DatasetBase):
dataset.wait_preload_done() dataset.wait_preload_done()
""" """
self._prepare_to_run() self._prepare_to_run()
if thread_num is None:
thread_num = self.thread_num
self.dataset.set_preload_thread_num(thread_num)
self.dataset.create_preload_readers()
self.dataset.preload_into_memory() self.dataset.preload_into_memory()
def wait_preload_done(self): def wait_preload_done(self):
...@@ -455,6 +462,7 @@ class InMemoryDataset(DatasetBase): ...@@ -455,6 +462,7 @@ class InMemoryDataset(DatasetBase):
dataset.wait_preload_done() dataset.wait_preload_done()
""" """
self.dataset.wait_preload_done() self.dataset.wait_preload_done()
self.dataset.destroy_preload_readers()
def local_shuffle(self): def local_shuffle(self):
""" """
......
...@@ -926,6 +926,9 @@ class FleetUtil(object): ...@@ -926,6 +926,9 @@ class FleetUtil(object):
if not client.is_exist(dest): if not client.is_exist(dest):
client.makedirs(dest) client.makedirs(dest)
if os.path.isdir(model_name):
client.upload_dir(dest, model_name)
else:
client.upload(dest, model_name) client.upload(dest, model_name)
fleet._role_maker._barrier_worker() fleet._role_maker._barrier_worker()
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""HDFS Utils""" """HDFS Utils."""
import os import os
import sys import sys
...@@ -84,7 +84,7 @@ class HDFSClient(object): ...@@ -84,7 +84,7 @@ class HDFSClient(object):
ret_code, ret_out, ret_err = proc.returncode, output, errors ret_code, ret_out, ret_err = proc.returncode, output, errors
_logger.info( _logger.info(
'Times: %d, Running command: %s. Return code: %d, Error: %s' % 'Times: %d, Running command: %s. Return code: %d, Msg: %s' %
(x, whole_commands, proc.returncode, errors)) (x, whole_commands, proc.returncode, errors))
if ret_code == 0: if ret_code == 0:
...@@ -93,6 +93,13 @@ class HDFSClient(object): ...@@ -93,6 +93,13 @@ class HDFSClient(object):
return ret_code, ret_out, ret_err return ret_code, ret_out, ret_err
def cat(self, hdfs_path=None): def cat(self, hdfs_path=None):
"""
cat hdfs file
Args:
hdfs_path(str): the hdfs file path
Returns:
file content
"""
if self.is_file(hdfs_path): if self.is_file(hdfs_path):
exist_cmd = ['-cat', hdfs_path] exist_cmd = ['-cat', hdfs_path]
returncode, output, errors = self.__run_hdfs_cmd( returncode, output, errors = self.__run_hdfs_cmd(
...@@ -101,8 +108,7 @@ class HDFSClient(object): ...@@ -101,8 +108,7 @@ class HDFSClient(object):
_logger.error("HDFS cat HDFS path: {} failed".format(hdfs_path)) _logger.error("HDFS cat HDFS path: {} failed".format(hdfs_path))
return "" return ""
else: else:
_logger.error("HDFS cat HDFS path: {} succeed".format( _logger.info("HDFS cat HDFS path: {} succeed".format(hdfs_path))
hdfs_path))
return output.strip() return output.strip()
else: else:
...@@ -190,7 +196,7 @@ class HDFSClient(object): ...@@ -190,7 +196,7 @@ class HDFSClient(object):
whether the remote HDFS path exists whether the remote HDFS path exists
Args: Args:
hdfs_path: HDFS path. hdfs_path(str): HDFS path.
Returns: Returns:
True or False True or False
...@@ -224,9 +230,10 @@ class HDFSClient(object): ...@@ -224,9 +230,10 @@ class HDFSClient(object):
Move a file or folder on HDFS. Move a file or folder on HDFS.
Args: Args:
hdfs_path(str): HDFS path. hdfs_src_path(str): HDFS path
overwrite(bool|False): If the path already exists and overwrite is False, will return False. hdfs_dst_path(str): HDFS path
overwrite(bool|False): If the path already exists and overwrite is
False, will return False.
Returns: Returns:
True or False True or False
""" """
...@@ -256,8 +263,9 @@ class HDFSClient(object): ...@@ -256,8 +263,9 @@ class HDFSClient(object):
def make_local_dirs(local_path): def make_local_dirs(local_path):
""" """
create a directiory local, is same to mkdir create a directiory local, is same to mkdir
Args: Args:
local_path: local path that wants to create a directiory. local_path(str): local path that wants to create a directiory.
""" """
try: try:
os.makedirs(local_path) os.makedirs(local_path)
...@@ -270,7 +278,8 @@ class HDFSClient(object): ...@@ -270,7 +278,8 @@ class HDFSClient(object):
Create a remote directory, recursively if necessary. Create a remote directory, recursively if necessary.
Args: Args:
hdfs_path(str): Remote path. Intermediate directories will be created appropriately. hdfs_path(str): Remote path. Intermediate directories will be
created appropriately.
Returns: Returns:
True or False True or False
...@@ -290,7 +299,7 @@ class HDFSClient(object): ...@@ -290,7 +299,7 @@ class HDFSClient(object):
_logger.error("HDFS mkdir path: {} failed".format(hdfs_path)) _logger.error("HDFS mkdir path: {} failed".format(hdfs_path))
return False return False
else: else:
_logger.error("HDFS mkdir path: {} successfully".format(hdfs_path)) _logger.info("HDFS mkdir path: {} successfully".format(hdfs_path))
return True return True
def ls(self, hdfs_path): def ls(self, hdfs_path):
...@@ -333,8 +342,7 @@ class HDFSClient(object): ...@@ -333,8 +342,7 @@ class HDFSClient(object):
Args: Args:
hdfs_path(str): Remote HDFS path. hdfs_path(str): Remote HDFS path.
only_file(bool|True): will discard folders. excludes(list): excludes
sort(bool|True): will be sorted by create time.
Returns: Returns:
List: a contents list about hdfs_path. List: a contents list about hdfs_path.
...@@ -373,7 +381,18 @@ class HDFSClient(object): ...@@ -373,7 +381,18 @@ class HDFSClient(object):
return ret_lines return ret_lines
@staticmethod @staticmethod
def split_flies(files, trainer_id, trainers): def split_files(files, trainer_id, trainers):
"""
split file list
Args:
files(list): file list
trainer_id(int): trainer mpi rank id
trainers(int): all trainers num
Returns:
fileist(list): file list of current trainer
"""
remainder = len(files) % trainers remainder = len(files) % trainers
blocksize = len(files) / trainers blocksize = len(files) / trainers
...@@ -402,6 +421,8 @@ class HDFSClient(object): ...@@ -402,6 +421,8 @@ class HDFSClient(object):
hdfs_path(str): path on hdfs hdfs_path(str): path on hdfs
local_path(str): path on local local_path(str): path on local
multi_processes(int|5): the download data process at the same time, default=5 multi_processes(int|5): the download data process at the same time, default=5
overwrite(bool): is overwrite
retry_times(int): retry times
Returns: Returns:
List: List:
...@@ -478,7 +499,7 @@ class HDFSClient(object): ...@@ -478,7 +499,7 @@ class HDFSClient(object):
local_path(str): path on local local_path(str): path on local
multi_processes(int|5): the upload data process at the same time, default=5 multi_processes(int|5): the upload data process at the same time, default=5
overwrite(bool|False): will overwrite file on HDFS or not overwrite(bool|False): will overwrite file on HDFS or not
sync(bool|True): upload files sync or not. retry_times(int): upload file max retry time.
Returns: Returns:
None None
...@@ -497,6 +518,15 @@ class HDFSClient(object): ...@@ -497,6 +518,15 @@ class HDFSClient(object):
return True return True
def get_local_files(path): def get_local_files(path):
"""
get local files
Args:
path(str): local path
Returns:
list of local files
"""
rlist = [] rlist = []
if not os.path.exists(path): if not os.path.exists(path):
...@@ -537,6 +567,32 @@ class HDFSClient(object): ...@@ -537,6 +567,32 @@ class HDFSClient(object):
_logger.info("Finish upload datas from {} to {}".format(local_path, _logger.info("Finish upload datas from {} to {}".format(local_path,
hdfs_path)) hdfs_path))
def upload_dir(self, dest_dir, local_dir, overwrite=False):
"""
upload dir to hdfs
Args:
dest_dir(str): hdfs dest dir
local_dir(str): hdfs local dir
overwrite(bool): is overwrite
Returns:
return code
"""
local_dir = local_dir.rstrip("/")
dest_dir = dest_dir.rstrip("/")
local_basename = os.path.basename(local_dir)
if self.is_exist(dest_dir + "/" + local_basename) and overwrite:
self.delete(dest_dir + "/" + local_basename)
if not self.is_exist(dest_dir):
self.makedirs(dest_dir)
put_command = ["-put", local_dir, dest_dir]
returncode, output, errors = self.__run_hdfs_cmd(put_command,
retry_times)
if returncode != 0:
_logger.error("Put local dir: {} to HDFS dir: {} failed".format(
local_dir, dest_dir))
return False
return True
if __name__ == "__main__": if __name__ == "__main__":
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
......
...@@ -233,6 +233,14 @@ class TestDataset(unittest.TestCase): ...@@ -233,6 +233,14 @@ class TestDataset(unittest.TestCase):
except Exception as e: except Exception as e:
self.assertTrue(False) self.assertTrue(False)
dataset.set_merge_by_lineid(slots_vars)
dataset.preload_into_memory()
dataset.wait_preload_done()
dataset.release_memory()
dataset.preload_into_memory(1)
dataset.wait_preload_done()
fleet_ptr = fluid.core.Fleet()
os.remove("./test_in_memory_dataset_run_a.txt") os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt") os.remove("./test_in_memory_dataset_run_b.txt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册