提交 f6c9232a 编写于 作者: D dongdaxiang

fix dataset float32 type problem

上级 73b1f396
...@@ -24,7 +24,6 @@ endfunction() ...@@ -24,7 +24,6 @@ endfunction()
add_subdirectory(ir) add_subdirectory(ir)
add_subdirectory(details) add_subdirectory(details)
add_subdirectory(fleet) add_subdirectory(fleet)
add_subdirectory(common)
add_subdirectory(io) add_subdirectory(io)
#ddim lib #ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
......
...@@ -60,10 +60,10 @@ void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list, ...@@ -60,10 +60,10 @@ void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list,
} }
// todo InitModel // todo InitModel
void AsyncExecutor::InitModel() { } void AsyncExecutor::InitModel() {}
// todo SaveModel // todo SaveModel
void AsyncExecutor::SaveModel(const std::string& path) { } void AsyncExecutor::SaveModel(const std::string& path) {}
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str, const std::string& data_feed_desc_str,
...@@ -88,14 +88,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -88,14 +88,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc); &data_feed_desc);
actual_thread_num = thread_num; actual_thread_num_ = thread_num;
int file_cnt = filelist.size(); int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty"); PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
if (actual_thread_num > file_cnt) { if (actual_thread_num_ > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
<< ". Changing thread_num = " << file_cnt; << ". Changing thread_num = " << file_cnt;
actual_thread_num = file_cnt; actual_thread_num_ = file_cnt;
} }
/* /*
...@@ -111,12 +111,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -111,12 +111,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
*/ */
// todo: should be factory method for creating datafeed // todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers; std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); /*
PrepareReaders(readers, actual_thread_num_, data_feed_desc, filelist);
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
PrepareDenseThread(mode); PrepareDenseThread(mode);
#endif #endif
*/
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers; std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num); workers.resize(actual_thread_num_);
for (auto& worker : workers) { for (auto& worker : workers) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") { if (mode == "mpi") {
...@@ -130,13 +132,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -130,13 +132,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
} }
// prepare thread resource here // prepare thread resource here
for (int thidx = 0; thidx < actual_thread_num; ++thidx) { /*
for (int thidx = 0; thidx < actual_thread_num_; ++thidx) {
CreateThreads(workers[thidx].get(), main_program, readers[thidx], CreateThreads(workers[thidx].get(), main_program, readers[thidx],
fetch_var_names, root_scope_, thidx, debug); fetch_var_names, root_scope_, thidx, debug);
} }
*/
// start executing ops in multiple threads // start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) { for (int thidx = 0; thidx < actual_thread_num_; ++thidx) {
if (debug) { if (debug) {
threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer, threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer,
workers[thidx].get())); workers[thidx].get()));
...@@ -160,11 +164,5 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -160,11 +164,5 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
return; return;
} }
// todo RunFromDataset
void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program,
Dataset* data_set,
const std::string& trainer_desc_str,
const bool debug) { }
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -25,12 +25,12 @@ limitations under the License. */ ...@@ -25,12 +25,12 @@ limitations under the License. */
#include <typeinfo> #include <typeinfo>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_thread_worker.h" #include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -64,7 +64,11 @@ class AsyncExecutor { ...@@ -64,7 +64,11 @@ class AsyncExecutor {
AsyncExecutor(Scope* scope, const platform::Place& place); AsyncExecutor(Scope* scope, const platform::Place& place);
virtual ~AsyncExecutor() {} virtual ~AsyncExecutor() {}
void RunFromFile(const ProgramDesc& main_program, void RunFromFile(const ProgramDesc& main_program,
const std::string& trainer_desc_str, const bool debug); const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_var_names,
const std::string& mode, const bool debug);
// TODO(guru4elephant): make init server decoupled from executor // TODO(guru4elephant): make init server decoupled from executor
void InitServer(const std::string& dist_desc, int index); void InitServer(const std::string& dist_desc, int index);
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include <random>
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include <random>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
...@@ -23,7 +23,9 @@ namespace paddle { ...@@ -23,7 +23,9 @@ namespace paddle {
namespace framework { namespace framework {
template <typename T> template <typename T>
DatasetImpl<T>::DatasetImpl() { thread_num_ = 1; } DatasetImpl<T>::DatasetImpl() {
thread_num_ = 1;
}
template <typename T> template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) { void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
...@@ -66,7 +68,7 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) { ...@@ -66,7 +68,7 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
template <typename T> template <typename T>
std::vector<std::shared_ptr<paddle::framework::DataFeed>>& std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
DatasetImpl<T>::GetReaders() { DatasetImpl<T>::GetReaders() {
return readers_; return readers_;
} }
...@@ -118,16 +120,15 @@ void DatasetImpl<T>::GlobalShuffle() { ...@@ -118,16 +120,15 @@ void DatasetImpl<T>::GlobalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end()); std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler"; VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(0, fleet_ptr->RegisterClientToClientMsgHandler(
[this](int msg_type, int client_id, const std::string& msg) -> int { 0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg); return this->ReceiveFromClient(msg_type, client_id, msg);
}); });
VLOG(3) << "start global shuffle threads"; VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads; std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back( global_shuffle_threads.push_back(std::thread(
std::thread(&paddle::framework::DataFeed::GlobalShuffle, &paddle::framework::DataFeed::GlobalShuffle, readers_[i].get()));
readers_[i].get()));
} }
for (std::thread& t : global_shuffle_threads) { for (std::thread& t : global_shuffle_threads) {
t.join(); t.join();
...@@ -169,14 +170,15 @@ void DatasetImpl<T>::DestroyReaders() { ...@@ -169,14 +170,15 @@ void DatasetImpl<T>::DestroyReaders() {
} }
std::vector<std::thread> fill_threads; std::vector<std::thread> fill_threads;
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
fill_threads.push_back(std::thread( fill_threads.push_back(
&paddle::framework::DataFeed::FillChannelToMemoryData, std::thread(&paddle::framework::DataFeed::FillChannelToMemoryData,
readers_[i].get())); readers_[i].get()));
} }
for (std::thread& t : fill_threads) { for (std::thread& t : fill_threads) {
t.join(); t.join();
} }
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_); std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
LOG(WARNING) << "readers size: " << readers_.size();
} }
template <typename T> template <typename T>
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
TEST() {
// create hogwild device worker
}
}
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/trainer.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
TEST() {
// create multi trainer
// create hogwild device worker
// create dataset
// train for a while
}
}
}
...@@ -78,7 +78,7 @@ class DatasetBase(object): ...@@ -78,7 +78,7 @@ class DatasetBase(object):
if var.lod_level == 0: if var.lod_level == 0:
slot_var.is_dense = True slot_var.is_dense = True
if var.dtype == core.VarDesc.VarType.FP32: if var.dtype == core.VarDesc.VarType.FP32:
slot_var.type = "float32" slot_var.type = "float"
elif var.dtype == core.VarDesc.VarType.INT64: elif var.dtype == core.VarDesc.VarType.INT64:
slot_var.type = "uint64" slot_var.type = "uint64"
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册