未验证 提交 b59426b5 编写于 作者: L Leo Chen 提交者: GitHub

Enhance error msg of imperative code (#23572)

* fix init_gflags with 'python -c', test=develop

* enhance error msg related Tracer, test=develop

* refine err msg, test=develop

* follow comments, test=develop
上级 1f830691
...@@ -41,7 +41,9 @@ void ThreadSafeNameSet::Insert(const std::string& name) { ...@@ -41,7 +41,9 @@ void ThreadSafeNameSet::Insert(const std::string& name) {
void ThreadSafeNameSet::Remove(const std::string& name) { void ThreadSafeNameSet::Remove(const std::string& name) {
std::lock_guard<std::mutex> guard(mtx_); std::lock_guard<std::mutex> guard(mtx_);
auto iter = set_.find(name); auto iter = set_.find(name);
PADDLE_ENFORCE_EQ(iter != set_.end(), true, "%s does not exist", name); PADDLE_ENFORCE_EQ(
iter != set_.end(), true,
platform::errors::NotFound("Variable name %s does not exist", name));
set_.erase(iter); set_.erase(iter);
} }
...@@ -54,48 +56,6 @@ ThreadSafeNameSet VarBase::name_set_; ...@@ -54,48 +56,6 @@ ThreadSafeNameSet VarBase::name_set_;
std::vector<std::string> VarBase::AliveVarNames() { return name_set_.Names(); } std::vector<std::string> VarBase::AliveVarNames() { return name_set_.Names(); }
static framework::VariableNameMap CreateVarNameMap(
const framework::OpInfo& op_info, const std::string& op_type,
const NameVarBaseMap& varbase_map, bool is_input) {
if (op_info.proto_ == nullptr) {
framework::VariableNameMap result;
for (auto& it : varbase_map) {
auto& var_vector = it.second;
std::vector<std::string> args;
args.reserve(var_vector.size());
for (auto& var_base : var_vector) {
args.emplace_back(var_base->Name());
}
result[it.first] = std::move(args);
}
return result;
}
framework::VariableNameMap result;
for (auto& var :
is_input ? op_info.Proto().inputs() : op_info.Proto().outputs()) {
auto it = varbase_map.find(var.name());
if (it == varbase_map.end()) {
PADDLE_ENFORCE_EQ(
var.dispensable(), true,
"Var: %s not dispensable and there are no such var in inputs",
var.name());
result[var.name()] = {};
} else {
auto& var_vector = it->second;
std::vector<std::string> args;
args.reserve(var_vector.size());
for (auto& var_base : var_vector) {
args.emplace_back(var_base->Name());
}
result[var.name()] = std::move(args);
}
}
return result;
}
static framework::RuntimeContext PrepareRuntimeContext( static framework::RuntimeContext PrepareRuntimeContext(
const NameVarBaseMap& ins, const NameVarBaseMap& outs) { const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
framework::VariableValueMap inputs, outputs; framework::VariableValueMap inputs, outputs;
...@@ -323,7 +283,9 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -323,7 +283,9 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const platform::Place& place) { const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op); auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); PADDLE_ENFORCE_NOT_NULL(
op_kernel, platform::errors::PermissionDenied(
"Only support operator with kernel in Dygraph mode."));
auto& info = op.Info(); auto& info = op.Info();
if (info.infer_var_type_) { if (info.infer_var_type_) {
RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, outs, attrs); RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, outs, attrs);
......
...@@ -95,11 +95,12 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, ...@@ -95,11 +95,12 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
// check if op[type] has kernel registered. // check if op[type] has kernel registered.
auto& all_op_kernels = op.AllOpKernels(); auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type()); auto kernels_iter = all_op_kernels.find(op.Type());
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW( PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::NotFound(
"There are no kernels which are registered in the %s operator.", "There are no kernels which are registered in the %s operator.",
op.Type()); op.Type()));
}
auto& kernels = kernels_iter->second; auto& kernels = kernels_iter->second;
...@@ -111,10 +112,10 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, ...@@ -111,10 +112,10 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
// TODO(jiabin): Add operator.cc's line 1000 part back when we need that case // TODO(jiabin): Add operator.cc's line 1000 part back when we need that case
if (kernel_iter == kernels.end()) { PADDLE_ENFORCE_NE(kernel_iter, kernels.end(),
PADDLE_THROW("op %s does not have kernel for %s", op.Type(), platform::errors::NotFound(
KernelTypeToString(expected_kernel_key)); "Operator %s does not have kernel for %s.", op.Type(),
} KernelTypeToString(expected_kernel_key)));
if (!(expected_kernel_key.place_ == place)) { if (!(expected_kernel_key.place_ == place)) {
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
......
...@@ -68,8 +68,9 @@ static framework::VariableNameMap CreateVarNameMap( ...@@ -68,8 +68,9 @@ static framework::VariableNameMap CreateVarNameMap(
if (it == varbase_map.end()) { if (it == varbase_map.end()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var.dispensable(), true, var.dispensable(), true,
"Var: %s not dispensable and there are no such var in inputs", platform::errors::NotFound("Variable %s is not dispensable and "
var.name()); "there are no such var in inputs",
var.name()));
result[var.name()] = {}; result[var.name()] = {};
} else { } else {
auto& var_vector = it->second; auto& var_vector = it->second;
......
...@@ -101,7 +101,8 @@ static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self, ...@@ -101,7 +101,8 @@ static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
const py::kwargs &kwargs) { const py::kwargs &kwargs) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
kwargs.contains("value"), true, kwargs.contains("value"), true,
platform::errors::InvalidArgument("Missing argument: value")); platform::errors::NotFound(
"The kwargs used to create Varbase misses argument: value"));
auto persistable = kwargs.contains("persistable") auto persistable = kwargs.contains("persistable")
? kwargs["persistable"].cast<bool>() ? kwargs["persistable"].cast<bool>()
...@@ -158,7 +159,8 @@ static T PyObjectCast(PyObject *obj) { ...@@ -158,7 +159,8 @@ static T PyObjectCast(PyObject *obj) {
try { try {
return py::cast<T>(py::handle(obj)); return py::cast<T>(py::handle(obj));
} catch (py::cast_error &) { } catch (py::cast_error &) {
PADDLE_THROW("Python object is not type of %s", typeid(T).name()); PADDLE_THROW(platform::errors::InvalidArgument(
"Python object is not type of %s", typeid(T).name()));
} }
} }
...@@ -212,8 +214,9 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap( ...@@ -212,8 +214,9 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
} }
} }
PADDLE_ENFORCE_EQ(PyErr_Occurred() == nullptr, true, PADDLE_ENFORCE_EQ(
py::str(py::handle(PyErr_Occurred()))); PyErr_Occurred(), nullptr,
platform::errors::InvalidArgument(py::str(py::handle(PyErr_Occurred()))));
return result; return result;
} }
...@@ -503,7 +506,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -503,7 +506,7 @@ void BindImperative(py::module *m_ptr) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensor.IsInitialized(), true, tensor.IsInitialized(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"%s is Empty, Please check if it has no data in", "Tensor of %s is Empty, please check if it has no data.",
self.Name())); self.Name()));
return TensorToPyArray(tensor, true); return TensorToPyArray(tensor, true);
}, },
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册