提交 7f6d8ace 编写于 作者: P peizhilin

cherry-pick the #12759

test=develop
上级 05d1121b
...@@ -82,6 +82,10 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -82,6 +82,10 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
AddAttr<std::string>(OpNamescopeAttrName(), "Operator name with namesope.") AddAttr<std::string>(OpNamescopeAttrName(), "Operator name with namesope.")
.SetDefault(""); .SetDefault("");
AddAttr<std::vector<std::string>>(OpCreationCallstackAttrName(),
"Callstack for Op Creatation.")
.SetDefault({});
Validate(); Validate();
} }
......
...@@ -47,6 +47,7 @@ class OpProtoAndCheckerMaker { ...@@ -47,6 +47,7 @@ class OpProtoAndCheckerMaker {
static const char *OpRoleAttrName() { return "op_role"; } static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; } static const char *OpRoleVarAttrName() { return "op_role_var"; }
static const char *OpNamescopeAttrName() { return "op_namescope"; } static const char *OpNamescopeAttrName() { return "op_namescope"; }
static const char *OpCreationCallstackAttrName() { return "op_callstack"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker); void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
......
...@@ -16,10 +16,15 @@ limitations under the License. */ ...@@ -16,10 +16,15 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
...@@ -157,7 +162,10 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, ...@@ -157,7 +162,10 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
} }
void OperatorBase::Run(const Scope& scope, const platform::Place& place) { void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
try {
if (VLOG_IS_ON(4)) {
VLOG(4) << place << " " << DebugStringEx(&scope); VLOG(4) << place << " " << DebugStringEx(&scope);
}
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place); PADDLE_THROW("Cannot run operator on place %s", place);
...@@ -167,17 +175,46 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -167,17 +175,46 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
#endif #endif
} }
// The profile has a process-wide mutex, results in serious performance issue // The profile has a process-wide mutex, results in serious performance
// issue
// in concurrency scenerio. Here use an `if` to fix this issue. // in concurrency scenerio. Here use an `if` to fix this issue.
// Please not remove the `if`, ask @Superjomn if there are any concern. // Please not remove the `if`, ask @Superjomn if there are any concern.
if (platform::IsProfileEnabled()) { if (platform::IsProfileEnabled()) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place)); platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place); RunImpl(scope, place);
} else { } else {
RunImpl(scope, place); RunImpl(scope, place);
} }
if (VLOG_IS_ON(3)) {
VLOG(3) << place << " " << DebugStringEx(&scope); VLOG(3) << place << " " << DebugStringEx(&scope);
}
} catch (platform::EnforceNotMet exception) {
if (Attrs().count("sub_block") != 0) {
throw exception;
}
auto& callstack = Attr<std::vector<std::string>>(
OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (callstack.empty()) {
throw exception;
}
std::ostringstream sout;
sout << "Invoke operator " << Type() << " error.\n";
sout << "Python Callstacks: \n";
for (auto& line : callstack) {
sout << line;
}
sout << "C++ Callstacks: \n";
sout << exception.err_str_;
exception.err_str_ = sout.str();
throw exception;
} catch (...) {
std::rethrow_exception(std::current_exception());
}
} }
bool OperatorBase::HasInputs(const std::string& name) const { bool OperatorBase::HasInputs(const std::string& name) const {
......
...@@ -30,6 +30,8 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -30,6 +30,8 @@ class TopkOp : public framework::OperatorWithKernel {
"Output(Indices) of TopkOp should not be null."); "Output(Indices) of TopkOp should not be null.");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(input_dims.size(), 2,
"Rank of TopK op's input must be 2.");
const int k = static_cast<int>(ctx->Attrs().Get<int>("k")); const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
......
...@@ -49,6 +49,9 @@ void BindConstValue(pybind11::module* m) { ...@@ -49,6 +49,9 @@ void BindConstValue(pybind11::module* m) {
op_proto_and_checker_maker.def( op_proto_and_checker_maker.def(
"kOpNameScopeAttrName", "kOpNameScopeAttrName",
framework::OpProtoAndCheckerMaker::OpNamescopeAttrName); framework::OpProtoAndCheckerMaker::OpNamescopeAttrName);
op_proto_and_checker_maker.def(
"kOpCreationCallstackAttrName",
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName);
} }
} // namespace pybind } // namespace pybind
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import re import re
import six import six
import sys import sys
import traceback
import numpy as np import numpy as np
...@@ -604,6 +605,10 @@ class Operator(object): ...@@ -604,6 +605,10 @@ class Operator(object):
if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0: if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0:
del op_attrs[role_var_name] del op_attrs[role_var_name]
callstack_var_name = op_maker.kOpCreationCallstackAttrName()
op_attrs[callstack_var_name] = list(
reversed(traceback.format_stack()))[1:]
if len(self.desc.type()) != 0: if len(self.desc.type()) != 0:
return return
if type is None: if type is None:
......
...@@ -69,7 +69,7 @@ class TestOperator(unittest.TestCase): ...@@ -69,7 +69,7 @@ class TestOperator(unittest.TestCase):
set(mul_op.attr_names), set(mul_op.attr_names),
set([ set([
"x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var", "x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var",
"op_namescope" "op_namescope", "op_callstack"
])) ]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册