提交 e95cafd9 编写于 作者: X xjqbest 提交者: dongdaxiang

fix code style & add dataset testcase

test=develop
上级 39362a84
......@@ -62,6 +62,8 @@ void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
template <typename T>
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) {
fs_name_ = fs_name;
fs_ugi_ = fs_ugi;
std::string cmd = std::string("hadoop fs");
cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi;
......
......@@ -20,6 +20,7 @@
#include <string>
#include <thread> // NOLINT
#include <vector>
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
......@@ -58,6 +59,8 @@ class Dataset {
virtual int GetThreadNum() = 0;
// get worker num
virtual int GetTrainerNum() = 0;
// get hdfs config
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
// get data fedd desc
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
// get readers, the reader num depend both on thread num
......@@ -102,6 +105,9 @@ class DatasetImpl : public Dataset {
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
}
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
}
......@@ -128,6 +134,8 @@ class DatasetImpl : public Dataset {
std::vector<std::string> filelist_;
size_t file_idx_;
std::mutex mutex_for_pick_file_;
std::string fs_name_;
std::string fs_ugi_;
};
// use std::vector<MultiSlotType> as data type
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor_thread_worker.h"
#include <algorithm>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/io/fs.h"
#include <memory>
namespace paddle {
namespace framework {
......
......@@ -17,6 +17,7 @@
#include <stdio.h>
#include <string>
#include <vector>
#include <memory>
#include "glog/logging.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
......
......@@ -52,6 +52,11 @@ void BindDataset(py::module* m) {
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("get_filelist", &framework::Dataset::GetFileList)
.def("get_thread_num", &framework::Dataset::GetThreadNum)
.def("get_trainer_num", &framework::Dataset::GetTrainerNum)
.def("get_hdfs_config", &framework::Dataset::GetHdfsConfig)
.def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc)
.def("register_client2client_msg_handler",
&framework::Dataset::RegisterClientToClientMsgHandler)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
......
......@@ -19,6 +19,7 @@
#include <cstring>
#include <string>
#include <vector>
#include <utility>
#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
......
......@@ -78,6 +78,12 @@ class AsyncExecutor(object):
"""
def __init__(self, place=None, run_mode=""):
"""
Init.
Args:
place(Place): CPUPlace or GPUPlace.
run_mode(str): default is empty string.
"""
if place is None:
place = core.CPUPlace()
if not isinstance(place, core.CPUPlace):
......@@ -91,6 +97,18 @@ class AsyncExecutor(object):
self.instance = None
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
"""
Run program by this AsyncExecutor.
Args:
program(Program): the program that need to run, if not provied,
then default_main_program will be used.
data_feed(DataFeedDesc): A DataFeedDesc object
filelist(str|list): a file or a list of files
thread_num(int): number of concurrent training threads.
fetch(str|list): the var name or a list of var names to inspect
debug(bool): When set to True, fetch vars will be printed to
standard output after each minibatch
"""
if program is None:
program = default_main_program()
program_desc = program.desc
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import numpy as np
import os
import shutil
import unittest
class TestDataset(unittest.TestCase):
def test_dataset_create(self):
try:
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
except:
self.assertTrue(False)
try:
dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
except:
self.assertTrue(False)
try:
dataset = fluid.DatasetFactory().create_dataset("MyOwnDataset")
self.assertTrue(False)
except:
self.assertTrue(True)
def test_dataset_config(self):
dataset = fluid.core.Dataset("MultiSlotDataset")
dataset.set_thread_num(12)
dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
dataset.set_trainer_num(4)
dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
thread_num = dataset.get_thread_num()
self.assertEqual(thread_num, 12)
filelist = dataset.get_filelist()
self.assertEqual(len(filelist), 3)
self.assertEqual(filelist[0], "a.txt")
self.assertEqual(filelist[1], "b.txt")
self.assertEqual(filelist[2], "c.txt")
trainer_num = dataset.get_trainer_num()
self.assertEqual(trainer_num, 4)
name, ugi = dataset.get_hdfs_config()
self.assertEqual(name, "my_fs_name")
self.assertEqual(ugi, "my_fs_ugi")
def test_in_memory_dataset_run(self):
with open("test_dataset_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_dataset_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots = ["slot1","slot2","slot3","slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(name=slot, shape=[1],
dtype="int64", lod_level=1)
slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist(["test_dataset_a.txt", "test_dataset_b.txt"])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
dataset.local_shuffle()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
for i in range(2):
try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
except:
self.assertTrue(False)
os.remove("./test_dataset_a.txt")
os.remove("./test_dataset_b.txt")
def test_queue_dataset_run(self):
with open("test_dataset_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_dataset_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots = ["slot1","slot2","slot3","slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(name=slot, shape=[1],
dtype="int64", lod_level=1)
slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist(["test_dataset_a.txt", "test_dataset_b.txt"])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
for i in range(2):
try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
except:
self.assertTrue(False)
os.remove("./test_dataset_a.txt")
os.remove("./test_dataset_b.txt")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册