未验证 提交 53a2b68f 编写于 作者: H hutuxian 提交者: GitHub

support customized download command in dataset (#22782)

* user can call dataset.set_download_cmd to set its customized download cmd
* add UT to cover this scenario
上级 1a2baa18
...@@ -95,6 +95,16 @@ void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name, ...@@ -95,6 +95,16 @@ void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
paddle::framework::hdfs_set_command(cmd); paddle::framework::hdfs_set_command(cmd);
} }
template <typename T>
void DatasetImpl<T>::SetDownloadCmd(const std::string& download_cmd) {
paddle::framework::set_download_command(download_cmd);
}
template <typename T>
std::string DatasetImpl<T>::GetDownloadCmd() {
return paddle::framework::download_cmd();
}
template <typename T> template <typename T>
void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) { void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
......
...@@ -55,6 +55,8 @@ class Dataset { ...@@ -55,6 +55,8 @@ class Dataset {
// set fs name and ugi // set fs name and ugi
virtual void SetHdfsConfig(const std::string& fs_name, virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0; const std::string& fs_ugi) = 0;
// set customized download command, such as using afs api
virtual void SetDownloadCmd(const std::string& download_cmd) = 0;
// set data fedd desc, which contains: // set data fedd desc, which contains:
// data feed name, batch size, slots // data feed name, batch size, slots
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0; virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
...@@ -78,6 +80,8 @@ class Dataset { ...@@ -78,6 +80,8 @@ class Dataset {
virtual int64_t GetFleetSendBatchSize() = 0; virtual int64_t GetFleetSendBatchSize() = 0;
// get hdfs config // get hdfs config
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0; virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
// get download cmd
virtual std::string GetDownloadCmd() = 0;
// get data fedd desc // get data fedd desc
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0; virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
// get channel num // get channel num
...@@ -152,6 +156,7 @@ class DatasetImpl : public Dataset { ...@@ -152,6 +156,7 @@ class DatasetImpl : public Dataset {
virtual void SetFleetSendBatchSize(int64_t size); virtual void SetFleetSendBatchSize(int64_t size);
virtual void SetHdfsConfig(const std::string& fs_name, virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi); const std::string& fs_ugi);
virtual void SetDownloadCmd(const std::string& download_cmd);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str); virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual void SetChannelNum(int channel_num); virtual void SetChannelNum(int channel_num);
virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseInsId(bool parse_ins_id);
...@@ -167,6 +172,7 @@ class DatasetImpl : public Dataset { ...@@ -167,6 +172,7 @@ class DatasetImpl : public Dataset {
virtual std::pair<std::string, std::string> GetHdfsConfig() { virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_); return std::make_pair(fs_name_, fs_ugi_);
} }
virtual std::string GetDownloadCmd();
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() { virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_; return data_feed_desc_;
} }
......
...@@ -221,14 +221,30 @@ const std::string& hdfs_command() { return hdfs_command_internal(); } ...@@ -221,14 +221,30 @@ const std::string& hdfs_command() { return hdfs_command_internal(); }
void hdfs_set_command(const std::string& x) { hdfs_command_internal() = x; } void hdfs_set_command(const std::string& x) { hdfs_command_internal() = x; }
static std::string& customized_download_cmd_internal() {
static std::string x = "";
return x;
}
const std::string& download_cmd() { return customized_download_cmd_internal(); }
void set_download_command(const std::string& x) {
customized_download_cmd_internal() = x;
}
std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no, std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no,
const std::string& converter) { const std::string& converter) {
if (fs_end_with_internal(path, ".gz")) { if (fs_end_with_internal(path, ".gz")) {
path = string::format_string("%s -text \"%s\"", hdfs_command().c_str(), path = string::format_string("%s -text \"%s\"", hdfs_command().c_str(),
path.c_str()); path.c_str());
} else { } else {
const std::string file_path = path;
path = string::format_string("%s -cat \"%s\"", hdfs_command().c_str(), path = string::format_string("%s -cat \"%s\"", hdfs_command().c_str(),
path.c_str()); file_path.c_str());
if (download_cmd() != "") { // use customized download command
path = string::format_string("%s \"%s\"", download_cmd().c_str(),
file_path.c_str());
}
} }
bool is_pipe = true; bool is_pipe = true;
......
...@@ -61,6 +61,10 @@ extern const std::string& hdfs_command(); ...@@ -61,6 +61,10 @@ extern const std::string& hdfs_command();
extern void hdfs_set_command(const std::string& x); extern void hdfs_set_command(const std::string& x);
extern const std::string& download_cmd();
extern void set_download_command(const std::string& x);
extern std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no, extern std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no,
const std::string& converter); const std::string& converter);
......
...@@ -197,6 +197,8 @@ void BindDataset(py::module *m) { ...@@ -197,6 +197,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig, .def("set_hdfs_config", &framework::Dataset::SetHdfsConfig,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_download_cmd", &framework::Dataset::SetDownloadCmd,
py::call_guard<py::gil_scoped_release>())
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc, .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("get_filelist", &framework::Dataset::GetFileList, .def("get_filelist", &framework::Dataset::GetFileList,
...@@ -210,6 +212,8 @@ void BindDataset(py::module *m) { ...@@ -210,6 +212,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("get_hdfs_config", &framework::Dataset::GetHdfsConfig, .def("get_hdfs_config", &framework::Dataset::GetHdfsConfig,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("get_download_cmd", &framework::Dataset::GetDownloadCmd,
py::call_guard<py::gil_scoped_release>())
.def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc, .def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("register_client2client_msg_handler", .def("register_client2client_msg_handler",
......
...@@ -236,6 +236,22 @@ class DatasetBase(object): ...@@ -236,6 +236,22 @@ class DatasetBase(object):
""" """
self.dataset.set_hdfs_config(fs_name, fs_ugi) self.dataset.set_hdfs_config(fs_name, fs_ugi)
def set_download_cmd(self, download_cmd):
"""
Set customized download cmd: download_cmd
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_download_cmd("./read_from_afs")
Args:
download_cmd(str): customized download command
"""
self.dataset.set_download_cmd(download_cmd)
def _prepare_to_run(self): def _prepare_to_run(self):
""" """
Set data_feed_desc before load or shuffle, Set data_feed_desc before load or shuffle,
......
...@@ -124,6 +124,7 @@ class TestDataset(unittest.TestCase): ...@@ -124,6 +124,7 @@ class TestDataset(unittest.TestCase):
dataset.set_filelist(["a.txt", "b.txt", "c.txt"]) dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
dataset.set_trainer_num(4) dataset.set_trainer_num(4)
dataset.set_hdfs_config("my_fs_name", "my_fs_ugi") dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
dataset.set_download_cmd("./read_from_afs my_fs_name my_fs_ugi")
thread_num = dataset.get_thread_num() thread_num = dataset.get_thread_num()
self.assertEqual(thread_num, 12) self.assertEqual(thread_num, 12)
...@@ -141,6 +142,62 @@ class TestDataset(unittest.TestCase): ...@@ -141,6 +142,62 @@ class TestDataset(unittest.TestCase):
self.assertEqual(name, "my_fs_name") self.assertEqual(name, "my_fs_name")
self.assertEqual(ugi, "my_fs_ugi") self.assertEqual(ugi, "my_fs_ugi")
download_cmd = dataset.get_download_cmd()
self.assertEqual(download_cmd, "./read_from_afs my_fs_name my_fs_ugi")
def test_set_download_cmd(self):
"""
Testcase for InMemoryDataset from create to run.
"""
filename1 = "afs:test_in_memory_dataset_run_a.txt"
filename2 = "afs:test_in_memory_dataset_run_b.txt"
with open(filename1, "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open(filename2, "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(
name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist([filename1, filename2])
dataset.set_pipe_command("cat")
dataset.set_download_cmd("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
if self.use_data_loader:
data_loader = fluid.io.DataLoader.from_dataset(dataset,
fluid.cpu_places(),
self.drop_last)
for i in range(self.epoch_num):
for data in data_loader():
exe.run(fluid.default_main_program(), feed=data)
else:
for i in range(self.epoch_num):
try:
exe.train_from_dataset(fluid.default_main_program(),
dataset)
except Exception as e:
self.assertTrue(False)
os.remove(filename1)
os.remove(filename2)
def test_in_memory_dataset_run(self): def test_in_memory_dataset_run(self):
""" """
Testcase for InMemoryDataset from create to run. Testcase for InMemoryDataset from create to run.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册