diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 1e82c5fae3e0c12b3e12ea22df69926d78c0c88c..8c73de9cda2682388d110609e4c2695972f49a6d 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -198,27 +198,15 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph build_strategy fast_threaded_ssa_graph_executor variable_helper) -if(WITH_PSLIB) - cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc - executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc - trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc - downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc - data_set.cc - DEPS op_registry device_context scope framework_proto - trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer - feed_fetch_method graph_to_program_pass async_executor_proto - variable_helper pslib_brpc pslib timer fs shell) -else() - cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc - executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc - trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc - downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc - data_set.cc - DEPS op_registry device_context scope framework_proto - trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer - feed_fetch_method graph_to_program_pass async_executor_proto - variable_helper timer fs shell) -endif(WITH_PSLIB) +cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc + executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc + trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc + downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc + data_set.cc + DEPS op_registry device_context scope framework_proto + trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer + feed_fetch_method graph_to_program_pass data_feed_proto + variable_helper timer) cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 0233982f119b38b8c121cbf209152fc828ff700c..bf7ade95b2891a36aa7c55b58ead0b376c4cf4df 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -220,8 +220,8 @@ void InMemoryDataFeed::LocalShuffle() { std::random_shuffle(memory_data_.begin(), memory_data_.end()); } - template class InMemoryDataFeed>; + // todo global shuffle /* template diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 46a8ad4d88797880b6f33cd3c6035507f5d7fc63..bbf59b95c65a176a451c4a1c1704ffa1c14651f5 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -50,6 +50,7 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/pybind/async_executor_py.h" #include "paddle/fluid/pybind/const_value.h" +#include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/fleet_wrapper_py.h" #include "paddle/fluid/pybind/imperative.h" @@ -61,7 +62,6 @@ limitations under the License. */ #include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/string/to_string.h" -#include "paddle/fluid/pybind/data_set_py.h" #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 @@ -923,6 +923,7 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init()) .def("close", &Executor::Close) + .def("run_from_dataset", &Executor::RunFromDataset) .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, int block_id, bool create_local_scope, bool create_vars, const std::vector &fetch_vars) {