未验证 提交 5822862d 编写于 作者: H hutuxian 提交者: GitHub

Monitor Framework (#24079)

* Add a StatValue class in the backend to represent a stat.
* Add a singleton StatRegistry to maintain the collection of stats.
* For the sake of code neatness, we only support type of int and float, which can cover most of the scenarios.
上级 21138c05
......@@ -190,7 +190,7 @@ if(WITH_DISTRIBUTE)
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer)
graph_to_program_pass variable_helper data_feed_proto timer monitor)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
......@@ -200,7 +200,7 @@ else()
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer)
graph_to_program_pass variable_helper timer monitor)
# TODO: Fix these unittest failed on Windows
if(NOT WIN32)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
......
......@@ -35,8 +35,10 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/timer.h"
USE_INT_STAT(STAT_total_feasign_num_in_mem);
namespace paddle {
namespace framework {
......@@ -391,6 +393,12 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
writer << std::move(instance);
instance = T();
}
STAT_ADD(STAT_total_feasign_num_in_mem, fea_num_);
{
std::lock_guard<std::mutex> flock(*mutex_for_fea_num_);
*total_fea_num_ += fea_num_;
fea_num_ = 0;
}
writer.Flush();
timeline.Pause();
VLOG(3) << "LoadIntoMemory() read all lines, file=" << filename
......@@ -935,6 +943,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
}
instance->float_feasigns_.shrink_to_fit();
instance->uint64_feasigns_.shrink_to_fit();
fea_num_ += instance->uint64_feasigns_.size();
return true;
}
#else
......
......@@ -110,6 +110,8 @@ class DataFeed {
DataFeed() {
mutex_for_pick_file_ = nullptr;
file_idx_ = nullptr;
mutex_for_fea_num_ = nullptr;
total_fea_num_ = nullptr;
}
virtual ~DataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
......@@ -166,7 +168,9 @@ class DataFeed {
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
virtual void SetFeaNumMutex(std::mutex* mutex) { mutex_for_fea_num_ = mutex; }
virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
virtual void SetFeaNum(uint64_t* fea_num) { total_fea_num_ = fea_num; }
virtual const std::vector<std::string>& GetInsIdVec() const {
return ins_id_vec_;
}
......@@ -199,6 +203,9 @@ class DataFeed {
std::vector<std::string> filelist_;
size_t* file_idx_;
std::mutex* mutex_for_pick_file_;
std::mutex* mutex_for_fea_num_ = nullptr;
uint64_t* total_fea_num_ = nullptr;
uint64_t fea_num_ = 0;
// the alias of used slots, and its order is determined by
// data_feed_desc(proto object)
......
......@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/timer.h"
#include "xxhash.h" // NOLINT
......@@ -31,6 +32,7 @@
#define _LINUX
#endif
USE_INT_STAT(STAT_total_feasign_num_in_mem);
namespace paddle {
namespace framework {
......@@ -42,6 +44,7 @@ DatasetImpl<T>::DatasetImpl() {
trainer_num_ = 1;
channel_num_ = 1;
file_idx_ = 0;
total_fea_num_ = 0;
cur_channel_ = 0;
fleet_send_batch_size_ = 1024;
fleet_send_sleep_seconds_ = 0;
......@@ -330,6 +333,11 @@ void DatasetImpl<T>::ReleaseMemory() {
std::vector<T>().swap(input_records_);
std::vector<T>().swap(slots_shuffle_original_data_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
VLOG(3) << "total_feasign_num_(" << STAT_GET(STAT_total_feasign_num_in_mem)
<< ") - current_fea_num_(" << total_fea_num_ << ") = ("
<< STAT_GET(STAT_total_feasign_num_in_mem) - total_fea_num_
<< ")"; // For Debug
STAT_SUB(STAT_total_feasign_num_in_mem, total_fea_num_);
}
// do local shuffle
......@@ -618,6 +626,8 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetThreadNum(thread_num_);
readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
readers_[i]->SetFileListIndex(&file_idx_);
readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_);
readers_[i]->SetFeaNum(&total_fea_num_);
readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseContent(parse_content_);
......@@ -687,6 +697,8 @@ void DatasetImpl<T>::CreatePreLoadReaders() {
preload_readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
preload_readers_[i]->SetFileListIndex(&file_idx_);
preload_readers_[i]->SetFileList(filelist_);
preload_readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_);
preload_readers_[i]->SetFeaNum(&total_fea_num_);
preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetParseLogKey(parse_logkey_);
......
......@@ -255,7 +255,9 @@ class DatasetImpl : public Dataset {
int trainer_num_;
std::vector<std::string> filelist_;
size_t file_idx_;
uint64_t total_fea_num_;
std::mutex mutex_for_pick_file_;
std::mutex mutex_for_fea_num_;
std::string fs_name_;
std::string fs_ugi_;
int64_t fleet_send_batch_size_;
......
......@@ -35,6 +35,7 @@ if(WITH_GPU)
set(enforce_deps ${enforce_deps} cuda_error_proto)
endif()
cc_library(enforce INTERFACE SRCS enforce.cc DEPS ${enforce_deps})
cc_library(monitor SRCS monitor.cc)
cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece enforce)
set(CPU_INFO_DEPS gflags glog enforce)
......@@ -44,7 +45,7 @@ ENDIF()
cc_library(cpu_info SRCS cpu_info.cc DEPS ${CPU_INFO_DEPS})
cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce monitor)
cc_library(place SRCS place.cc DEPS enforce boost)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/string/split.h"
DECLARE_double(fraction_of_gpu_memory_to_use);
......@@ -33,6 +34,7 @@ DECLARE_uint64(gpu_memory_limit_mb);
constexpr static float fraction_reserve_gpu_memory = 0.05f;
USE_GPU_MEM_STAT;
namespace paddle {
namespace platform {
......@@ -364,6 +366,7 @@ class RecordedCudaMallocHelper {
if (NeedRecord()) {
cur_size_ += size;
}
STAT_INT_ADD("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
return cudaSuccess;
} else {
RaiseNonOutOfMemoryError(&result);
......@@ -392,6 +395,7 @@ class RecordedCudaMallocHelper {
std::lock_guard<std::mutex> guard(*mtx_);
cur_size_ -= size;
}
STAT_INT_SUB("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
} else {
cudaGetLastError(); // clear the error flag when cudaErrorCudartUnloading
}
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/monitor.h"
#include <utility>
namespace paddle {
namespace platform {} // namespace platform
} // namespace paddle
DEFINE_INT_STATUS(STAT_total_feasign_num_in_mem)
DEFINE_INT_STATUS(STAT_gpu0_mem_size)
DEFINE_INT_STATUS(STAT_gpu1_mem_size)
DEFINE_INT_STATUS(STAT_gpu2_mem_size)
DEFINE_INT_STATUS(STAT_gpu3_mem_size)
DEFINE_INT_STATUS(STAT_gpu4_mem_size)
DEFINE_INT_STATUS(STAT_gpu5_mem_size)
DEFINE_INT_STATUS(STAT_gpu6_mem_size)
DEFINE_INT_STATUS(STAT_gpu7_mem_size)
DEFINE_INT_STATUS(STAT_gpu8_mem_size)
DEFINE_INT_STATUS(STAT_gpu9_mem_size)
DEFINE_INT_STATUS(STAT_gpu10_mem_size)
DEFINE_INT_STATUS(STAT_gpu11_mem_size)
DEFINE_INT_STATUS(STAT_gpu12_mem_size)
DEFINE_INT_STATUS(STAT_gpu13_mem_size)
DEFINE_INT_STATUS(STAT_gpu14_mem_size)
DEFINE_INT_STATUS(STAT_gpu15_mem_size)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <atomic>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "glog/logging.h"
namespace paddle {
namespace platform {
template <typename T>
class StatRegistry;
class MonitorRegistrar {
public:
// The design is followed by OperatorRegistrar: To avoid the removal of global
// name by the linkerr, we add Touch to all StatValue classes and make
// USE_STAT macros to call this method. So, as long as the callee code calls
// USE_STAT, the global registrar variable won't be removed by the linker.
void Touch() {}
};
template <typename T>
class StatValue : public MonitorRegistrar {
T v_{0};
std::mutex mu_;
// We use lock rather than atomic for generic values
public:
explicit StatValue(const std::string& n) {
StatRegistry<T>::Instance().add(n, this);
}
T increase(T inc) {
std::lock_guard<std::mutex> lock(mu_);
return v_ += inc;
}
T decrease(T inc) {
std::lock_guard<std::mutex> lock(mu_);
return v_ -= inc;
}
T reset(T value = 0) {
std::lock_guard<std::mutex> lock(mu_);
return v_ = value;
}
T get() {
std::lock_guard<std::mutex> lock(mu_);
return v_;
}
};
template <typename T>
struct ExportedStatValue {
std::string key;
T value;
};
template <typename T>
class StatRegistry {
public:
~StatRegistry<T>() {}
static StatRegistry<T>& Instance() {
static StatRegistry<T> r;
return r;
}
StatValue<T>* get(const std::string& name) {
std::lock_guard<std::mutex> lg(mutex_);
auto it = stats_.find(name);
if (it != stats_.end()) {
return it->second;
} else {
return nullptr;
}
}
int add(const std::string& name, StatValue<T>* stat) {
std::lock_guard<std::mutex> lg(mutex_);
auto it = stats_.find(name);
if (it != stats_.end()) {
return -1;
}
stats_.insert(std::make_pair(name, stat));
return 0;
}
void publish(std::vector<ExportedStatValue<T>>& exported, // NOLINT
bool reset = false) {
std::lock_guard<std::mutex> lg(mutex_);
exported.resize(stats_.size());
int i = 0;
for (const auto& kv : stats_) {
auto& out = exported.at(i++);
out.key = kv.first;
out.value = reset ? kv.second->reset() : kv.second->get();
}
}
std::vector<ExportedStatValue<T>> publish(bool reset = false) {
std::vector<ExportedStatValue<T>> stats;
publish(stats, reset);
return stats;
}
private:
std::mutex mutex_;
std::unordered_map<std::string, StatValue<T>*> stats_;
};
} // namespace platform
} // namespace paddle
#define STAT_ADD(item, t) _##item.increase(t)
#define STAT_SUB(item, t) _##item.decrease(t)
// Support add stat value by string
#define STAT_INT_ADD(item, t) \
paddle::platform::StatRegistry<int64_t>::Instance().get(item)->increase(t)
#define STAT_INT_SUB(item, t) \
paddle::platform::StatRegistry<int64_t>::Instance().get(item)->decrease(t)
#define STAT_FLOAT_ADD(item, t) \
paddle::platform::StatRegistry<float>::Instance().get(item)->increase(t)
#define STAT_FLOAT_SUB(item, t) \
paddle::platform::StatRegistry<float>::Instance().get(item)->decrease(t)
#define STAT_RESET(item, t) _##item.reset(t)
#define STAT_GET(item) _##item.get()
#define DEFINE_FLOAT_STATUS(item) \
paddle::platform::StatValue<float> _##item(#item); \
int TouchStatRegistrar_##item() { \
_##item.Touch(); \
return 0; \
}
#define DEFINE_INT_STATUS(item) \
paddle::platform::StatValue<int64_t> _##item(#item); \
int TouchStatRegistrar_##item() { \
_##item.Touch(); \
return 0; \
}
#define USE_STAT(item) \
extern int TouchStatRegistrar_##item(); \
UNUSED static int use_stat_##item = TouchStatRegistrar_##item()
#define USE_INT_STAT(item) \
extern paddle::platform::StatValue<int64_t> _##item; \
USE_STAT(item)
#define USE_FLOAT_STAT(item) \
extern paddle::platform::StatValue<float> _##item; \
USE_STAT(item)
#define USE_GPU_MEM_STAT \
USE_INT_STAT(STAT_gpu0_mem_size); \
USE_INT_STAT(STAT_gpu1_mem_size); \
USE_INT_STAT(STAT_gpu2_mem_size); \
USE_INT_STAT(STAT_gpu3_mem_size); \
USE_INT_STAT(STAT_gpu4_mem_size); \
USE_INT_STAT(STAT_gpu5_mem_size); \
USE_INT_STAT(STAT_gpu6_mem_size); \
USE_INT_STAT(STAT_gpu7_mem_size); \
USE_INT_STAT(STAT_gpu8_mem_size); \
USE_INT_STAT(STAT_gpu9_mem_size); \
USE_INT_STAT(STAT_gpu10_mem_size); \
USE_INT_STAT(STAT_gpu11_mem_size); \
USE_INT_STAT(STAT_gpu12_mem_size); \
USE_INT_STAT(STAT_gpu13_mem_size); \
USE_INT_STAT(STAT_gpu14_mem_size); \
USE_INT_STAT(STAT_gpu15_mem_size)
......@@ -56,6 +56,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/box_helper_py.h"
......@@ -1536,6 +1537,25 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
m.def("is_compiled_with_brpc", IsCompiledWithBrpc);
m.def("is_compiled_with_dist", IsCompiledWithDIST);
m.def("get_float_stats", []() {
std::vector<paddle::platform::ExportedStatValue<float>> float_stats;
paddle::platform::StatRegistry<float>::Instance().publish(float_stats);
std::unordered_map<std::string, float> stats_map;
for (const auto &stat : float_stats) {
stats_map[stat.key] = stat.value;
}
return stats_map;
});
m.def("get_int_stats", []() {
std::vector<paddle::platform::ExportedStatValue<int64_t>> int_stats;
paddle::platform::StatRegistry<int64_t>::Instance().publish(int_stats);
std::unordered_map<std::string, int64_t> stats_map;
for (const auto &stat : int_stats) {
stats_map[stat.key] = stat.value;
}
return stats_map;
});
m.def("run_cmd",
[](const std::string &cmd, int time_out = -1,
int sleep_inter = -1) -> const std::string {
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TestCases for Monitor
"""
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import os
import unittest
class TestDatasetWithStat(unittest.TestCase):
""" TestCases for Dataset. """
def setUp(self):
self.use_data_loader = False
self.epoch_num = 10
self.drop_last = False
def test_dataset_run_with_stat(self):
with open("test_in_memory_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_in_memory_dataset_run_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(
name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist([
"test_in_memory_dataset_run_a.txt",
"test_in_memory_dataset_run_b.txt"
])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
dataset.set_fea_eval(1, True)
dataset.slots_shuffle(["slot1"])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
if self.use_data_loader:
data_loader = fluid.io.DataLoader.from_dataset(dataset,
fluid.cpu_places(),
self.drop_last)
for i in range(self.epoch_num):
for data in data_loader():
exe.run(fluid.default_main_program(), feed=data)
else:
for i in range(self.epoch_num):
try:
exe.train_from_dataset(fluid.default_main_program(),
dataset)
except Exception as e:
self.assertTrue(False)
int_stat = core.get_int_stats()
# total 56 keys
print(int_stat["STAT_total_feasign_num_in_mem"])
os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册