未验证 提交 9d2bd0ac 编写于 作者: 1 123malin 提交者: GitHub

downpour_worker增加try_catch机制,打印program所有参数 (#24700)

* test=develop, add try_catch for debug
上级 b2ba830e
...@@ -25,7 +25,7 @@ void DeviceWorker::SetDataFeed(DataFeed* data_feed) { ...@@ -25,7 +25,7 @@ void DeviceWorker::SetDataFeed(DataFeed* data_feed) {
} }
template <typename T> template <typename T>
std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) { std::string PrintLodTensorType(Tensor* tensor, int64_t start, int64_t end) {
auto count = tensor->numel(); auto count = tensor->numel();
if (start < 0 || end > count) { if (start < 0 || end > count) {
VLOG(3) << "access violation"; VLOG(3) << "access violation";
...@@ -38,8 +38,7 @@ std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) { ...@@ -38,8 +38,7 @@ std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) {
return os.str(); return os.str();
} }
std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start, std::string PrintLodTensorIntType(Tensor* tensor, int64_t start, int64_t end) {
int64_t end) {
auto count = tensor->numel(); auto count = tensor->numel();
if (start < 0 || end > count) { if (start < 0 || end > count) {
VLOG(3) << "access violation"; VLOG(3) << "access violation";
...@@ -52,7 +51,7 @@ std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start, ...@@ -52,7 +51,7 @@ std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
return os.str(); return os.str();
} }
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) { std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end) {
std::string out_val; std::string out_val;
if (tensor->type() == proto::VarType::FP32) { if (tensor->type() == proto::VarType::FP32) {
out_val = PrintLodTensorType<float>(tensor, start, end); out_val = PrintLodTensorType<float>(tensor, start, end);
......
...@@ -45,7 +45,7 @@ limitations under the License. */ ...@@ -45,7 +45,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end); std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end);
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index); std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size); bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);
...@@ -171,6 +171,7 @@ class DeviceWorker { ...@@ -171,6 +171,7 @@ class DeviceWorker {
bool need_dump_field_; bool need_dump_field_;
const std::vector<std::string>* dump_param_; const std::vector<std::string>* dump_param_;
const std::vector<std::string>* dump_fields_; const std::vector<std::string>* dump_fields_;
std::vector<std::string> all_param_;
int dump_mode_ = 0; int dump_mode_ = 0;
int dump_interval_ = 10000; int dump_interval_ = 10000;
......
...@@ -771,7 +771,50 @@ void DownpourWorker::TrainFiles() { ...@@ -771,7 +771,50 @@ void DownpourWorker::TrainFiles() {
} }
} }
if (!need_skip) { if (!need_skip) {
#ifdef PADDLE_WITH_PSLIB
try {
op->Run(*thread_scope_, place_);
} catch (std::exception& e) {
fprintf(stderr, "error message: %s\n", e.what());
auto& ins_id_vec = device_reader_->GetInsIdVec();
size_t batch_size = device_reader_->GetCurBatchSize();
std::string s = "";
for (auto& ins_id : ins_id_vec) {
if (s != "") s += ",";
s += ins_id;
}
fprintf(stderr, "batch_size: %zu, ins_ids_vec: %s\n", batch_size,
s.c_str());
s = "";
for (auto& param : all_param_) {
Variable* var = thread_scope_->FindVar(param);
if (var == nullptr) {
continue;
}
Tensor* tensor = nullptr;
int64_t len = 0;
if (var->IsType<framework::LoDTensor>()) {
tensor = var->GetMutable<LoDTensor>();
len = tensor->numel();
} else if (var->IsType<SelectedRows>()) {
auto selected_rows = var->GetMutable<SelectedRows>();
tensor = selected_rows->mutable_value();
len = tensor->numel();
}
if (!tensor->IsInitialized()) {
continue;
}
s += param + ":" + std::to_string(len) + ":";
s += PrintLodTensor(tensor, 0, len);
fprintf(stderr, "%s\n", s.c_str());
fflush(stderr);
s = "";
}
throw e;
}
#else
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
#endif
} }
} }
......
...@@ -58,6 +58,7 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc &program) { ...@@ -58,6 +58,7 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
thread_scope_ = &root_scope_->NewScope(); thread_scope_ = &root_scope_->NewScope();
for (auto &var : block.AllVars()) { for (auto &var : block.AllVars()) {
all_param_.push_back(var->Name());
if (var->Persistable()) { if (var->Persistable()) {
auto *ptr = root_scope_->Var(var->Name()); auto *ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册