From f6c9232a3d91c88c25d52b56193c38e1506bee11 Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Mon, 18 Mar 2019 15:10:40 +0800 Subject: [PATCH] fix dataset float32 type problem --- paddle/fluid/framework/CMakeLists.txt | 1 - paddle/fluid/framework/async_executor.cc | 28 ++++++++--------- paddle/fluid/framework/async_executor.h | 8 +++-- paddle/fluid/framework/data_set.cc | 32 +++++++++++--------- paddle/fluid/framework/device_worker_test.cc | 24 +++++++++++++++ paddle/fluid/framework/trainer_test.cc | 27 +++++++++++++++++ python/paddle/fluid/dataset.py | 2 +- 7 files changed, 88 insertions(+), 34 deletions(-) create mode 100644 paddle/fluid/framework/device_worker_test.cc create mode 100644 paddle/fluid/framework/trainer_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index d13009480..f1c8af2ef 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -24,7 +24,6 @@ endfunction() add_subdirectory(ir) add_subdirectory(details) add_subdirectory(fleet) -add_subdirectory(common) add_subdirectory(io) #ddim lib proto_library(framework_proto SRCS framework.proto) diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 078bd3961..b13eefba2 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -60,10 +60,10 @@ void AsyncExecutor::GatherServers(const std::vector& host_sign_list, } // todo InitModel -void AsyncExecutor::InitModel() { } +void AsyncExecutor::InitModel() {} // todo SaveModel -void AsyncExecutor::SaveModel(const std::string& path) { } +void AsyncExecutor::SaveModel(const std::string& path) {} void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, const std::string& data_feed_desc_str, @@ -88,14 +88,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, &data_feed_desc); - actual_thread_num = thread_num; + actual_thread_num_ = thread_num; int file_cnt = filelist.size(); PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty"); - if (actual_thread_num > file_cnt) { + if (actual_thread_num_ > file_cnt) { VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt << ". Changing thread_num = " << file_cnt; - actual_thread_num = file_cnt; + actual_thread_num_ = file_cnt; } /* @@ -111,12 +111,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, */ // todo: should be factory method for creating datafeed std::vector> readers; - PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); + /* + PrepareReaders(readers, actual_thread_num_, data_feed_desc, filelist); #ifdef PADDLE_WITH_PSLIB PrepareDenseThread(mode); #endif + */ std::vector> workers; - workers.resize(actual_thread_num); + workers.resize(actual_thread_num_); for (auto& worker : workers) { #ifdef PADDLE_WITH_PSLIB if (mode == "mpi") { @@ -130,13 +132,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, } // prepare thread resource here - for (int thidx = 0; thidx < actual_thread_num; ++thidx) { + /* + for (int thidx = 0; thidx < actual_thread_num_; ++thidx) { CreateThreads(workers[thidx].get(), main_program, readers[thidx], fetch_var_names, root_scope_, thidx, debug); } + */ // start executing ops in multiple threads - for (int thidx = 0; thidx < actual_thread_num; ++thidx) { + for (int thidx = 0; thidx < actual_thread_num_; ++thidx) { if (debug) { threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer, workers[thidx].get())); @@ -160,11 +164,5 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, return; } -// todo RunFromDataset -void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program, - Dataset* data_set, - const std::string& trainer_desc_str, - const bool debug) { } - } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/async_executor.h b/paddle/fluid/framework/async_executor.h index e54a17333..7b59e1b11 100644 --- a/paddle/fluid/framework/async_executor.h +++ b/paddle/fluid/framework/async_executor.h @@ -25,12 +25,12 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/data_feed.pb.h" +#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor_thread_worker.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/data_set.h" namespace paddle { namespace framework { @@ -64,7 +64,11 @@ class AsyncExecutor { AsyncExecutor(Scope* scope, const platform::Place& place); virtual ~AsyncExecutor() {} void RunFromFile(const ProgramDesc& main_program, - const std::string& trainer_desc_str, const bool debug); + const std::string& data_feed_desc_str, + const std::vector& filelist, + const int thread_num, + const std::vector& fetch_var_names, + const std::string& mode, const bool debug); // TODO(guru4elephant): make init server decoupled from executor void InitServer(const std::string& dist_desc, int index); diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 1d2a018be..e7128869d 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -12,8 +12,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include "paddle/fluid/framework/data_set.h" +#include #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" @@ -23,7 +23,9 @@ namespace paddle { namespace framework { template -DatasetImpl::DatasetImpl() { thread_num_ = 1; } +DatasetImpl::DatasetImpl() { + thread_num_ = 1; +} template void DatasetImpl::SetFileList(const std::vector& filelist) { @@ -66,7 +68,7 @@ void DatasetImpl::SetDataFeedDesc(const std::string& data_feed_desc_str) { template std::vector>& - DatasetImpl::GetReaders() { +DatasetImpl::GetReaders() { return readers_; } @@ -112,22 +114,21 @@ template void DatasetImpl::GlobalShuffle() { VLOG(3) << "DatasetImpl::GlobalShuffle() begin"; if (readers_.size() == 0) { - CreateReaders(); + CreateReaders(); } // if it is not InMemory, memory_data_ is empty std::random_shuffle(memory_data_.begin(), memory_data_.end()); auto fleet_ptr = FleetWrapper::GetInstance(); VLOG(3) << "RegisterClientToClientMsgHandler"; - fleet_ptr->RegisterClientToClientMsgHandler(0, - [this](int msg_type, int client_id, const std::string& msg) -> int { - return this->ReceiveFromClient(msg_type, client_id, msg); - }); + fleet_ptr->RegisterClientToClientMsgHandler( + 0, [this](int msg_type, int client_id, const std::string& msg) -> int { + return this->ReceiveFromClient(msg_type, client_id, msg); + }); VLOG(3) << "start global shuffle threads"; std::vector global_shuffle_threads; for (int i = 0; i < thread_num_; ++i) { - global_shuffle_threads.push_back( - std::thread(&paddle::framework::DataFeed::GlobalShuffle, - readers_[i].get())); + global_shuffle_threads.push_back(std::thread( + &paddle::framework::DataFeed::GlobalShuffle, readers_[i].get())); } for (std::thread& t : global_shuffle_threads) { t.join(); @@ -169,19 +170,20 @@ void DatasetImpl::DestroyReaders() { } std::vector fill_threads; for (int i = 0; i < thread_num_; ++i) { - fill_threads.push_back(std::thread( - &paddle::framework::DataFeed::FillChannelToMemoryData, - readers_[i].get())); + fill_threads.push_back( + std::thread(&paddle::framework::DataFeed::FillChannelToMemoryData, + readers_[i].get())); } for (std::thread& t : fill_threads) { t.join(); } std::vector>().swap(readers_); + LOG(WARNING) << "readers size: " << readers_.size(); } template int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, - const std::string& msg) { + const std::string& msg) { // todo random // int64_t index = paddle::ps::local_random_engine()() % thread_num_; int64_t index = 0; diff --git a/paddle/fluid/framework/device_worker_test.cc b/paddle/fluid/framework/device_worker_test.cc new file mode 100644 index 000000000..faa648ab3 --- /dev/null +++ b/paddle/fluid/framework/device_worker_test.cc @@ -0,0 +1,24 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/trainer.h" + +namespace paddle { +namespace framework { +TEST() { + // create hogwild device worker +} +} +} diff --git a/paddle/fluid/framework/trainer_test.cc b/paddle/fluid/framework/trainer_test.cc new file mode 100644 index 000000000..f689679d4 --- /dev/null +++ b/paddle/fluid/framework/trainer_test.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/trainer.h" +#include + +namespace paddle { +namespace framework { +TEST() { + // create multi trainer + // create hogwild device worker + // create dataset + // train for a while +} +} +} diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 6d239260c..6ae1d3cf1 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -78,7 +78,7 @@ class DatasetBase(object): if var.lod_level == 0: slot_var.is_dense = True if var.dtype == core.VarDesc.VarType.FP32: - slot_var.type = "float32" + slot_var.type = "float" elif var.dtype == core.VarDesc.VarType.INT64: slot_var.type = "uint64" else: -- GitLab