提交 45eb6f07 编写于 作者: D dongdaxiang

run pre-commit check files and fix code style problem

test=develop
上级 e57ac5ed
...@@ -246,8 +246,8 @@ void InMemoryDataFeed<T>::FillMemoryDataToChannel() { ...@@ -246,8 +246,8 @@ void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_; VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_;
auto interval = GetMemoryDataInterval(); auto interval = GetMemoryDataInterval();
VLOG(3) << "memory data size=" << memory_data_->size() VLOG(3) << "memory data size=" << memory_data_->size()
<< ", fill data from [" << interval.first << ", " << ", fill data from [" << interval.first << ", " << interval.second
<< interval.second << "), thread_id=" << thread_id_; << "), thread_id=" << thread_id_;
for (int64_t i = interval.first; i < interval.second; ++i) { for (int64_t i = interval.first; i < interval.second; ++i) {
T& t = (*memory_data_)[i]; T& t = (*memory_data_)[i];
shuffled_ins_->Push(std::move(t)); shuffled_ins_->Push(std::move(t));
...@@ -275,13 +275,13 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() { ...@@ -275,13 +275,13 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
channel->Pop(&local_vec[i]); channel->Pop(&local_vec[i]);
} }
VLOG(3) << "local_vec size=" << local_vec.size() VLOG(3) << "local_vec size=" << local_vec.size()
<<", thread_id=" << thread_id_; << ", thread_id=" << thread_id_;
{ {
std::lock_guard<std::mutex> g(*mutex_for_update_memory_data_); std::lock_guard<std::mutex> g(*mutex_for_update_memory_data_);
VLOG(3) << "before insert, memory_data_ size=" << memory_data_->size() VLOG(3) << "before insert, memory_data_ size=" << memory_data_->size()
<< ", thread_id=" << thread_id_; << ", thread_id=" << thread_id_;
memory_data_->insert(memory_data_->end(), local_vec.begin(), memory_data_->insert(memory_data_->end(), local_vec.begin(),
local_vec.end()); local_vec.end());
VLOG(3) << "after insert memory_data_ size=" << memory_data_->size() VLOG(3) << "after insert memory_data_ size=" << memory_data_->size()
<< ", thread_id=" << thread_id_; << ", thread_id=" << thread_id_;
} }
...@@ -308,8 +308,8 @@ void InMemoryDataFeed<T>::LoadIntoMemory() { ...@@ -308,8 +308,8 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
local_vec.push_back(instance); local_vec.push_back(instance);
} }
timeline.Pause(); timeline.Pause();
VLOG(3) << "LoadIntoMemory() read all lines, file=" VLOG(3) << "LoadIntoMemory() read all lines, file=" << filename
<< filename << ", cost time=" << timeline.ElapsedSec() << ", cost time=" << timeline.ElapsedSec()
<< " seconds, thread_id=" << thread_id_; << " seconds, thread_id=" << thread_id_;
{ {
std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_); std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_);
...@@ -319,8 +319,7 @@ void InMemoryDataFeed<T>::LoadIntoMemory() { ...@@ -319,8 +319,7 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
std::make_move_iterator(local_vec.end())); std::make_move_iterator(local_vec.end()));
timeline.Pause(); timeline.Pause();
VLOG(3) << "LoadIntoMemory() memory_data insert, cost time=" VLOG(3) << "LoadIntoMemory() memory_data insert, cost time="
<< timeline.ElapsedSec() << " seconds, thread_id=" << timeline.ElapsedSec() << " seconds, thread_id=" << thread_id_;
<< thread_id_;
} }
local_vec.clear(); local_vec.clear();
} }
...@@ -358,8 +357,8 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -358,8 +357,8 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
std::string send_str; std::string send_str;
SerializeIns(send_vec[j], &send_str); SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length() VLOG(3) << "send str_length=" << send_str.length()
<< ", ins num=" << send_vec[j].size() << " to node_id=" << ", ins num=" << send_vec[j].size() << " to node_id=" << j
<< j << ", thread_id=" << thread_id_; << ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str); auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_; VLOG(3) << "end send, thread_id=" << thread_id_;
send_vec[j].clear(); send_vec[j].clear();
...@@ -371,8 +370,8 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -371,8 +370,8 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
if (send_vec[j].size() != 0) { if (send_vec[j].size() != 0) {
std::string send_str; std::string send_str;
SerializeIns(send_vec[j], &send_str); SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length() VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j
<< " to node_id=" << j << ", thread_id=" << thread_id_; << ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str); auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_; VLOG(3) << "end send, thread_id=" << thread_id_;
total_status.push_back(std::move(ret)); total_status.push_back(std::move(ret));
...@@ -888,15 +887,13 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( ...@@ -888,15 +887,13 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle // todo serialize ins in global shuffle
void MultiSlotInMemoryDataFeed::SerializeIns( void MultiSlotInMemoryDataFeed::SerializeIns(
const std::vector<std::vector<MultiSlotType>*>& ins, const std::vector<std::vector<MultiSlotType>*>& ins, std::string* str) {
std::string* str) {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Serialize(ins, str); fleet_ptr->Serialize(ins, str);
} }
// todo deserialize ins in global shuffle // todo deserialize ins in global shuffle
void MultiSlotInMemoryDataFeed::DeserializeIns( void MultiSlotInMemoryDataFeed::DeserializeIns(
std::vector<std::vector<MultiSlotType>>* ins, std::vector<std::vector<MultiSlotType>>* ins, const std::string& str) {
const std::string& str) {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Deserialize(ins, str); fleet_ptr->Deserialize(ins, str);
} }
......
...@@ -15,23 +15,23 @@ limitations under the License. */ ...@@ -15,23 +15,23 @@ limitations under the License. */
#pragma once #pragma once
#include <fstream> #include <fstream>
#include <future> // NOLINT
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <sstream>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector>
#include <sstream>
#include <future> // NOLINT
#include <utility> #include <utility>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -85,21 +85,19 @@ class DataFeed { ...@@ -85,21 +85,19 @@ class DataFeed {
virtual void AddFeedVar(Variable* var, const std::string& name); virtual void AddFeedVar(Variable* var, const std::string& name);
// This function will do nothing at default // This function will do nothing at default
virtual void SetMemoryData(void* memory_data) { } virtual void SetMemoryData(void* memory_data) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetMemoryDataMutex(std::mutex* mutex) { } virtual void SetMemoryDataMutex(std::mutex* mutex) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetThreadId(int thread_id) { } virtual void SetThreadId(int thread_id) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetThreadNum(int thread_num) { } virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) { } virtual void SetTrainerNum(int trainer_num) {}
virtual void SetFileListMutex(std::mutex* mutex) { virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex; mutex_for_pick_file_ = mutex;
} }
virtual void SetFileListIndex(size_t* file_index) { virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
file_idx_ = file_index;
}
virtual void LoadIntoMemory() { virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented."); PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
} }
...@@ -110,11 +108,11 @@ class DataFeed { ...@@ -110,11 +108,11 @@ class DataFeed {
PADDLE_THROW("This function(GlobalShuffle) is not implemented."); PADDLE_THROW("This function(GlobalShuffle) is not implemented.");
} }
// This function will do nothing at default // This function will do nothing at default
virtual void FillMemoryDataToChannel() { } virtual void FillMemoryDataToChannel() {}
// This function will do nothing at default // This function will do nothing at default
virtual void FillChannelToMemoryData() { } virtual void FillChannelToMemoryData() {}
// This function will do nothing at default // This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) { } virtual void PutInsToChannel(const std::string& ins_str) {}
protected: protected:
// The following three functions are used to check if it is executed in this // The following three functions are used to check if it is executed in this
...@@ -222,8 +220,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -222,8 +220,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void GlobalShuffle(); virtual void GlobalShuffle();
protected: protected:
virtual void AddInstanceToInsVec(T* vec_ins, virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
const T& instance,
int index) = 0; int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0; virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0; virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
...@@ -363,6 +360,7 @@ class MultiSlotInMemoryDataFeed ...@@ -363,6 +360,7 @@ class MultiSlotInMemoryDataFeed
MultiSlotInMemoryDataFeed() {} MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {} virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc); virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
protected: protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins, virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
const std::vector<MultiSlotType>& instance, const std::vector<MultiSlotType>& instance,
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -248,8 +248,7 @@ template <typename T> ...@@ -248,8 +248,7 @@ template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) { const std::string& msg) {
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length=" << ", client_id=" << client_id << ", msg length=" << msg.length();
<< msg.length();
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_; int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_;
VLOG(3) << "ramdom index=" << index; VLOG(3) << "ramdom index=" << index;
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector>
#include <utility> #include <utility>
#include <vector>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
......
...@@ -25,24 +25,23 @@ typedef std::shared_ptr<Dataset> (*CreateDatasetFunction)(); ...@@ -25,24 +25,23 @@ typedef std::shared_ptr<Dataset> (*CreateDatasetFunction)();
typedef std::unordered_map<std::string, CreateDatasetFunction> datasetMap; typedef std::unordered_map<std::string, CreateDatasetFunction> datasetMap;
datasetMap g_dataset_map; datasetMap g_dataset_map;
#define REGISTER_DATASET_CLASS(dataset_class) \ #define REGISTER_DATASET_CLASS(dataset_class) \
namespace { \ namespace { \
std::shared_ptr<Dataset> Creator_##dataset_class() { \ std::shared_ptr<Dataset> Creator_##dataset_class() { \
return std::shared_ptr<Dataset>(new dataset_class); \ return std::shared_ptr<Dataset>(new dataset_class); \
} \ } \
class __Registerer_##dataset_class { \ class __Registerer_##dataset_class { \
public: \ public: \
__Registerer_##dataset_class() { \ __Registerer_##dataset_class() { \
g_dataset_map[#dataset_class] = &Creator_##dataset_class; \ g_dataset_map[#dataset_class] = &Creator_##dataset_class; \
} \ } \
}; \ }; \
__Registerer_##dataset_class g_registerer_##dataset_class; \ __Registerer_##dataset_class g_registerer_##dataset_class; \
} // namespace } // namespace
std::string DatasetFactory::DatasetTypeList() { std::string DatasetFactory::DatasetTypeList() {
std::string dataset_types; std::string dataset_types;
for (auto iter = g_dataset_map.begin(); iter != g_dataset_map.end(); for (auto iter = g_dataset_map.begin(); iter != g_dataset_map.end(); ++iter) {
++iter) {
if (iter != g_dataset_map.begin()) { if (iter != g_dataset_map.begin()) {
dataset_types += ", "; dataset_types += ", ";
} }
......
...@@ -113,8 +113,7 @@ class Executor { ...@@ -113,8 +113,7 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Scope* scope, void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset, Dataset* dataset, const std::string& trainer_desc_str);
const std::string& trainer_desc_str);
private: private:
const platform::Place place_; const platform::Place place_;
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#pragma once #pragma once
#include <stdio.h> #include <stdio.h>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/io/shell.h" #include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
......
...@@ -47,7 +47,7 @@ void PullDenseWorker::Initialize(const TrainerDesc& param) { ...@@ -47,7 +47,7 @@ void PullDenseWorker::Initialize(const TrainerDesc& param) {
int var_num = table.dense_value_name_size(); int var_num = table.dense_value_name_size();
dense_value_names_[tid].resize(var_num); dense_value_names_[tid].resize(var_num);
for (int j = 0; j < var_num; ++j) { for (int j = 0; j < var_num; ++j) {
dense_value_names_[tid][j] = table.dense_value_name(j); dense_value_names_[tid][j] = table.dense_value_name(j);
} }
// setup training version for each table // setup training version for each table
training_versions_[tid].resize(thread_num_, 0); training_versions_[tid].resize(thread_num_, 0);
......
...@@ -21,9 +21,9 @@ limitations under the License. */ ...@@ -21,9 +21,9 @@ limitations under the License. */
#ifdef _XOPEN_SOURCE #ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE #undef _XOPEN_SOURCE
#endif #endif
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
......
...@@ -19,21 +19,21 @@ limitations under the License. */ ...@@ -19,21 +19,21 @@ limitations under the License. */
#ifdef _XOPEN_SOURCE #ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE #undef _XOPEN_SOURCE
#endif #endif
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/async_executor.h" #include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/dataset_factory.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/framework/dataset_factory.h"
namespace py = pybind11; namespace py = pybind11;
namespace pd = paddle::framework; namespace pd = paddle::framework;
...@@ -42,8 +42,8 @@ namespace paddle { ...@@ -42,8 +42,8 @@ namespace paddle {
namespace pybind { namespace pybind {
void BindDataset(py::module* m) { void BindDataset(py::module* m) {
py::class_<framework::Dataset, py::class_<framework::Dataset, std::shared_ptr<framework::Dataset>>(*m,
std::shared_ptr<framework::Dataset>>(*m, "Dataset") "Dataset")
.def(py::init([](const std::string& name = "MultiSlotDataset") { .def(py::init([](const std::string& name = "MultiSlotDataset") {
return framework::DatasetFactory::CreateDataset(name); return framework::DatasetFactory::CreateDataset(name);
})) }))
...@@ -58,7 +58,7 @@ void BindDataset(py::module* m) { ...@@ -58,7 +58,7 @@ void BindDataset(py::module* m) {
.def("get_hdfs_config", &framework::Dataset::GetHdfsConfig) .def("get_hdfs_config", &framework::Dataset::GetHdfsConfig)
.def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc) .def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc)
.def("register_client2client_msg_handler", .def("register_client2client_msg_handler",
&framework::Dataset::RegisterClientToClientMsgHandler) &framework::Dataset::RegisterClientToClientMsgHandler)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory) .def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("release_memory", &framework::Dataset::ReleaseMemory) .def("release_memory", &framework::Dataset::ReleaseMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle) .def("local_shuffle", &framework::Dataset::LocalShuffle)
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include <stdio.h> #include <stdio.h>
#include <cstring> #include <cstring>
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
#include <vector>
#include "boost/lexical_cast.hpp" #include "boost/lexical_cast.hpp"
#include "glog/logging.h" #include "glog/logging.h"
......
...@@ -80,18 +80,20 @@ class TestDataset(unittest.TestCase): ...@@ -80,18 +80,20 @@ class TestDataset(unittest.TestCase):
data += "1 7 2 3 6 4 8 8 8 8 1 7\n" data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data) f.write(data)
slots = ["slot1","slot2","slot3","slot4"] slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = [] slots_vars = []
for slot in slots: for slot in slots:
var = fluid.layers.data(name=slot, shape=[1], var = fluid.layers.data(
dtype="int64", lod_level=1) name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var) slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32) dataset.set_batch_size(32)
dataset.set_thread(3) dataset.set_thread(3)
dataset.set_filelist(["test_in_memory_dataset_run_a.txt", dataset.set_filelist([
"test_in_memory_dataset_run_b.txt"]) "test_in_memory_dataset_run_a.txt",
"test_in_memory_dataset_run_b.txt"
])
dataset.set_pipe_command("cat") dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars) dataset.set_use_var(slots_vars)
dataset.load_into_memory() dataset.load_into_memory()
...@@ -124,18 +126,18 @@ class TestDataset(unittest.TestCase): ...@@ -124,18 +126,18 @@ class TestDataset(unittest.TestCase):
data += "1 7 2 3 6 4 8 8 8 8 1 7\n" data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data) f.write(data)
slots = ["slot1","slot2","slot3","slot4"] slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = [] slots_vars = []
for slot in slots: for slot in slots:
var = fluid.layers.data(name=slot, shape=[1], var = fluid.layers.data(
dtype="int64", lod_level=1) name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var) slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("QueueDataset") dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
dataset.set_batch_size(32) dataset.set_batch_size(32)
dataset.set_thread(3) dataset.set_thread(3)
dataset.set_filelist(["test_queue_dataset_run_a.txt", dataset.set_filelist(
"test_queue_dataset_run_b.txt"]) ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
dataset.set_pipe_command("cat") dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars) dataset.set_use_var(slots_vars)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册