未验证 提交 cedc0477 编写于 作者: X xujiaqi01 提交者: GitHub

support change shuffle and train thread num (#19841)

* support change shuffle thread num
* support change train thread num
* fix receive shuffle data of each channel
* data norm stop gradient
* add check thread_tensor type and root_tensor type when merge metric
* remove sleep in shuffle, add config
* add config of pslib client to client communication
* fix xbox str
* add data norm op testcase
* add flush in trainer finalize
上级 14625ffe
...@@ -42,8 +42,8 @@ DatasetImpl<T>::DatasetImpl() { ...@@ -42,8 +42,8 @@ DatasetImpl<T>::DatasetImpl() {
channel_num_ = 1; channel_num_ = 1;
file_idx_ = 0; file_idx_ = 0;
cur_channel_ = 0; cur_channel_ = 0;
fleet_send_batch_size_ = 80000; fleet_send_batch_size_ = 1024;
fleet_send_sleep_seconds_ = 2; fleet_send_sleep_seconds_ = 0;
merge_by_insid_ = false; merge_by_insid_ = false;
erase_duplicate_feas_ = true; erase_duplicate_feas_ = true;
keep_unmerged_ins_ = true; keep_unmerged_ins_ = true;
...@@ -51,6 +51,7 @@ DatasetImpl<T>::DatasetImpl() { ...@@ -51,6 +51,7 @@ DatasetImpl<T>::DatasetImpl() {
parse_ins_id_ = false; parse_ins_id_ = false;
parse_content_ = false; parse_content_ = false;
preload_thread_num_ = 0; preload_thread_num_ = 0;
global_index_ = 0;
} }
// set filelist, file_idx_ will reset to zero. // set filelist, file_idx_ will reset to zero.
...@@ -291,7 +292,7 @@ void DatasetImpl<T>::LocalShuffle() { ...@@ -291,7 +292,7 @@ void DatasetImpl<T>::LocalShuffle() {
} }
template <typename T> template <typename T>
void DatasetImpl<T>::GlobalShuffle() { void DatasetImpl<T>::GlobalShuffle(int thread_num) {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin"; VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
platform::Timer timeline; platform::Timer timeline;
timeline.Start(); timeline.Start();
...@@ -358,13 +359,21 @@ void DatasetImpl<T>::GlobalShuffle() { ...@@ -358,13 +359,21 @@ void DatasetImpl<T>::GlobalShuffle() {
ars.shrink_to_fit(); ars.shrink_to_fit();
data.clear(); data.clear();
data.shrink_to_fit(); data.shrink_to_fit();
sleep(this->fleet_send_sleep_seconds_); // currently we find bottleneck is server not able to handle large data
// in time, so we can remove this sleep and set fleet_send_batch_size to
// 1024, and set server thread to 24.
if (fleet_send_sleep_seconds_ != 0) {
sleep(this->fleet_send_sleep_seconds_);
}
} }
}; };
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads; std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) { if (thread_num == -1) {
thread_num = thread_num_;
}
VLOG(3) << "start global shuffle threads, num = " << thread_num;
for (int i = 0; i < thread_num; ++i) {
global_shuffle_threads.push_back(std::thread(global_shuffle_func)); global_shuffle_threads.push_back(std::thread(global_shuffle_func));
} }
for (std::thread& t : global_shuffle_threads) { for (std::thread& t : global_shuffle_threads) {
...@@ -378,6 +387,101 @@ void DatasetImpl<T>::GlobalShuffle() { ...@@ -378,6 +387,101 @@ void DatasetImpl<T>::GlobalShuffle() {
<< timeline.ElapsedSec() << " seconds"; << timeline.ElapsedSec() << " seconds";
} }
template <typename T>
void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num) {
if (channel_num_ == channel_num) {
VLOG(3) << "DatasetImpl<T>::DynamicAdjustChannelNum channel_num_="
<< channel_num_ << ", channel_num_=channel_num, no need to adjust";
return;
}
VLOG(3) << "adjust channel num from " << channel_num_ << " to "
<< channel_num;
channel_num_ = channel_num;
std::vector<paddle::framework::Channel<T>>* origin_channels = nullptr;
std::vector<paddle::framework::Channel<T>>* other_channels = nullptr;
// find out which channel (output or consume) has data
int cur_channel = 0;
uint64_t output_channels_data_size = 0;
uint64_t consume_channels_data_size = 0;
CHECK(multi_output_channel_.size() == multi_consume_channel_.size());
for (int i = 0; i < multi_output_channel_.size(); ++i) {
output_channels_data_size += multi_output_channel_[i]->Size();
consume_channels_data_size += multi_consume_channel_[i]->Size();
}
if (output_channels_data_size != 0) {
CHECK(consume_channels_data_size == 0); // NOLINT
cur_channel = 0;
} else {
CHECK(output_channels_data_size == 0); // NOLINT
cur_channel = 1;
}
if (cur_channel == 0) {
origin_channels = &multi_output_channel_;
other_channels = &multi_consume_channel_;
} else {
origin_channels = &multi_consume_channel_;
other_channels = &multi_output_channel_;
}
CHECK(origin_channels != nullptr); // NOLINT
CHECK(other_channels != nullptr); // NOLINT
paddle::framework::Channel<T> total_data_channel =
paddle::framework::MakeChannel<T>();
std::vector<paddle::framework::Channel<T>> new_channels;
std::vector<paddle::framework::Channel<T>> new_other_channels;
std::vector<T> local_vec;
for (int i = 0; i < origin_channels->size(); ++i) {
local_vec.clear();
(*origin_channels)[i]->Close();
(*origin_channels)[i]->ReadAll(local_vec);
total_data_channel->Write(std::move(local_vec));
}
total_data_channel->Close();
total_data_channel->SetBlockSize(total_data_channel->Size() / channel_num +
1);
for (int i = 0; i < channel_num; ++i) {
local_vec.clear();
total_data_channel->Read(local_vec);
new_other_channels.push_back(paddle::framework::MakeChannel<T>());
new_channels.push_back(paddle::framework::MakeChannel<T>());
new_channels[i]->Write(std::move(local_vec));
}
total_data_channel->Clear();
origin_channels->clear();
other_channels->clear();
*origin_channels = new_channels;
*other_channels = new_other_channels;
new_channels.clear();
new_other_channels.clear();
std::vector<paddle::framework::Channel<T>>().swap(new_channels);
std::vector<paddle::framework::Channel<T>>().swap(new_other_channels);
local_vec.clear();
std::vector<T>().swap(local_vec);
VLOG(3) << "adjust channel num done";
}
template <typename T>
void DatasetImpl<T>::DynamicAdjustReadersNum(int thread_num) {
if (thread_num_ == thread_num) {
VLOG(3) << "DatasetImpl<T>::DynamicAdjustReadersNum thread_num_="
<< thread_num_ << ", thread_num_=thread_num, no need to adjust";
return;
}
VLOG(3) << "adjust readers num from " << thread_num_ << " to " << thread_num;
thread_num_ = thread_num;
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
CreateReaders();
VLOG(3) << "adjust readers num done";
}
template <typename T>
void DatasetImpl<T>::SetFleetSendSleepSeconds(int seconds) {
fleet_send_sleep_seconds_ = seconds;
}
template <typename T> template <typename T>
void DatasetImpl<T>::CreateReaders() { void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Calling CreateReaders()"; VLOG(3) << "Calling CreateReaders()";
...@@ -509,7 +613,16 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, ...@@ -509,7 +613,16 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
CHECK(ar.Cursor() == ar.Finish()); CHECK(ar.Cursor() == ar.Finish());
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = fleet_ptr->LocalRandomEngine()() % channel_num_; // not use random because it doesn't perform well here.
// to make sure each channel get data equally, we just put data to
// channel one by one.
// int64_t index = fleet_ptr->LocalRandomEngine()() % channel_num_;
int64_t index = 0;
{
std::unique_lock<std::mutex> lk(global_index_mutex_);
index = global_index_++;
}
index = index % channel_num_;
VLOG(3) << "ramdom index=" << index; VLOG(3) << "ramdom index=" << index;
multi_output_channel_[index]->Write(std::move(data)); multi_output_channel_[index]->Write(std::move(data));
......
...@@ -99,7 +99,7 @@ class Dataset { ...@@ -99,7 +99,7 @@ class Dataset {
// local shuffle data // local shuffle data
virtual void LocalShuffle() = 0; virtual void LocalShuffle() = 0;
// global shuffle data // global shuffle data
virtual void GlobalShuffle() = 0; virtual void GlobalShuffle(int thread_num = -1) = 0;
// for slots shuffle // for slots shuffle
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0; virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace, virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
...@@ -120,6 +120,11 @@ class Dataset { ...@@ -120,6 +120,11 @@ class Dataset {
virtual void DestroyPreLoadReaders() = 0; virtual void DestroyPreLoadReaders() = 0;
// set preload thread num // set preload thread num
virtual void SetPreLoadThreadNum(int thread_num) = 0; virtual void SetPreLoadThreadNum(int thread_num) = 0;
// seperate train thread and dataset thread
virtual void DynamicAdjustChannelNum(int channel_num) = 0;
virtual void DynamicAdjustReadersNum(int thread_num) = 0;
// set fleet send sleep seconds
virtual void SetFleetSendSleepSeconds(int seconds) = 0;
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
...@@ -169,7 +174,7 @@ class DatasetImpl : public Dataset { ...@@ -169,7 +174,7 @@ class DatasetImpl : public Dataset {
virtual void WaitPreLoadDone(); virtual void WaitPreLoadDone();
virtual void ReleaseMemory(); virtual void ReleaseMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
virtual void GlobalShuffle(); virtual void GlobalShuffle(int thread_num = -1);
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {} virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace, virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {} std::vector<Record>* result) {}
...@@ -181,6 +186,9 @@ class DatasetImpl : public Dataset { ...@@ -181,6 +186,9 @@ class DatasetImpl : public Dataset {
virtual void CreatePreLoadReaders(); virtual void CreatePreLoadReaders();
virtual void DestroyPreLoadReaders(); virtual void DestroyPreLoadReaders();
virtual void SetPreLoadThreadNum(int thread_num); virtual void SetPreLoadThreadNum(int thread_num);
virtual void DynamicAdjustChannelNum(int channel_num);
virtual void DynamicAdjustReadersNum(int thread_num);
virtual void SetFleetSendSleepSeconds(int seconds);
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
...@@ -217,6 +225,8 @@ class DatasetImpl : public Dataset { ...@@ -217,6 +225,8 @@ class DatasetImpl : public Dataset {
std::vector<std::string> merge_slots_list_; std::vector<std::string> merge_slots_list_;
bool slots_shuffle_fea_eval_ = false; bool slots_shuffle_fea_eval_ = false;
int preload_thread_num_; int preload_thread_num_;
std::mutex global_index_mutex_;
int64_t global_index_ = 0;
}; };
// use std::vector<MultiSlotType> or Record as data type // use std::vector<MultiSlotType> or Record as data type
......
...@@ -148,11 +148,18 @@ void DistMultiTrainer::Finalize() { ...@@ -148,11 +148,18 @@ void DistMultiTrainer::Finalize() {
if (root_tensor->numel() != thread_tensor->numel()) { if (root_tensor->numel() != thread_tensor->numel()) {
continue; continue;
} }
#define MergeCallback(cpp_type, proto_type) \ #define MergeCallback(cpp_type, proto_type) \
do { \ do { \
if (root_tensor->type() == proto_type) { \ if (root_tensor->type() == proto_type) { \
MergeToRootScope<cpp_type>(root_tensor, thread_tensor); \ if (thread_tensor->type() != proto_type) { \
} \ VLOG(0) << "Error: thread id=" << j << ", need_merge_var_names_[" << i \
<< "] " << need_merge_var_names_[i] \
<< ", root tensor type=" << root_tensor->type() \
<< ", thread tensor type=" << thread_tensor->type(); \
exit(-1); \
} \
MergeToRootScope<cpp_type>(root_tensor, thread_tensor); \
} \
} while (0) } while (0)
_ForEachDataType_(MergeCallback); _ForEachDataType_(MergeCallback);
} }
...@@ -163,6 +170,10 @@ void DistMultiTrainer::Finalize() { ...@@ -163,6 +170,10 @@ void DistMultiTrainer::Finalize() {
} }
pull_dense_worker_->Stop(); pull_dense_worker_->Stop();
root_scope_->DropKids(); root_scope_->DropKids();
// flush local client push queue
auto fleet_ptr_ = FleetWrapper::GetInstance();
fleet_ptr_->ClientFlush();
} }
template <typename T> template <typename T>
......
...@@ -66,6 +66,14 @@ paddle::ps::Archive<AR>& operator>>(paddle::ps::Archive<AR>& ar, ...@@ -66,6 +66,14 @@ paddle::ps::Archive<AR>& operator>>(paddle::ps::Archive<AR>& ar,
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL; std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
#endif #endif
void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms,
int connect_timeout_ms,
int max_retry) {
client2client_request_timeout_ms_ = request_timeout_ms;
client2client_connect_timeout_ms_ = connect_timeout_ms;
client2client_max_retry_ = max_retry;
}
void FleetWrapper::InitServer(const std::string& dist_desc, int index) { void FleetWrapper::InitServer(const std::string& dist_desc, int index) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) { if (!is_initialized_) {
...@@ -142,7 +150,9 @@ std::vector<uint64_t> FleetWrapper::GetClientsInfo() { ...@@ -142,7 +150,9 @@ std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
void FleetWrapper::CreateClient2ClientConnection() { void FleetWrapper::CreateClient2ClientConnection() {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to create client2client connection"; VLOG(3) << "Going to create client2client connection";
pslib_ptr_->create_client2client_connection(); pslib_ptr_->create_client2client_connection(client2client_request_timeout_ms_,
client2client_connect_timeout_ms_,
client2client_max_retry_);
#endif #endif
} }
...@@ -344,7 +354,9 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -344,7 +354,9 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
slot = boost::lexical_cast<int>(sparse_key_names[i]); slot = boost::lexical_cast<int>(sparse_key_names[i]);
} }
Variable* g_var = scope.FindVar(sparse_grad_names[i]); Variable* g_var = scope.FindVar(sparse_grad_names[i]);
CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found"; if (g_var == nullptr) {
continue;
}
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>(); LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == nullptr) { if (g_tensor == nullptr) {
LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null"; LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null";
......
...@@ -59,7 +59,17 @@ class FleetWrapper { ...@@ -59,7 +59,17 @@ class FleetWrapper {
scale_sparse_gradient_with_batch_size_ = true; scale_sparse_gradient_with_batch_size_ = true;
// trainer sleep some time for pslib core dump // trainer sleep some time for pslib core dump
sleep_seconds_before_fail_exit_ = 300; sleep_seconds_before_fail_exit_ = 300;
// pslib request server timeout ms
client2client_request_timeout_ms_ = 500000;
// pslib connect server timeout_ms
client2client_connect_timeout_ms_ = 10000;
// pslib request max retry
client2client_max_retry_ = 3;
} }
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry);
// Pull sparse variables from server in Sync mode // Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys // Param<in>: scope, table_id, var_names, fea_keys
// Param<out>: fea_values // Param<out>: fea_values
...@@ -200,6 +210,9 @@ class FleetWrapper { ...@@ -200,6 +210,9 @@ class FleetWrapper {
static bool is_initialized_; static bool is_initialized_;
bool scale_sparse_gradient_with_batch_size_; bool scale_sparse_gradient_with_batch_size_;
int32_t sleep_seconds_before_fail_exit_; int32_t sleep_seconds_before_fail_exit_;
int client2client_request_timeout_ms_;
int client2client_connect_timeout_ms_;
int client2client_max_retry_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper); DISABLE_COPY_AND_ASSIGN(FleetWrapper);
}; };
......
...@@ -124,6 +124,9 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -124,6 +124,9 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker {
"'epsilon' should be between 0.0 and 0.001."); "'epsilon' should be between 0.0 and 0.001.");
}); });
AddAttr<std::string>("data_layout", "").SetDefault("NCHW"); AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddInput("X", "The input tensor"); AddInput("X", "The input tensor");
AddInput("BatchSize", AddInput("BatchSize",
"BatchSize is a 1-dimensional tensor of size C " "BatchSize is a 1-dimensional tensor of size C "
...@@ -224,7 +227,6 @@ class DataNormGradOp : public framework::OperatorWithKernel { ...@@ -224,7 +227,6 @@ class DataNormGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Scales"), ""); PADDLE_ENFORCE(ctx->HasInput("Scales"), "");
// check output // check output
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSize")), ""); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSize")), "");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSum")), ""); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSum")), "");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSquareSum")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSquareSum")),
...@@ -237,7 +239,9 @@ class DataNormGradOp : public framework::OperatorWithKernel { ...@@ -237,7 +239,9 @@ class DataNormGradOp : public framework::OperatorWithKernel {
(data_layout == DataLayout::kNCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
ctx->SetOutputDim(framework::GradVarName("BatchSize"), {C}); ctx->SetOutputDim(framework::GradVarName("BatchSize"), {C});
ctx->SetOutputDim(framework::GradVarName("BatchSum"), {C}); ctx->SetOutputDim(framework::GradVarName("BatchSum"), {C});
ctx->SetOutputDim(framework::GradVarName("BatchSquareSum"), {C}); ctx->SetOutputDim(framework::GradVarName("BatchSquareSum"), {C});
...@@ -304,7 +308,10 @@ class DataNormGradKernel<platform::CPUDeviceContext, T> ...@@ -304,7 +308,10 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
// init output // init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X")); Tensor *d_x = nullptr;
if (ctx.HasOutput(framework::GradVarName("X"))) {
d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
}
auto *d_batch_size = auto *d_batch_size =
ctx.Output<Tensor>(framework::GradVarName("BatchSize")); ctx.Output<Tensor>(framework::GradVarName("BatchSize"));
auto *d_batch_sum = ctx.Output<Tensor>(framework::GradVarName("BatchSum")); auto *d_batch_sum = ctx.Output<Tensor>(framework::GradVarName("BatchSum"));
...@@ -331,10 +338,12 @@ class DataNormGradKernel<platform::CPUDeviceContext, T> ...@@ -331,10 +338,12 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
ConstEigenVectorArrayMap<T> means_arr(means->data<T>(), C); ConstEigenVectorArrayMap<T> means_arr(means->data<T>(), C);
ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N); ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N); ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C, N); if (d_x != nullptr) {
d_x_arr.setZero(); EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C, N);
for (int nc = 0; nc < N; ++nc) { d_x_arr.setZero();
d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr; for (int nc = 0; nc < N; ++nc) {
d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr;
}
} }
// calculate data sum and squre sum // calculate data sum and squre sum
......
...@@ -257,6 +257,15 @@ void BindDataset(py::module *m) { ...@@ -257,6 +257,15 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("destroy_preload_readers", .def("destroy_preload_readers",
&framework::Dataset::DestroyPreLoadReaders, &framework::Dataset::DestroyPreLoadReaders,
py::call_guard<py::gil_scoped_release>())
.def("dynamic_adjust_channel_num",
&framework::Dataset::DynamicAdjustChannelNum,
py::call_guard<py::gil_scoped_release>())
.def("dynamic_adjust_readers_num",
&framework::Dataset::DynamicAdjustReadersNum,
py::call_guard<py::gil_scoped_release>())
.def("set_fleet_send_sleep_seconds",
&framework::Dataset::SetFleetSendSleepSeconds,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper") py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper")
......
...@@ -65,7 +65,9 @@ void BindFleetWrapper(py::module* m) { ...@@ -65,7 +65,9 @@ void BindFleetWrapper(py::module* m) {
.def("client_flush", &framework::FleetWrapper::ClientFlush) .def("client_flush", &framework::FleetWrapper::ClientFlush)
.def("load_from_paddle_model", .def("load_from_paddle_model",
&framework::FleetWrapper::LoadFromPaddleModel) &framework::FleetWrapper::LoadFromPaddleModel)
.def("load_model_one_table", &framework::FleetWrapper::LoadModelOneTable); .def("load_model_one_table", &framework::FleetWrapper::LoadModelOneTable)
.def("set_client2client_config",
&framework::FleetWrapper::SetClient2ClientConfig);
} // end FleetWrapper } // end FleetWrapper
} // end namespace pybind } // end namespace pybind
} // end namespace paddle } // end namespace paddle
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""This is defination of dataset class, which is high performance IO."""
from paddle.fluid.proto import data_feed_pb2 from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
...@@ -70,7 +71,7 @@ class DatasetBase(object): ...@@ -70,7 +71,7 @@ class DatasetBase(object):
self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat" self.proto_desc.pipe_command = "cat"
self.dataset = core.Dataset("MultiSlotDataset") self.dataset = core.Dataset("MultiSlotDataset")
self.thread_num = 0 self.thread_num = 1
self.filelist = [] self.filelist = []
def set_pipe_command(self, pipe_command): def set_pipe_command(self, pipe_command):
...@@ -265,6 +266,12 @@ class DatasetBase(object): ...@@ -265,6 +266,12 @@ class DatasetBase(object):
""" """
return text_format.MessageToString(self.proto_desc) return text_format.MessageToString(self.proto_desc)
def _dynamic_adjust_before_train(self, thread_num):
pass
def _dynamic_adjust_after_train(self):
pass
class InMemoryDataset(DatasetBase): class InMemoryDataset(DatasetBase):
""" """
...@@ -281,19 +288,19 @@ class InMemoryDataset(DatasetBase): ...@@ -281,19 +288,19 @@ class InMemoryDataset(DatasetBase):
super(InMemoryDataset, self).__init__() super(InMemoryDataset, self).__init__()
self.proto_desc.name = "MultiSlotInMemoryDataFeed" self.proto_desc.name = "MultiSlotInMemoryDataFeed"
self.fleet_send_batch_size = None self.fleet_send_batch_size = None
self.is_user_set_queue_num = False
self.queue_num = None self.queue_num = None
self.parse_ins_id = False self.parse_ins_id = False
self.parse_content = False self.parse_content = False
self.merge_by_lineid = False self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None
def _prepare_to_run(self): def _prepare_to_run(self):
""" """
Set data_feed_desc before load or shuffle, Set data_feed_desc before load or shuffle,
user no need to call this function. user no need to call this function.
""" """
if self.thread_num > len(self.filelist): if self.thread_num <= 0:
self.thread_num = len(self.filelist)
if self.thread_num == 0:
self.thread_num = 1 self.thread_num = 1
self.dataset.set_thread_num(self.thread_num) self.dataset.set_thread_num(self.thread_num)
if self.queue_num is None: if self.queue_num is None:
...@@ -305,6 +312,16 @@ class InMemoryDataset(DatasetBase): ...@@ -305,6 +312,16 @@ class InMemoryDataset(DatasetBase):
self.dataset.create_channel() self.dataset.create_channel()
self.dataset.create_readers() self.dataset.create_readers()
def _dynamic_adjust_before_train(self, thread_num):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(thread_num)
self.dataset.dynamic_adjust_readers_num(thread_num)
def _dynamic_adjust_after_train(self):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(self.thread_num)
self.dataset.dynamic_adjust_readers_num(self.thread_num)
def set_queue_num(self, queue_num): def set_queue_num(self, queue_num):
""" """
Set Dataset output queue num, training threads get data from queues Set Dataset output queue num, training threads get data from queues
...@@ -320,6 +337,7 @@ class InMemoryDataset(DatasetBase): ...@@ -320,6 +337,7 @@ class InMemoryDataset(DatasetBase):
dataset.set_queue_num(12) dataset.set_queue_num(12)
""" """
self.is_user_set_queue_num = True
self.queue_num = queue_num self.queue_num = queue_num
def set_parse_ins_id(self, parse_ins_id): def set_parse_ins_id(self, parse_ins_id):
...@@ -356,9 +374,9 @@ class InMemoryDataset(DatasetBase): ...@@ -356,9 +374,9 @@ class InMemoryDataset(DatasetBase):
""" """
self.parse_content = parse_content self.parse_content = parse_content
def set_fleet_send_batch_size(self, fleet_send_batch_size): def set_fleet_send_batch_size(self, fleet_send_batch_size=1024):
""" """
Set fleet send batch size, default is 80000 Set fleet send batch size, default is 1024
Args: Args:
fleet_send_batch_size(int): fleet send batch size fleet_send_batch_size(int): fleet send batch size
...@@ -373,6 +391,23 @@ class InMemoryDataset(DatasetBase): ...@@ -373,6 +391,23 @@ class InMemoryDataset(DatasetBase):
""" """
self.fleet_send_batch_size = fleet_send_batch_size self.fleet_send_batch_size = fleet_send_batch_size
def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0):
"""
Set fleet send sleep time, default is 0
Args:
fleet_send_sleep_seconds(int): fleet send sleep time
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_fleet_send_sleep_seconds(2)
"""
self.fleet_send_sleep_seconds = fleet_send_sleep_seconds
def set_merge_by_lineid(self, def set_merge_by_lineid(self,
var_list, var_list,
erase_duplicate_feas=True, erase_duplicate_feas=True,
...@@ -480,7 +515,7 @@ class InMemoryDataset(DatasetBase): ...@@ -480,7 +515,7 @@ class InMemoryDataset(DatasetBase):
""" """
self.dataset.local_shuffle() self.dataset.local_shuffle()
def global_shuffle(self, fleet=None): def global_shuffle(self, fleet=None, thread_num=12):
""" """
Global shuffle. Global shuffle.
Global shuffle can be used only in distributed mode. i.e. multiple Global shuffle can be used only in distributed mode. i.e. multiple
...@@ -500,6 +535,7 @@ class InMemoryDataset(DatasetBase): ...@@ -500,6 +535,7 @@ class InMemoryDataset(DatasetBase):
Args: Args:
fleet(Fleet): fleet singleton. Default None. fleet(Fleet): fleet singleton. Default None.
thread_num(int): shuffle thread num. Default is 12.
""" """
trainer_num = 1 trainer_num = 1
...@@ -507,13 +543,16 @@ class InMemoryDataset(DatasetBase): ...@@ -507,13 +543,16 @@ class InMemoryDataset(DatasetBase):
fleet._role_maker._barrier_worker() fleet._role_maker._barrier_worker()
trainer_num = fleet.worker_num() trainer_num = fleet.worker_num()
if self.fleet_send_batch_size is None: if self.fleet_send_batch_size is None:
self.fleet_send_batch_size = 800 * trainer_num self.fleet_send_batch_size = 1024
if self.fleet_send_sleep_seconds is None:
self.fleet_send_sleep_seconds = 0
self.dataset.register_client2client_msg_handler() self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num) self.dataset.set_trainer_num(trainer_num)
self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size) self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size)
self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds)
if fleet is not None: if fleet is not None:
fleet._role_maker._barrier_worker() fleet._role_maker._barrier_worker()
self.dataset.global_shuffle() self.dataset.global_shuffle(thread_num)
if fleet is not None: if fleet is not None:
fleet._role_maker._barrier_worker() fleet._role_maker._barrier_worker()
if self.merge_by_lineid: if self.merge_by_lineid:
...@@ -666,6 +705,9 @@ class QueueDataset(DatasetBase): ...@@ -666,6 +705,9 @@ class QueueDataset(DatasetBase):
dataset = fluid.DatasetFactory().create_dataset("QueueDataset") dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
dataset.local_shuffle() dataset.local_shuffle()
Raises:
NotImplementedError: QueueDataset does not support local shuffle
""" """
raise NotImplementedError( raise NotImplementedError(
"QueueDataset does not support local shuffle, " "QueueDataset does not support local shuffle, "
...@@ -689,6 +731,9 @@ class QueueDataset(DatasetBase): ...@@ -689,6 +731,9 @@ class QueueDataset(DatasetBase):
dataset = fluid.DatasetFactory().create_dataset("QueueDataset") dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
dataset.global_shuffle(fleet) dataset.global_shuffle(fleet)
Raises:
NotImplementedError: QueueDataset does not support global shuffle
""" """
raise NotImplementedError( raise NotImplementedError(
"QueueDataset does not support global shuffle, " "QueueDataset does not support global shuffle, "
...@@ -708,14 +753,16 @@ class FileInstantDataset(DatasetBase): ...@@ -708,14 +753,16 @@ class FileInstantDataset(DatasetBase):
def __init__(self): def __init__(self):
""" """
Init Initialize FileInstantDataset
This class should be created by DatasetFactory
""" """
super(FileInstantDataset, self).__init__() super(FileInstantDataset, self).__init__()
self.proto_desc.name = "MultiSlotFileInstantDataFeed" self.proto_desc.name = "MultiSlotFileInstantDataFeed"
def local_shuffle(self): def local_shuffle(self):
""" """
Local shuffle, FileInstantDataset does not support local shuffle Local shuffle
FileInstantDataset does not support local shuffle
""" """
raise NotImplementedError( raise NotImplementedError(
"FileInstantDataset does not support local shuffle, " "FileInstantDataset does not support local shuffle, "
...@@ -724,6 +771,7 @@ class FileInstantDataset(DatasetBase): ...@@ -724,6 +771,7 @@ class FileInstantDataset(DatasetBase):
def global_shuffle(self, fleet=None): def global_shuffle(self, fleet=None):
""" """
Global shuffle Global shuffle
FileInstantDataset does not support global shuffle
""" """
raise NotImplementedError( raise NotImplementedError(
"FileInstantDataset does not support global shuffle, " "FileInstantDataset does not support global shuffle, "
...@@ -743,26 +791,30 @@ class BoxPSDataset(InMemoryDataset): ...@@ -743,26 +791,30 @@ class BoxPSDataset(InMemoryDataset):
def __init__(self): def __init__(self):
""" """
Init Initialize BoxPSDataset
This class should be created by DatasetFactory
""" """
super(BoxPSDataset, self).__init__() super(BoxPSDataset, self).__init__()
self.boxps = core.BoxPS(self.dataset) self.boxps = core.BoxPS(self.dataset)
def begin_pass(self): def begin_pass(self):
""" """
Notify BoxPS to begin next pass Begin Pass
Notify BoxPS to begin next pass
""" """
self.boxps.begin_pass() self.boxps.begin_pass()
def end_pass(self): def end_pass(self):
""" """
Notify BoxPS to end current pass End Pass
Notify BoxPS to end current pass
""" """
self.boxps.end_pass() self.boxps.end_pass()
def wait_preload_done(self): def wait_preload_done(self):
""" """
Wait async proload done Wait async proload done
Wait Until Feed Pass Done
""" """
self.boxps.wait_feed_pass_done() self.boxps.wait_feed_pass_done()
......
...@@ -803,7 +803,6 @@ class Executor(object): ...@@ -803,7 +803,6 @@ class Executor(object):
program.program._fleet_opt) program.program._fleet_opt)
trainer._set_program(program.program) trainer._set_program(program.program)
# The following thread_num-determined logic will be deprecated
if thread <= 0: if thread <= 0:
if dataset.thread_num <= 0: if dataset.thread_num <= 0:
raise RuntimeError( raise RuntimeError(
...@@ -889,9 +888,11 @@ class Executor(object): ...@@ -889,9 +888,11 @@ class Executor(object):
trainer._set_infer(True) trainer._set_infer(True)
trainer._gen_trainer_desc() trainer._gen_trainer_desc()
self._dump_debug_info(program=program, trainer=trainer) self._dump_debug_info(program=program, trainer=trainer)
dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num)
self._default_executor.run_from_dataset(program.desc, scope, self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset, dataset.dataset,
trainer._desc()) trainer._desc())
dataset._dynamic_adjust_after_train()
dataset._finish_to_run() dataset._finish_to_run()
return None return None
...@@ -973,8 +974,10 @@ class Executor(object): ...@@ -973,8 +974,10 @@ class Executor(object):
print_period=print_period) print_period=print_period)
trainer._gen_trainer_desc() trainer._gen_trainer_desc()
self._dump_debug_info(program=program, trainer=trainer) self._dump_debug_info(program=program, trainer=trainer)
dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num)
self._default_executor.run_from_dataset(program.desc, scope, self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset, dataset.dataset,
trainer._desc()) trainer._desc())
dataset._dynamic_adjust_after_train()
dataset._finish_to_run() dataset._finish_to_run()
return None return None
...@@ -32,11 +32,20 @@ class PSLib(Fleet): ...@@ -32,11 +32,20 @@ class PSLib(Fleet):
self._fleet_ptr = None self._fleet_ptr = None
self._main_programs = [] self._main_programs = []
self._scopes = [] self._scopes = []
self._client2client_request_timeout_ms = 500000
self._client2client_connect_timeout_ms = 10000
self._client2client_max_retry = 3
def init(self, role_maker=None): def init(self, role_maker=None):
super(PSLib, self).init(MPISymetricRoleMaker()) super(PSLib, self).init(MPISymetricRoleMaker())
self._fleet_ptr = fluid.core.Fleet() self._fleet_ptr = fluid.core.Fleet()
def _set_client_communication_config(self, request_timeout_ms,
connect_timeout_ms, max_retry):
self._client2client_request_timeout_ms = request_timeout_ms
self._client2client_connect_timeout_ms = connect_timeout_ms
self._client2client_max_retry = max_retry
def init_worker(self): def init_worker(self):
""" """
init_worker(): will be called by user. When a user knows current process is_server(), he/she init_worker(): will be called by user. When a user knows current process is_server(), he/she
...@@ -72,6 +81,10 @@ class PSLib(Fleet): ...@@ -72,6 +81,10 @@ class PSLib(Fleet):
info = self._fleet_ptr.get_clients_info() info = self._fleet_ptr.get_clients_info()
all_info = self._role_maker._worker_gather(info[0]) all_info = self._role_maker._worker_gather(info[0])
self._fleet_ptr.gather_clients(all_info) self._fleet_ptr.gather_clients(all_info)
self._fleet_ptr.set_client2client_config(
self._client2client_request_timeout_ms,
self._client2client_connect_timeout_ms,
self._client2client_max_retry)
self._fleet_ptr.create_client2client_connection() self._fleet_ptr.create_client2client_connection()
# barrier for init model # barrier for init model
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
......
...@@ -311,14 +311,23 @@ class FleetUtil(object): ...@@ -311,14 +311,23 @@ class FleetUtil(object):
xbox_base_key, xbox_base_key,
data_path, data_path,
hadoop_fs_name, hadoop_fs_name,
monitor_data={}): monitor_data={},
mode="patch"):
xbox_dict = collections.OrderedDict() xbox_dict = collections.OrderedDict()
xbox_dict["id"] = str(int(time.time())) if mode == "base":
xbox_dict["id"] = str(xbox_base_key)
elif mode == "patch":
xbox_dict["id"] = str(int(time.time()))
else:
print("warning: unknown mode %s, set it to patch" % mode)
mode = "patch"
xbox_dict["id"] = str(int(time.time()))
xbox_dict["key"] = str(xbox_base_key) xbox_dict["key"] = str(xbox_base_key)
if model_path.startswith("hdfs:") or model_path.startswith("afs:"): if model_path.startswith("hdfs:") or model_path.startswith("afs:"):
model_path = model_path[model_path.find(":") + 1:] model_path = model_path[model_path.find(":") + 1:]
xbox_dict["input"] = hadoop_fs_name + model_path.rstrip("/") + "/000" xbox_dict["input"] = hadoop_fs_name + model_path.rstrip("/") + "/000"
xbox_dict["record_count"] = "111111" xbox_dict["record_count"] = "111111"
xbox_dict["partition_type"] = "2"
xbox_dict["job_name"] = "default_job_name" xbox_dict["job_name"] = "default_job_name"
xbox_dict["ins_tag"] = "feasign" xbox_dict["ins_tag"] = "feasign"
xbox_dict["ins_path"] = data_path xbox_dict["ins_path"] = data_path
...@@ -477,13 +486,16 @@ class FleetUtil(object): ...@@ -477,13 +486,16 @@ class FleetUtil(object):
day = str(day) day = str(day)
pass_id = str(pass_id) pass_id = str(pass_id)
xbox_base_key = int(xbox_base_key) xbox_base_key = int(xbox_base_key)
mode = None
if pass_id != "-1": if pass_id != "-1":
mode = "patch"
suffix_name = "/%s/delta-%s/" % (day, pass_id) suffix_name = "/%s/delta-%s/" % (day, pass_id)
model_path = output_path.rstrip("/") + suffix_name model_path = output_path.rstrip("/") + suffix_name
if donefile_name is None: if donefile_name is None:
donefile_name = "xbox_patch_done.txt" donefile_name = "xbox_patch_done.txt"
else: else:
mode = "base"
suffix_name = "/%s/base/" % day suffix_name = "/%s/base/" % day
model_path = output_path.rstrip("/") + suffix_name model_path = output_path.rstrip("/") + suffix_name
if donefile_name is None: if donefile_name is None:
...@@ -495,7 +507,8 @@ class FleetUtil(object): ...@@ -495,7 +507,8 @@ class FleetUtil(object):
if fleet.worker_index() == 0: if fleet.worker_index() == 0:
donefile_path = output_path + "/" + donefile_name donefile_path = output_path + "/" + donefile_name
xbox_str = self._get_xbox_str(output_path, day, model_path, \ xbox_str = self._get_xbox_str(output_path, day, model_path, \
xbox_base_key, data_path, hadoop_fs_name, monitor_data={}) xbox_base_key, data_path, hadoop_fs_name, monitor_data={},
mode=mode)
configs = { configs = {
"fs.default.name": hadoop_fs_name, "fs.default.name": hadoop_fs_name,
"hadoop.job.ugi": hadoop_fs_ugi "hadoop.job.ugi": hadoop_fs_ugi
......
...@@ -24,10 +24,20 @@ import copy ...@@ -24,10 +24,20 @@ import copy
import errno import errno
import logging import logging
from paddle.fluid.log_helper import get_logger
__all__ = ["HDFSClient"] __all__ = ["HDFSClient"]
def get_logger(name, level, fmt):
logger = logging.getLogger(name)
logger.setLevel(level)
handler = logging.FileHandler('hdfs.log', mode='w')
formatter = logging.Formatter(fmt=fmt)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
_logger = get_logger( _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
...@@ -461,7 +471,7 @@ class HDFSClient(object): ...@@ -461,7 +471,7 @@ class HDFSClient(object):
procs = [] procs = []
for i in range(multi_processes): for i in range(multi_processes):
process_datas = HDFSClient.split_flies(all_files, i, process_datas = HDFSClient.split_files(all_files, i,
multi_processes) multi_processes)
p = multiprocessing.Process( p = multiprocessing.Process(
target=__subprocess_download, target=__subprocess_download,
...@@ -551,7 +561,7 @@ class HDFSClient(object): ...@@ -551,7 +561,7 @@ class HDFSClient(object):
procs = [] procs = []
for i in range(multi_processes): for i in range(multi_processes):
process_datas = HDFSClient.split_flies(all_files, i, process_datas = HDFSClient.split_files(all_files, i,
multi_processes) multi_processes)
p = multiprocessing.Process( p = multiprocessing.Process(
target=__subprocess_upload, args=( target=__subprocess_upload, args=(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is unit test of Test data_norm Op."""
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from op_test import OpTest
from paddle.fluid.framework import grad_var_name
def _reference_testing(x, batch_size, batch_sum, batch_square_sum):
x_shape = x.shape
means_arr = batch_sum / batch_size
scales_arr = np.sqrt(batch_size / batch_square_sum)
for i in range(x_shape[0]):
x[i] -= means_arr
x[i] *= scales_arr
y = np.array(x)
return y
def create_or_get_tensor(scope, var_name, var, place):
tensor = scope.var(var_name).get_tensor()
if var is not None:
assert isinstance(var, np.ndarray)
tensor.set_recursive_sequence_lengths([])
tensor.set(var, place)
return tensor
class TestDataNormOpInference(unittest.TestCase):
"""
test class for data norm op
test forward
"""
def setUp(self):
"""
init members of this class
"""
self.dtype = np.float32
self.use_mkldnn = False
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def check_with_place(self, place, data_layout, dtype, shape):
"""
do forward and check
Args:
place(Place): CPUPlace
data_layout(str): NCHW or NWHC
dtype(dtype): np.float32
shape(list): input shape
"""
epsilon = 0.00001
if len(shape) == 2:
x_shape = shape
c = x_shape[1]
else:
ValueError("len(shape) should be equal to 2")
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(dtype)
x_val = x_val - 0.5
batch_size = np.ones(scale_shape).astype(np.float32)
batch_size *= 1e4
batch_sum = np.zeros(scale_shape).astype(np.float32)
batch_square_sum = np.ones(scale_shape).astype(np.float32)
batch_square_sum *= 1e4
y_out = _reference_testing(x_val, batch_size, batch_sum,
batch_square_sum).astype(dtype)
scope = core.Scope()
# create input
x_tensor = create_or_get_tensor(scope, "x_val",
OpTest.np_dtype_to_fluid_dtype(x_val),
place)
batch_size_tensor = create_or_get_tensor(
scope, "batch_size",
OpTest.np_dtype_to_fluid_dtype(batch_size), place)
batch_sum_tensor = create_or_get_tensor(
scope, "batch_sum",
OpTest.np_dtype_to_fluid_dtype(batch_sum), place)
batch_square_sum_tensor = create_or_get_tensor(
scope, "batch_square_sum",
OpTest.np_dtype_to_fluid_dtype(batch_square_sum), place)
# create output
y_tensor = create_or_get_tensor(scope, "y_out", None, place)
mean_tensor = create_or_get_tensor(scope, "mean", None, place)
scales_tensor = create_or_get_tensor(scope, "scales", None, place)
data_norm_op = Operator(
"data_norm",
# inputs
X="x_val",
BatchSize="batch_size",
BatchSum="batch_sum",
BatchSquareSum="batch_square_sum",
# outputs
Y="y_out",
Means="mean",
Scales="scales",
# attrs
epsilon=epsilon,
use_mkldnn=self.use_mkldnn)
data_norm_op.run(scope, place)
# check inference result
self.__assert_close(
y_tensor,
y_out,
"inference output are different at " + str(place) + ", " +
data_layout + ", " + str(np.dtype(dtype)) +
str(np.array(y_tensor)) + str(y_out),
atol=1e-3)
def test_check_output(self):
"""
test check forward, check output
"""
places = [core.CPUPlace()]
for place in places:
for data_format in ["NCHW", "NHWC"]:
self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestDataNormOp(OpTest):
"""
test class for data norm op
test forward and backward
"""
def setUp(self):
"""
init data norm op test env
"""
self.op_type = 'data_norm'
self.use_mkldnn = False
epsilon = 0.00001
x_shape = [2, 3]
scale_shape = [3]
tp = np.float32
x_val = np.array([[-0.35702616, -0.42756206, -0.08306625],
[0.41199666, -0.21719968, -0.10180971]]).astype(tp)
batch_size = np.ones(scale_shape).astype(tp)
batch_size *= 1e4
batch_sum = np.zeros(scale_shape).astype(tp)
batch_square_sum = np.ones(scale_shape).astype(tp)
batch_square_sum *= 1e4
y = np.array(x_val)
mean = np.array([[0, 0, 0], [0, 0, 0]]).astype(tp)
scale = np.array([[1, 1, 1], [1, 1, 1]]).astype(tp)
self.inputs = {
"X": x_val,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum
}
self.outputs = {"Y": y, "Means": mean, "Scales": scale}
self.attrs = {"epsilon": epsilon, "use_mkldnn": self.use_mkldnn}
def test_check_output(self):
"""
test check forward, check output
"""
self.check_output()
def test_check_grad(self):
"""
test check backward, check grad
"""
self.check_grad(['X'], 'Y', no_grad_set=set([]))
if __name__ == '__main__':
unittest.main()
...@@ -237,6 +237,25 @@ class TestDataset(unittest.TestCase): ...@@ -237,6 +237,25 @@ class TestDataset(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda( exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)) ) else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for i in range(2):
try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
exe.train_from_dataset(
fluid.default_main_program(), dataset, thread=1)
exe.train_from_dataset(
fluid.default_main_program(), dataset, thread=2)
exe.train_from_dataset(
fluid.default_main_program(), dataset, thread=2)
exe.train_from_dataset(
fluid.default_main_program(), dataset, thread=3)
exe.train_from_dataset(
fluid.default_main_program(), dataset, thread=4)
except ImportError as e:
pass
except Exception as e:
self.assertTrue(False)
if self.use_data_loader: if self.use_data_loader:
data_loader = fluid.io.DataLoader.from_dataset(dataset, data_loader = fluid.io.DataLoader.from_dataset(dataset,
fluid.cpu_places(), fluid.cpu_places(),
...@@ -253,12 +272,14 @@ class TestDataset(unittest.TestCase): ...@@ -253,12 +272,14 @@ class TestDataset(unittest.TestCase):
self.assertTrue(False) self.assertTrue(False)
dataset.set_merge_by_lineid(slots_vars) dataset.set_merge_by_lineid(slots_vars)
dataset.set_fleet_send_sleep_seconds(2)
dataset.preload_into_memory() dataset.preload_into_memory()
dataset.wait_preload_done() dataset.wait_preload_done()
dataset.release_memory() dataset.release_memory()
dataset.preload_into_memory(1) dataset.preload_into_memory(1)
dataset.wait_preload_done() dataset.wait_preload_done()
fleet_ptr = fluid.core.Fleet() fleet_ptr = fluid.core.Fleet()
fleet_ptr.set_client2client_config(1, 1, 1)
os.remove("./test_in_memory_dataset_run_a.txt") os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt") os.remove("./test_in_memory_dataset_run_b.txt")
...@@ -311,6 +332,19 @@ class TestDataset(unittest.TestCase): ...@@ -311,6 +332,19 @@ class TestDataset(unittest.TestCase):
except Exception as e: except Exception as e:
self.assertTrue(False) self.assertTrue(False)
dataset2 = fluid.DatasetFactory().create_dataset("QueueDataset")
dataset2.set_use_var(slots_vars)
dataset2.set_batch_size(32)
dataset2.set_thread(3)
dataset2.set_pipe_command("cat")
dataset.set_filelist([])
try:
exe.train_from_dataset(fluid.default_main_program(), dataset2)
except ImportError as e:
print("warning: we skip trainer_desc_pb2 import problem in windows")
except Exception as e:
self.assertTrue(False)
os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt") os.remove("./test_queue_dataset_run_b.txt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册