提交 e36bbcc8 编写于 作者: D dongdaxiang

fix some typo and CMakefile.txt

上级 824b84d1
...@@ -18,15 +18,14 @@ ...@@ -18,15 +18,14 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
Dataset::Dataset() { Dataset::Dataset() { thread_num_ = 1; }
thread_num_ = 1;
}
void Dataset::SetFileList(const std::vector<std::string>& filelist) { void Dataset::SetFileList(const std::vector<std::string>& filelist) {
filelist_ = filelist; filelist_ = filelist;
int file_cnt = filelist_.size(); int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) { if (thread_num_ > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num_ << ", file num = " << file_cnt VLOG(1) << "DataSet thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt; << ". Changing DataSet thread num = " << file_cnt;
thread_num_ = file_cnt; thread_num_ = file_cnt;
} }
...@@ -35,22 +34,23 @@ void Dataset::SetFileList(const std::vector<std::string>& filelist) { ...@@ -35,22 +34,23 @@ void Dataset::SetFileList(const std::vector<std::string>& filelist) {
void Dataset::SetThreadNum(int thread_num) { void Dataset::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size(); int file_cnt = filelist_.size();
if (file_cnt != 0 && thread_num > file_cnt) { if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num << ", file num = " << file_cnt VLOG(1) << "DataSet thread num = " << thread_num
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt; << ". Changing DataSet thread num = " << file_cnt;
thread_num = file_cnt; thread_num = file_cnt;
} }
thread_num_ = thread_num; thread_num_ = thread_num;
} }
void Dataset::SetTrainerNum(int trainer_num) { void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; }
trainer_num_ = trainer_num;
}
void Dataset::SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc) { void Dataset::SetDataFeedDesc(
const paddle::framework::DataFeedDesc& data_feed_desc) {
data_feed_desc_ = data_feed_desc; data_feed_desc_ = data_feed_desc;
} }
std::vector<std::shared_ptr<paddle::framework::DataFeed>> Dataset::GetReaders() { std::vector<std::shared_ptr<paddle::framework::DataFeed>>
Dataset::GetReaders() {
return readers_; return readers_;
} }
...@@ -60,8 +60,8 @@ void Dataset::LoadIntoMemory() { ...@@ -60,8 +60,8 @@ void Dataset::LoadIntoMemory() {
} }
std::vector<std::thread> load_threads; std::vector<std::thread> load_threads;
for (int64_t i = 0; i < thread_num_; ++i) { for (int64_t i = 0; i < thread_num_; ++i) {
load_threads.push_back(std::thread(&paddle::framework::DataFeed::LoadIntoMemory, load_threads.push_back(std::thread(
readers_[i].get())); &paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
} }
for (std::thread& t : load_threads) { for (std::thread& t : load_threads) {
t.join(); t.join();
...@@ -74,8 +74,8 @@ void Dataset::LocalShuffle() { ...@@ -74,8 +74,8 @@ void Dataset::LocalShuffle() {
} }
std::vector<std::thread> local_shuffle_threads; std::vector<std::thread> local_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) { for (int64_t i = 0; i < thread_num_; ++i) {
local_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::LocalShuffle, local_shuffle_threads.push_back(std::thread(
readers_[i].get())); &paddle::framework::DataFeed::LocalShuffle, readers_[i].get()));
} }
for (std::thread& t : local_shuffle_threads) { for (std::thread& t : local_shuffle_threads) {
t.join(); t.join();
...@@ -115,14 +115,14 @@ void Dataset::CreateReaders() { ...@@ -115,14 +115,14 @@ void Dataset::CreateReaders() {
readers_[0]->SetFileList(filelist_); readers_[0]->SetFileList(filelist_);
} }
int Dataset::ReceiveFromClient(int msg_type, int client_id, const std::string& msg) { int Dataset::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) {
// can also use hash // can also use hash
// int64_t index = paddle::ps::local_random_engine()() % thread_num_; // int64_t index = paddle::ps::local_random_engine()() % thread_num_;
// todo
int64_t index = 0; int64_t index = 0;
readers_[index]->PutInsToChannel(msg); readers_[index]->PutInsToChannel(msg);
return 0; return 0;
} }
} } // end namespace framework
} } // end namespace paddle
...@@ -34,29 +34,27 @@ class Dataset { ...@@ -34,29 +34,27 @@ class Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist); virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num); virtual void SetTrainerNum(int trainer_num);
virtual void SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc); virtual void SetDataFeedDesc(
const paddle::framework::DataFeedDesc& data_feed_desc);
virtual const std::vector<std::string>& GetFileList() { virtual const std::vector<std::string>& GetFileList() { return filelist_; }
return filelist_; virtual int GetThreadNum() { return thread_num_; }
} virtual int GetTrainerNum() { return trainer_num_; }
virtual int GetThreadNum() {
return thread_num_;
}
virtual int GetTrainerNum() {
return trainer_num_;
}
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() { virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_; return data_feed_desc_;
} }
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>> GetReaders(); virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>
GetReaders();
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
// todo global shuffle // todo global shuffle
virtual void GlobalShuffle(); virtual void GlobalShuffle();
virtual void CreateReaders(); virtual void CreateReaders();
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg); virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_; std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
int thread_num_; int thread_num_;
std::string fs_name_; std::string fs_name_;
...@@ -66,5 +64,5 @@ class Dataset { ...@@ -66,5 +64,5 @@ class Dataset {
int trainer_num_; int trainer_num_;
}; };
} } // end namespace framework
} } // end namespace paddle
...@@ -115,6 +115,10 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -115,6 +115,10 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
} }
} }
void Executor::RunFromDataset(const ProgramDesc& pdesc, const Dataset& dataset,
const std::string& trainer_desc_str,
const bool debug) {}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars, bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars, const std::vector<std::string>& skip_ref_cnt_vars,
......
...@@ -19,13 +19,13 @@ limitations under the License. */ ...@@ -19,13 +19,13 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -112,11 +112,7 @@ class Executor { ...@@ -112,11 +112,7 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromTrainerDesc(const ProgramDesc& main_program, void RunFromDataset(const ProgramDesc& main_program, const Dataset& dataset,
const std::string& trainer_desc_str,
const bool debug);
void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset,
const std::string& trainer_desc_str, const bool debug); const std::string& trainer_desc_str, const bool debug);
public: public:
......
...@@ -5,11 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wr ...@@ -5,11 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wr
if(WITH_PYTHON) if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op) list(APPEND PYBIND_DEPS py_func_op)
endif() endif()
<<<<<<< HEAD set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc fleet_wrapper_py.cc data_set_py.cc imperative.cc ir.cc inference_api.cc)
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc)
=======
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc fleet_wrapper_py.cc imperative.cc ir.cc inference_api.cc)
>>>>>>> add pybind for fleet
if(WITH_PYTHON) if(WITH_PYTHON)
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
......
...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <fcntl.h> #include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE #ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE #undef _POSIX_C_SOURCE
#endif #endif
...@@ -29,12 +27,12 @@ limitations under the License. */ ...@@ -29,12 +27,12 @@ limitations under the License. */
#include "paddle/fluid/framework/async_executor.h" #include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/async_executor_py.h" #include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/framework/data_set.h"
namespace py = pybind11; namespace py = pybind11;
namespace pd = paddle::framework; namespace pd = paddle::framework;
...@@ -43,18 +41,17 @@ namespace paddle { ...@@ -43,18 +41,17 @@ namespace paddle {
namespace pybind { namespace pybind {
void BindDataset(py::module* m) { void BindDataset(py::module* m) {
py::class_<framework::DataSet>(*m, "Dataset") py::class_<framework::Dataset>(*m, "Dataset")
.def(py::init([]() { .def(py::init([]() {
return std::unique_ptr<framework::Dataset>( return std::unique_ptr<framework::Dataset>(new framework::Dataset());
new framework::Dataset()); }))
})) .def("set_filelist", &framework::Dataset::SetFileList)
.def("set_filelist", &framework::Dataset::SetFileList) .def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_thread_num", &framework::Dataset::SetThreadNum) .def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) .def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory) .def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("local_shuffle", &framework::Dataset::LocalShuffle) .def("global_shuffle", &framework::Dataset::GlobalShuffle);
.def("global_shuffle", &framework::Dataset::GLobalShuffle)
} }
} // end namespace pybind } // end namespace pybind
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册