提交 f362503d 编写于 作者: W Wei Luning

support throw attribute error from c++

上级 c5eae497
......@@ -40,6 +40,7 @@
#include "debug/trace.h"
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "pybind_api/pybind_patch.h"
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/common.h"
......@@ -536,6 +537,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
} catch (const py::index_error &ex) {
ReleaseResource(phase);
throw py::index_error(ex);
} catch (const py::attribute_error &ex) {
ReleaseResource(phase);
throw py::attribute_error(ex);
} catch (const std::exception &ex) {
ReleaseResource(phase);
// re-throw this exception to Python interpreter to handle it
......
......@@ -761,8 +761,8 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
ValuePtr method = cls->GetMethod(item_name);
if (method->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
<< ", item value: " << item_v->ToString();
MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
<< ", item value: " << item_v->ToString();
}
// Infer class method
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PYBIND_API_PYBIND_PATCH_H_
#define PYBIND_API_PYBIND_PATCH_H_
namespace pybind11 {
PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError)
}
#endif // PYBIND_API_PYBIND_PATCH_H_
......@@ -145,10 +145,11 @@ static std::string ExceptionTypeToString(ExceptionType type) {
_TO_STRING(IndexError),
_TO_STRING(ValueError),
_TO_STRING(TypeError),
_TO_STRING(AttributeError),
};
// clang-format on
#undef _TO_STRING
if (type < UnknownError || type > TypeError) {
if (type < UnknownError || type > AttributeError) {
type = UnknownError;
}
return std::string(type_names[type]);
......@@ -212,7 +213,7 @@ void LogWriter::operator^(const LogStream &stream) const {
std::ostringstream oss;
oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] ";
if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError &&
exception_type_ != ValueError) {
exception_type_ != ValueError && exception_type_ != AttributeError) {
oss << ExceptionTypeToString(exception_type_) << " ";
}
oss << msg.str();
......
......@@ -58,6 +58,7 @@ enum ExceptionType {
IndexError,
ValueError,
TypeError,
AttributeError,
};
struct LocationInfo {
......
......@@ -18,6 +18,7 @@
#include <string>
#include "pybind11/pybind11.h"
#include "pybind_api/pybind_patch.h"
namespace py = pybind11;
namespace mindspore {
......@@ -38,6 +39,9 @@ class PyExceptionInitializer {
if (exception_type == TypeError) {
throw py::type_error(str);
}
if (exception_type == AttributeError) {
throw py::attribute_error(str);
}
py::pybind11_fail(str);
}
};
......
......@@ -304,6 +304,29 @@ def test_access():
""" test_access """
invoke_dataclass(1, 2)
@dataclass
class Access2:
a: int
b: int
def max(self):
if self.a > self.b:
return self.c
return self.b
@ms_function
def invoke_dataclass2(x, y):
""" invoke_dataclass """
acs = Access2(x, y)
return acs.max()
def test_access_attr_error():
""" test_access """
with pytest.raises(AttributeError):
invoke_dataclass2(1, 2)
def myfunc(x):
""" myfunc """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册