diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 7c5f9351d220491020773ba33ef3c4fa5584a9d6..33390ed137487dfb2467be9b0895d3a0759ef2b7 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -95,6 +95,16 @@ void DatasetImpl::SetHdfsConfig(const std::string& fs_name, paddle::framework::hdfs_set_command(cmd); } +template +void DatasetImpl::SetDownloadCmd(const std::string& download_cmd) { + paddle::framework::set_download_command(download_cmd); +} + +template +std::string DatasetImpl::GetDownloadCmd() { + return paddle::framework::download_cmd(); +} + template void DatasetImpl::SetDataFeedDesc(const std::string& data_feed_desc_str) { google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index f244cd76f6a34c1b2fca80ac6a4b2b2323e0a08c..df8bbc33e7aef1a7b51d487b1721f1a031590edc 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -55,6 +55,8 @@ class Dataset { // set fs name and ugi virtual void SetHdfsConfig(const std::string& fs_name, const std::string& fs_ugi) = 0; + // set customized download command, such as using afs api + virtual void SetDownloadCmd(const std::string& download_cmd) = 0; // set data fedd desc, which contains: // data feed name, batch size, slots virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0; @@ -78,6 +80,8 @@ class Dataset { virtual int64_t GetFleetSendBatchSize() = 0; // get hdfs config virtual std::pair GetHdfsConfig() = 0; + // get download cmd + virtual std::string GetDownloadCmd() = 0; // get data fedd desc virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0; // get channel num @@ -152,6 +156,7 @@ class DatasetImpl : public Dataset { virtual void SetFleetSendBatchSize(int64_t size); virtual void SetHdfsConfig(const std::string& fs_name, const std::string& fs_ugi); + virtual void SetDownloadCmd(const std::string& download_cmd); virtual void SetDataFeedDesc(const std::string& data_feed_desc_str); virtual void SetChannelNum(int channel_num); virtual void SetParseInsId(bool parse_ins_id); @@ -167,6 +172,7 @@ class DatasetImpl : public Dataset { virtual std::pair GetHdfsConfig() { return std::make_pair(fs_name_, fs_ugi_); } + virtual std::string GetDownloadCmd(); virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() { return data_feed_desc_; } diff --git a/paddle/fluid/framework/io/fs.cc b/paddle/fluid/framework/io/fs.cc index 91b67be4602e609d693510a418f1eba195452ce7..c0d477b1003d88c5ea1aa7207da5a3b591bef8aa 100644 --- a/paddle/fluid/framework/io/fs.cc +++ b/paddle/fluid/framework/io/fs.cc @@ -221,14 +221,30 @@ const std::string& hdfs_command() { return hdfs_command_internal(); } void hdfs_set_command(const std::string& x) { hdfs_command_internal() = x; } +static std::string& customized_download_cmd_internal() { + static std::string x = ""; + return x; +} + +const std::string& download_cmd() { return customized_download_cmd_internal(); } + +void set_download_command(const std::string& x) { + customized_download_cmd_internal() = x; +} + std::shared_ptr hdfs_open_read(std::string path, int* err_no, const std::string& converter) { if (fs_end_with_internal(path, ".gz")) { path = string::format_string("%s -text \"%s\"", hdfs_command().c_str(), path.c_str()); } else { + const std::string file_path = path; path = string::format_string("%s -cat \"%s\"", hdfs_command().c_str(), - path.c_str()); + file_path.c_str()); + if (download_cmd() != "") { // use customized download command + path = string::format_string("%s \"%s\"", download_cmd().c_str(), + file_path.c_str()); + } } bool is_pipe = true; diff --git a/paddle/fluid/framework/io/fs.h b/paddle/fluid/framework/io/fs.h index 06ec11f5d19e0e15d791cb222f1a3a229b3edf31..c88636e267422c27696341df145346e33018cdaf 100644 --- a/paddle/fluid/framework/io/fs.h +++ b/paddle/fluid/framework/io/fs.h @@ -61,6 +61,10 @@ extern const std::string& hdfs_command(); extern void hdfs_set_command(const std::string& x); +extern const std::string& download_cmd(); + +extern void set_download_command(const std::string& x); + extern std::shared_ptr hdfs_open_read(std::string path, int* err_no, const std::string& converter); diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 6435aea8a8811bc446b658841108d1b4ea0f00c4..bd3aa4e4989380cdcddcecb5ac725b28d2eacbf8 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -197,6 +197,8 @@ void BindDataset(py::module *m) { py::call_guard()) .def("set_hdfs_config", &framework::Dataset::SetHdfsConfig, py::call_guard()) + .def("set_download_cmd", &framework::Dataset::SetDownloadCmd, + py::call_guard()) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc, py::call_guard()) .def("get_filelist", &framework::Dataset::GetFileList, @@ -210,6 +212,8 @@ void BindDataset(py::module *m) { py::call_guard()) .def("get_hdfs_config", &framework::Dataset::GetHdfsConfig, py::call_guard()) + .def("get_download_cmd", &framework::Dataset::GetDownloadCmd, + py::call_guard()) .def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc, py::call_guard()) .def("register_client2client_msg_handler", diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index ea83ba40e79bfd1c84c3fcda6dfd90beb786e30b..f4c17fb7858ca0e37aa7265d603e0fb019995174 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -236,6 +236,22 @@ class DatasetBase(object): """ self.dataset.set_hdfs_config(fs_name, fs_ugi) + def set_download_cmd(self, download_cmd): + """ + Set customized download cmd: download_cmd + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_download_cmd("./read_from_afs") + + Args: + download_cmd(str): customized download command + """ + self.dataset.set_download_cmd(download_cmd) + def _prepare_to_run(self): """ Set data_feed_desc before load or shuffle, diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index 6f13fd7220b8bd464a5797ba898e520ac329547e..9f6673124c338e962a4103314452c842a3e91bf7 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -124,6 +124,7 @@ class TestDataset(unittest.TestCase): dataset.set_filelist(["a.txt", "b.txt", "c.txt"]) dataset.set_trainer_num(4) dataset.set_hdfs_config("my_fs_name", "my_fs_ugi") + dataset.set_download_cmd("./read_from_afs my_fs_name my_fs_ugi") thread_num = dataset.get_thread_num() self.assertEqual(thread_num, 12) @@ -141,6 +142,62 @@ class TestDataset(unittest.TestCase): self.assertEqual(name, "my_fs_name") self.assertEqual(ugi, "my_fs_ugi") + download_cmd = dataset.get_download_cmd() + self.assertEqual(download_cmd, "./read_from_afs my_fs_name my_fs_ugi") + + def test_set_download_cmd(self): + """ + Testcase for InMemoryDataset from create to run. + """ + filename1 = "afs:test_in_memory_dataset_run_a.txt" + filename2 = "afs:test_in_memory_dataset_run_b.txt" + with open(filename1, "w") as f: + data = "1 1 2 3 3 4 5 5 5 5 1 1\n" + data += "1 2 2 3 4 4 6 6 6 6 1 2\n" + data += "1 3 2 3 5 4 7 7 7 7 1 3\n" + f.write(data) + with open(filename2, "w") as f: + data = "1 4 2 3 3 4 5 5 5 5 1 4\n" + data += "1 5 2 3 4 4 6 6 6 6 1 5\n" + data += "1 6 2 3 5 4 7 7 7 7 1 6\n" + data += "1 7 2 3 6 4 8 8 8 8 1 7\n" + f.write(data) + + slots = ["slot1", "slot2", "slot3", "slot4"] + slots_vars = [] + for slot in slots: + var = fluid.layers.data( + name=slot, shape=[1], dtype="int64", lod_level=1) + slots_vars.append(var) + + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_batch_size(32) + dataset.set_thread(3) + dataset.set_filelist([filename1, filename2]) + dataset.set_pipe_command("cat") + dataset.set_download_cmd("cat") + dataset.set_use_var(slots_vars) + dataset.load_into_memory() + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + if self.use_data_loader: + data_loader = fluid.io.DataLoader.from_dataset(dataset, + fluid.cpu_places(), + self.drop_last) + for i in range(self.epoch_num): + for data in data_loader(): + exe.run(fluid.default_main_program(), feed=data) + else: + for i in range(self.epoch_num): + try: + exe.train_from_dataset(fluid.default_main_program(), + dataset) + except Exception as e: + self.assertTrue(False) + + os.remove(filename1) + os.remove(filename2) + def test_in_memory_dataset_run(self): """ Testcase for InMemoryDataset from create to run.