提交 e36bbcc8 编写于 作者: D dongdaxiang

fix some typo and CMakefile.txt

上级 824b84d1
......@@ -18,15 +18,14 @@
namespace paddle {
namespace framework {
Dataset::Dataset() {
thread_num_ = 1;
}
Dataset::Dataset() { thread_num_ = 1; }
void Dataset::SetFileList(const std::vector<std::string>& filelist) {
filelist_ = filelist;
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num_ << ", file num = " << file_cnt
VLOG(1) << "DataSet thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num_ = file_cnt;
}
......@@ -35,22 +34,23 @@ void Dataset::SetFileList(const std::vector<std::string>& filelist) {
void Dataset::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size();
if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num << ", file num = " << file_cnt
VLOG(1) << "DataSet thread num = " << thread_num
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num = file_cnt;
}
thread_num_ = thread_num;
}
void Dataset::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
}
void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; }
void Dataset::SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc) {
void Dataset::SetDataFeedDesc(
const paddle::framework::DataFeedDesc& data_feed_desc) {
data_feed_desc_ = data_feed_desc;
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>> Dataset::GetReaders() {
std::vector<std::shared_ptr<paddle::framework::DataFeed>>
Dataset::GetReaders() {
return readers_;
}
......@@ -60,8 +60,8 @@ void Dataset::LoadIntoMemory() {
}
std::vector<std::thread> load_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
load_threads.push_back(std::thread(&paddle::framework::DataFeed::LoadIntoMemory,
readers_[i].get()));
load_threads.push_back(std::thread(
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
}
for (std::thread& t : load_threads) {
t.join();
......@@ -74,8 +74,8 @@ void Dataset::LocalShuffle() {
}
std::vector<std::thread> local_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
local_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::LocalShuffle,
readers_[i].get()));
local_shuffle_threads.push_back(std::thread(
&paddle::framework::DataFeed::LocalShuffle, readers_[i].get()));
}
for (std::thread& t : local_shuffle_threads) {
t.join();
......@@ -115,14 +115,14 @@ void Dataset::CreateReaders() {
readers_[0]->SetFileList(filelist_);
}
int Dataset::ReceiveFromClient(int msg_type, int client_id, const std::string& msg) {
int Dataset::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) {
// can also use hash
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
// todo
int64_t index = 0;
readers_[index]->PutInsToChannel(msg);
return 0;
}
}
}
} // end namespace framework
} // end namespace paddle
......@@ -34,29 +34,27 @@ class Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual void SetDataFeedDesc(
const paddle::framework::DataFeedDesc& data_feed_desc);
virtual const std::vector<std::string>& GetFileList() {
return filelist_;
}
virtual int GetThreadNum() {
return thread_num_;
}
virtual int GetTrainerNum() {
return trainer_num_;
}
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
}
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>> GetReaders();
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>
GetReaders();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
// todo global shuffle
virtual void GlobalShuffle();
virtual void CreateReaders();
protected:
virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg);
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
int thread_num_;
std::string fs_name_;
......@@ -66,5 +64,5 @@ class Dataset {
int trainer_num_;
};
}
}
} // end namespace framework
} // end namespace paddle
......@@ -115,6 +115,10 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
}
void Executor::RunFromDataset(const ProgramDesc& pdesc, const Dataset& dataset,
const std::string& trainer_desc_str,
const bool debug) {}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars,
......
......@@ -19,13 +19,13 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
......@@ -112,11 +112,7 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program);
void RunFromTrainerDesc(const ProgramDesc& main_program,
const std::string& trainer_desc_str,
const bool debug);
void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset,
void RunFromDataset(const ProgramDesc& main_program, const Dataset& dataset,
const std::string& trainer_desc_str, const bool debug);
public:
......
......@@ -5,11 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wr
if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op)
endif()
<<<<<<< HEAD
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc)
=======
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc fleet_wrapper_py.cc imperative.cc ir.cc inference_api.cc)
>>>>>>> add pybind for fleet
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc fleet_wrapper_py.cc data_set_py.cc imperative.cc ir.cc inference_api.cc)
if(WITH_PYTHON)
if(WITH_AMD_GPU)
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
......@@ -29,12 +27,12 @@ limitations under the License. */
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/async_executor_py.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/pybind/data_set_py.h"
namespace py = pybind11;
namespace pd = paddle::framework;
......@@ -43,10 +41,9 @@ namespace paddle {
namespace pybind {
void BindDataset(py::module* m) {
py::class_<framework::DataSet>(*m, "Dataset")
py::class_<framework::Dataset>(*m, "Dataset")
.def(py::init([]() {
return std::unique_ptr<framework::Dataset>(
new framework::Dataset());
return std::unique_ptr<framework::Dataset>(new framework::Dataset());
}))
.def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum)
......@@ -54,7 +51,7 @@ void BindDataset(py::module* m) {
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GLobalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle);
}
} // end namespace pybind
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册