提交 b384af58 编写于 作者: Y Yu Yang

Fix bugs in CustomStackTrace.

* Make layer stack trace shows ThreadId, Forward or Backward.

Change-Id: Iba1477adb8c9115c3a67ff2959bb5c878ca706c7
上级 3304de7a
...@@ -40,3 +40,4 @@ HPPL_ERROR_LOG ...@@ -40,3 +40,4 @@ HPPL_ERROR_LOG
unittest.list unittest.list
proto proto
dist dist
setup.py
...@@ -277,6 +277,7 @@ void NeuralNetwork::getState(MachineState& machineState) { ...@@ -277,6 +277,7 @@ void NeuralNetwork::getState(MachineState& machineState) {
} }
void NeuralNetwork::backward(const UpdateCallback& callback) { void NeuralNetwork::backward(const UpdateCallback& callback) {
gLayerStackTrace.pop(""); // tell layer trace is during backward.
FOR_EACH_R(layer, layers_) { FOR_EACH_R(layer, layers_) {
REGISTER_TIMER_INFO("BackwardTimer", (*layer)->getName().c_str()); REGISTER_TIMER_INFO("BackwardTimer", (*layer)->getName().c_str());
if ((*layer)->needGradient()) { if ((*layer)->needGradient()) {
......
...@@ -14,9 +14,44 @@ limitations under the License. */ ...@@ -14,9 +14,44 @@ limitations under the License. */
#include "CustomStackTrace.h" #include "CustomStackTrace.h"
#include "CommandLineParser.h"
#include <iostream>
P_DEFINE_bool(layer_stack_error_only_current_thread,
true,
"Dump current thread or whole process layer stack when signal error "
"occurred. true means only dump current thread layer stack");
namespace paddle { namespace paddle {
CustomStackTrace<std::string> gLayerStackTrace; CustomStackTrace<std::string> gLayerStackTrace;
static std::mutex gLayerStackTraceMtx;
void installLayerStackTracer() {
logging::installFailureWriter([](const char* data, int sz) {
std::lock_guard<std::mutex> guard(gLayerStackTraceMtx);
if (!gLayerStackTrace.empty()) {
size_t curTid = -1UL;
std::hash<std::thread::id> hasher;
gLayerStackTrace.dump([&curTid, &hasher](std::thread::id tid,
bool* isForwarding,
const std::string& layerName) {
if (curTid != hasher(tid)) {
if (curTid != -1UL) {
std::cerr << std::endl;
}
curTid = hasher(tid);
std::cerr << "Thread [" << tid << "] ";
if (isForwarding) {
std::cerr << (*isForwarding ? "Forwarding ": "Backwarding ");
}
}
std::cerr << layerName << ", ";
}, FLAGS_layer_stack_error_only_current_thread);
std::cerr << std::endl;
}
std::cerr.write(data, sz);
});
}
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,9 @@ limitations under the License. */ ...@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <stack> #include <stack>
#include <thread>
#include <unordered_map>
#include <functional>
#include "ThreadLocal.h" #include "ThreadLocal.h"
...@@ -29,25 +32,18 @@ namespace paddle { ...@@ -29,25 +32,18 @@ namespace paddle {
* @code{.cpp} * @code{.cpp}
* *
* paddle::CustomStackTrace<std::string> stack; * paddle::CustomStackTrace<std::string> stack;
* PASS_TEST=0;
* for (auto& layer : layers){ * for (auto& layer : layers){
* stack.push(layer->getName()); * stack.push(layer->getName());
* layer->forward(passType); * layer->forward();
* } * }
* for (auto& layer : layers){ *
* stack.pop(""); // mark under pop stage.
*
* for (auto it = layers.rbegin(); it != layers.rend(); ++it){
* auto& layer = *it;
* layer->backward(passType); * layer->backward(passType);
* stack.pop(layer->getName()); * stack.pop(layer->getName());
* } * }
*
* if(passType == PASS_TEST) {
* stack.clear();
* }
* else {
* stack.dump([](const std::string& layername){
* LOG(INFO) << "LayerName: " << layername;
* })
* }
*
* *
* @endcode * @endcode
*/ */
...@@ -55,45 +51,141 @@ template <typename T> ...@@ -55,45 +51,141 @@ template <typename T>
class CustomStackTrace{ class CustomStackTrace{
public: public:
/** /**
* @brief Pop out an item from the top of the stack. For safety the item * @brief Pop out an item from the top of the stack if item == top.
* will be poped should equal to ip. * Else, just set status to popping.
*/ */
void pop(const T& ip) { void pop(const T& item) {
auto& p = *logstack_; pushing() = false;
CHECK_EQ(ip, p.top()); auto& s = this->stack();
p.pop(); if (item == s.top()) {
s.pop();
}
} }
/** /**
* @brief Empty the stack by sequence from top to button. * @brief clear current thread stack.
* @param[in] callback A function deal with each item while dumping.
* It must have and only have a in parameter which is the stack item.
*/ */
template <typename Callback> void clear() {
void dump(Callback callback) { auto& s = stack();
auto& p = *logstack_; while (!s.empty()) {
while (!p.empty()) { s.pop();
callback(p.top());
p.pop();
} }
} }
/** /**
* @brief Only empty the stack. * @brief return true if all thread's stack is empty.
* @return true if empty
*/ */
void clear() { bool empty() const {
dump([](const T& ip){}); std::lock_guard<std::mutex> g(this->mtx_);
for (auto p : this->stackBuffers_) {
std::stack<T>& s = *p.second;
if (!s.empty()) {
return false;
}
}
return true;
}
/**
* @brief DumpCallback Type. It will be invoked many times by dump method.
*
* The first parameter is stack thread id.
* The second parameter is the last action of stack is push or not.
* The third parameter is the item in stack.
*/
typedef std::function<void(const std::thread::id& /*threadId*/,
bool* /*isPushing*/,
const T& /*item*/)> DumpCallback;
/**
* Dump all thread stack, and all stack will be cleared.
*/
void dump(const DumpCallback& callback, bool onlyCurrentThread = false) {
std::lock_guard<std::mutex> g(this->mtx_);
for (auto p : this->stackBuffers_) {
std::thread::id tid = p.first;
if (onlyCurrentThread && tid != std::this_thread::get_id()) {
continue;
}
std::stack<T>& s = *p.second;
bool* isPush = nullptr;
auto it = this->pushingBuffers_.find(tid);
if (it != this->pushingBuffers_.end()) {
isPush = it->second;
}
while (!s.empty()) {
callback(tid, isPush, s.top());
s.pop();
}
}
} }
/** /**
* @brief Push item ip to the top of the stack. * @brief Push item to current thread stack.
*/ */
void push(const T& ip) { void push(const T& item) {
auto& p = *logstack_; pushing() = true;
p.push(ip); auto& p = this->stack();
p.push(item);
} }
private: private:
ThreadLocalD<std::stack<T> > logstack_; /**
* Get thread local attribute, and save them into a map (threadId => TYPE*)
*
* @tparam TYPE thread local attribute type.
* @param threadLocal Thread Local object.
* @param buffers a map from threadId to TYPE*
*/
template <typename TYPE>
inline TYPE& getThreadLocal(
ThreadLocal<TYPE>& threadLocal,
std::unordered_map<std::thread::id, TYPE*>& buffers) {
TYPE* retv = threadLocal.get(false);
if (retv) {
return *retv;
} else {
std::lock_guard<std::mutex> guard(this->mtx_);
retv = threadLocal.get();
auto id = std::this_thread::get_id();
buffers.insert({id, retv});
return *retv;
}
}
/**
* @brief Get thread local stack reference.
*/
std::stack<T>& stack() {
return this->getThreadLocal(this->logStack_,
this->stackBuffers_);
}
/**
* @brief Get thread local pushing flag.
*/
bool& pushing() {
return this->getThreadLocal(this->isPushing_,
this->pushingBuffers_);
}
private:
mutable std::mutex mtx_;
std::unordered_map<std::thread::id, std::stack<T>* > stackBuffers_;
std::unordered_map<std::thread::id, bool* > pushingBuffers_;
ThreadLocal<bool> isPushing_;
ThreadLocal<std::stack<T> > logStack_;
}; };
extern CustomStackTrace<std::string> gLayerStackTrace; extern CustomStackTrace<std::string> gLayerStackTrace;
/**
* @brief Install a failure handler to print layer stack when error.
*/
extern void installLayerStackTracer();
} // namespace paddle } // namespace paddle
...@@ -129,13 +129,7 @@ void runInitFunctions() { ...@@ -129,13 +129,7 @@ void runInitFunctions() {
void initMain(int argc, char** argv) { void initMain(int argc, char** argv) {
initializeLogging(argc, argv); initializeLogging(argc, argv);
logging::installFailureWriter([](const char* data, int sz) { installLayerStackTracer();
std::cerr << "Current Layer forward/backward stack is " << std::endl;
gLayerStackTrace.dump([](const std::string& layername){
std::cerr << "LayerName: " << layername << std::endl;
});
std::cerr.write(data, sz);
});
std::string line; std::string line;
for (int i = 0; i < argc; ++i) { for (int i = 0; i < argc; ++i) {
line += argv[i]; line += argv[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册