未验证 提交 53a2b68f 编写于 作者: H hutuxian 提交者: GitHub

support customized download command in dataset (#22782)

* user can call dataset.set_download_cmd to set its customized download cmd
* add UT to cover this scenario
上级 1a2baa18
......@@ -95,6 +95,16 @@ void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
paddle::framework::hdfs_set_command(cmd);
}
template <typename T>
void DatasetImpl<T>::SetDownloadCmd(const std::string& download_cmd) {
paddle::framework::set_download_command(download_cmd);
}
template <typename T>
std::string DatasetImpl<T>::GetDownloadCmd() {
return paddle::framework::download_cmd();
}
template <typename T>
void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
......
......@@ -55,6 +55,8 @@ class Dataset {
// set fs name and ugi
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0;
// set customized download command, such as using afs api
virtual void SetDownloadCmd(const std::string& download_cmd) = 0;
// set data fedd desc, which contains:
// data feed name, batch size, slots
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
......@@ -78,6 +80,8 @@ class Dataset {
virtual int64_t GetFleetSendBatchSize() = 0;
// get hdfs config
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
// get download cmd
virtual std::string GetDownloadCmd() = 0;
// get data fedd desc
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
// get channel num
......@@ -152,6 +156,7 @@ class DatasetImpl : public Dataset {
virtual void SetFleetSendBatchSize(int64_t size);
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi);
virtual void SetDownloadCmd(const std::string& download_cmd);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual void SetChannelNum(int channel_num);
virtual void SetParseInsId(bool parse_ins_id);
......@@ -167,6 +172,7 @@ class DatasetImpl : public Dataset {
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
}
virtual std::string GetDownloadCmd();
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
}
......
......@@ -221,14 +221,30 @@ const std::string& hdfs_command() { return hdfs_command_internal(); }
void hdfs_set_command(const std::string& x) { hdfs_command_internal() = x; }
static std::string& customized_download_cmd_internal() {
static std::string x = "";
return x;
}
const std::string& download_cmd() { return customized_download_cmd_internal(); }
void set_download_command(const std::string& x) {
customized_download_cmd_internal() = x;
}
std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no,
const std::string& converter) {
if (fs_end_with_internal(path, ".gz")) {
path = string::format_string("%s -text \"%s\"", hdfs_command().c_str(),
path.c_str());
} else {
const std::string file_path = path;
path = string::format_string("%s -cat \"%s\"", hdfs_command().c_str(),
path.c_str());
file_path.c_str());
if (download_cmd() != "") { // use customized download command
path = string::format_string("%s \"%s\"", download_cmd().c_str(),
file_path.c_str());
}
}
bool is_pipe = true;
......
......@@ -61,6 +61,10 @@ extern const std::string& hdfs_command();
extern void hdfs_set_command(const std::string& x);
extern const std::string& download_cmd();
extern void set_download_command(const std::string& x);
extern std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no,
const std::string& converter);
......
......@@ -197,6 +197,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig,
py::call_guard<py::gil_scoped_release>())
.def("set_download_cmd", &framework::Dataset::SetDownloadCmd,
py::call_guard<py::gil_scoped_release>())
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc,
py::call_guard<py::gil_scoped_release>())
.def("get_filelist", &framework::Dataset::GetFileList,
......@@ -210,6 +212,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("get_hdfs_config", &framework::Dataset::GetHdfsConfig,
py::call_guard<py::gil_scoped_release>())
.def("get_download_cmd", &framework::Dataset::GetDownloadCmd,
py::call_guard<py::gil_scoped_release>())
.def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc,
py::call_guard<py::gil_scoped_release>())
.def("register_client2client_msg_handler",
......
......@@ -236,6 +236,22 @@ class DatasetBase(object):
"""
self.dataset.set_hdfs_config(fs_name, fs_ugi)
def set_download_cmd(self, download_cmd):
"""
Set customized download cmd: download_cmd
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_download_cmd("./read_from_afs")
Args:
download_cmd(str): customized download command
"""
self.dataset.set_download_cmd(download_cmd)
def _prepare_to_run(self):
"""
Set data_feed_desc before load or shuffle,
......
......@@ -124,6 +124,7 @@ class TestDataset(unittest.TestCase):
dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
dataset.set_trainer_num(4)
dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
dataset.set_download_cmd("./read_from_afs my_fs_name my_fs_ugi")
thread_num = dataset.get_thread_num()
self.assertEqual(thread_num, 12)
......@@ -141,6 +142,62 @@ class TestDataset(unittest.TestCase):
self.assertEqual(name, "my_fs_name")
self.assertEqual(ugi, "my_fs_ugi")
download_cmd = dataset.get_download_cmd()
self.assertEqual(download_cmd, "./read_from_afs my_fs_name my_fs_ugi")
def test_set_download_cmd(self):
"""
Testcase for InMemoryDataset from create to run.
"""
filename1 = "afs:test_in_memory_dataset_run_a.txt"
filename2 = "afs:test_in_memory_dataset_run_b.txt"
with open(filename1, "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open(filename2, "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(
name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist([filename1, filename2])
dataset.set_pipe_command("cat")
dataset.set_download_cmd("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
if self.use_data_loader:
data_loader = fluid.io.DataLoader.from_dataset(dataset,
fluid.cpu_places(),
self.drop_last)
for i in range(self.epoch_num):
for data in data_loader():
exe.run(fluid.default_main_program(), feed=data)
else:
for i in range(self.epoch_num):
try:
exe.train_from_dataset(fluid.default_main_program(),
dataset)
except Exception as e:
self.assertTrue(False)
os.remove(filename1)
os.remove(filename2)
def test_in_memory_dataset_run(self):
"""
Testcase for InMemoryDataset from create to run.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册