提交 7f6d8ace 编写于 作者: P peizhilin

cherry-pick the #12759

test=develop
上级 05d1121b
......@@ -82,6 +82,10 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
AddAttr<std::string>(OpNamescopeAttrName(), "Operator name with namesope.")
.SetDefault("");
AddAttr<std::vector<std::string>>(OpCreationCallstackAttrName(),
"Callstack for Op Creatation.")
.SetDefault({});
Validate();
}
......
......@@ -47,6 +47,7 @@ class OpProtoAndCheckerMaker {
static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; }
static const char *OpNamescopeAttrName() { return "op_namescope"; }
static const char *OpCreationCallstackAttrName() { return "op_callstack"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
......
......@@ -16,10 +16,15 @@ limitations under the License. */
#include <glog/logging.h>
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
......@@ -157,27 +162,59 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
try {
if (VLOG_IS_ON(4)) {
VLOG(4) << place << " " << DebugStringEx(&scope);
}
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place);
PADDLE_THROW("Cannot run operator on place %s", place);
#else
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
#endif
}
}
// The profile has a process-wide mutex, results in serious performance issue
// in concurrency scenerio. Here use an `if` to fix this issue.
// Please not remove the `if`, ask @Superjomn if there are any concern.
if (platform::IsProfileEnabled()) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
} else {
RunImpl(scope, place);
// The profile has a process-wide mutex, results in serious performance
// issue
// in concurrency scenerio. Here use an `if` to fix this issue.
// Please not remove the `if`, ask @Superjomn if there are any concern.
if (platform::IsProfileEnabled()) {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
} else {
RunImpl(scope, place);
}
if (VLOG_IS_ON(3)) {
VLOG(3) << place << " " << DebugStringEx(&scope);
}
} catch (platform::EnforceNotMet exception) {
if (Attrs().count("sub_block") != 0) {
throw exception;
}
auto& callstack = Attr<std::vector<std::string>>(
OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (callstack.empty()) {
throw exception;
}
std::ostringstream sout;
sout << "Invoke operator " << Type() << " error.\n";
sout << "Python Callstacks: \n";
for (auto& line : callstack) {
sout << line;
}
sout << "C++ Callstacks: \n";
sout << exception.err_str_;
exception.err_str_ = sout.str();
throw exception;
} catch (...) {
std::rethrow_exception(std::current_exception());
}
VLOG(3) << place << " " << DebugStringEx(&scope);
}
bool OperatorBase::HasInputs(const std::string& name) const {
......
......@@ -30,6 +30,8 @@ class TopkOp : public framework::OperatorWithKernel {
"Output(Indices) of TopkOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(input_dims.size(), 2,
"Rank of TopK op's input must be 2.");
const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
......
......@@ -49,6 +49,9 @@ void BindConstValue(pybind11::module* m) {
op_proto_and_checker_maker.def(
"kOpNameScopeAttrName",
framework::OpProtoAndCheckerMaker::OpNamescopeAttrName);
op_proto_and_checker_maker.def(
"kOpCreationCallstackAttrName",
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName);
}
} // namespace pybind
......
......@@ -20,6 +20,7 @@ import os
import re
import six
import sys
import traceback
import numpy as np
......@@ -604,6 +605,10 @@ class Operator(object):
if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0:
del op_attrs[role_var_name]
callstack_var_name = op_maker.kOpCreationCallstackAttrName()
op_attrs[callstack_var_name] = list(
reversed(traceback.format_stack()))[1:]
if len(self.desc.type()) != 0:
return
if type is None:
......
......@@ -69,7 +69,7 @@ class TestOperator(unittest.TestCase):
set(mul_op.attr_names),
set([
"x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var",
"op_namescope"
"op_namescope", "op_callstack"
]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册