提交 f6c9232a 编写于 作者: D dongdaxiang

fix dataset float32 type problem

上级 73b1f396
......@@ -24,7 +24,6 @@ endfunction()
add_subdirectory(ir)
add_subdirectory(details)
add_subdirectory(fleet)
add_subdirectory(common)
add_subdirectory(io)
#ddim lib
proto_library(framework_proto SRCS framework.proto)
......
......@@ -60,10 +60,10 @@ void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list,
}
// todo InitModel
void AsyncExecutor::InitModel() { }
void AsyncExecutor::InitModel() {}
// todo SaveModel
void AsyncExecutor::SaveModel(const std::string& path) { }
void AsyncExecutor::SaveModel(const std::string& path) {}
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str,
......@@ -88,14 +88,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc);
actual_thread_num = thread_num;
actual_thread_num_ = thread_num;
int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
if (actual_thread_num > file_cnt) {
if (actual_thread_num_ > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
<< ". Changing thread_num = " << file_cnt;
actual_thread_num = file_cnt;
actual_thread_num_ = file_cnt;
}
/*
......@@ -111,12 +111,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
*/
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
/*
PrepareReaders(readers, actual_thread_num_, data_feed_desc, filelist);
#ifdef PADDLE_WITH_PSLIB
PrepareDenseThread(mode);
#endif
*/
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num);
workers.resize(actual_thread_num_);
for (auto& worker : workers) {
#ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") {
......@@ -130,13 +132,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
}
// prepare thread resource here
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
/*
for (int thidx = 0; thidx < actual_thread_num_; ++thidx) {
CreateThreads(workers[thidx].get(), main_program, readers[thidx],
fetch_var_names, root_scope_, thidx, debug);
}
*/
// start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
for (int thidx = 0; thidx < actual_thread_num_; ++thidx) {
if (debug) {
threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer,
workers[thidx].get()));
......@@ -160,11 +164,5 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
return;
}
// todo RunFromDataset
void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program,
Dataset* data_set,
const std::string& trainer_desc_str,
const bool debug) { }
} // end namespace framework
} // end namespace paddle
......@@ -25,12 +25,12 @@ limitations under the License. */
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
......@@ -64,7 +64,11 @@ class AsyncExecutor {
AsyncExecutor(Scope* scope, const platform::Place& place);
virtual ~AsyncExecutor() {}
void RunFromFile(const ProgramDesc& main_program,
const std::string& trainer_desc_str, const bool debug);
const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_var_names,
const std::string& mode, const bool debug);
// TODO(guru4elephant): make init server decoupled from executor
void InitServer(const std::string& dist_desc, int index);
......
......@@ -12,8 +12,8 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <random>
#include "paddle/fluid/framework/data_set.h"
#include <random>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
......@@ -23,7 +23,9 @@ namespace paddle {
namespace framework {
template <typename T>
DatasetImpl<T>::DatasetImpl() { thread_num_ = 1; }
DatasetImpl<T>::DatasetImpl() {
thread_num_ = 1;
}
template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
......@@ -66,7 +68,7 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
template <typename T>
std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
DatasetImpl<T>::GetReaders() {
DatasetImpl<T>::GetReaders() {
return readers_;
}
......@@ -118,16 +120,15 @@ void DatasetImpl<T>::GlobalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(0,
[this](int msg_type, int client_id, const std::string& msg) -> int {
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(
std::thread(&paddle::framework::DataFeed::GlobalShuffle,
readers_[i].get()));
global_shuffle_threads.push_back(std::thread(
&paddle::framework::DataFeed::GlobalShuffle, readers_[i].get()));
}
for (std::thread& t : global_shuffle_threads) {
t.join();
......@@ -169,14 +170,15 @@ void DatasetImpl<T>::DestroyReaders() {
}
std::vector<std::thread> fill_threads;
for (int i = 0; i < thread_num_; ++i) {
fill_threads.push_back(std::thread(
&paddle::framework::DataFeed::FillChannelToMemoryData,
fill_threads.push_back(
std::thread(&paddle::framework::DataFeed::FillChannelToMemoryData,
readers_[i].get()));
}
for (std::thread& t : fill_threads) {
t.join();
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
LOG(WARNING) << "readers size: " << readers_.size();
}
template <typename T>
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
TEST() {
// create hogwild device worker
}
}
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/trainer.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
TEST() {
// create multi trainer
// create hogwild device worker
// create dataset
// train for a while
}
}
}
......@@ -78,7 +78,7 @@ class DatasetBase(object):
if var.lod_level == 0:
slot_var.is_dense = True
if var.dtype == core.VarDesc.VarType.FP32:
slot_var.type = "float32"
slot_var.type = "float"
elif var.dtype == core.VarDesc.VarType.INT64:
slot_var.type = "uint64"
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册