提交 f362503d 编写于 作者: W Wei Luning

support throw attribute error from c++

上级 c5eae497
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include "debug/trace.h" #include "debug/trace.h"
#include "pipeline/pynative/pynative_execute.h" #include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/py_pass_manager.h" #include "frontend/optimizer/py_pass_manager.h"
#include "pybind_api/pybind_patch.h"
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES) #if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/common.h" #include "frontend/parallel/ps/common.h"
...@@ -536,6 +537,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py: ...@@ -536,6 +537,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
} catch (const py::index_error &ex) { } catch (const py::index_error &ex) {
ReleaseResource(phase); ReleaseResource(phase);
throw py::index_error(ex); throw py::index_error(ex);
} catch (const py::attribute_error &ex) {
ReleaseResource(phase);
throw py::attribute_error(ex);
} catch (const std::exception &ex) { } catch (const std::exception &ex) {
ReleaseResource(phase); ReleaseResource(phase);
// re-throw this exception to Python interpreter to handle it // re-throw this exception to Python interpreter to handle it
......
...@@ -761,8 +761,8 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng ...@@ -761,8 +761,8 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
ValuePtr method = cls->GetMethod(item_name); ValuePtr method = cls->GetMethod(item_name);
if (method->isa<AnyValue>()) { if (method->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
<< ", item value: " << item_v->ToString(); << ", item value: " << item_v->ToString();
} }
// Infer class method // Infer class method
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PYBIND_API_PYBIND_PATCH_H_
#define PYBIND_API_PYBIND_PATCH_H_
namespace pybind11 {
PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError)
}
#endif // PYBIND_API_PYBIND_PATCH_H_
...@@ -145,10 +145,11 @@ static std::string ExceptionTypeToString(ExceptionType type) { ...@@ -145,10 +145,11 @@ static std::string ExceptionTypeToString(ExceptionType type) {
_TO_STRING(IndexError), _TO_STRING(IndexError),
_TO_STRING(ValueError), _TO_STRING(ValueError),
_TO_STRING(TypeError), _TO_STRING(TypeError),
_TO_STRING(AttributeError),
}; };
// clang-format on // clang-format on
#undef _TO_STRING #undef _TO_STRING
if (type < UnknownError || type > TypeError) { if (type < UnknownError || type > AttributeError) {
type = UnknownError; type = UnknownError;
} }
return std::string(type_names[type]); return std::string(type_names[type]);
...@@ -212,7 +213,7 @@ void LogWriter::operator^(const LogStream &stream) const { ...@@ -212,7 +213,7 @@ void LogWriter::operator^(const LogStream &stream) const {
std::ostringstream oss; std::ostringstream oss;
oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] "; oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] ";
if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError && if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError &&
exception_type_ != ValueError) { exception_type_ != ValueError && exception_type_ != AttributeError) {
oss << ExceptionTypeToString(exception_type_) << " "; oss << ExceptionTypeToString(exception_type_) << " ";
} }
oss << msg.str(); oss << msg.str();
......
...@@ -58,6 +58,7 @@ enum ExceptionType { ...@@ -58,6 +58,7 @@ enum ExceptionType {
IndexError, IndexError,
ValueError, ValueError,
TypeError, TypeError,
AttributeError,
}; };
struct LocationInfo { struct LocationInfo {
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <string> #include <string>
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind_api/pybind_patch.h"
namespace py = pybind11; namespace py = pybind11;
namespace mindspore { namespace mindspore {
...@@ -38,6 +39,9 @@ class PyExceptionInitializer { ...@@ -38,6 +39,9 @@ class PyExceptionInitializer {
if (exception_type == TypeError) { if (exception_type == TypeError) {
throw py::type_error(str); throw py::type_error(str);
} }
if (exception_type == AttributeError) {
throw py::attribute_error(str);
}
py::pybind11_fail(str); py::pybind11_fail(str);
} }
}; };
......
...@@ -304,6 +304,29 @@ def test_access(): ...@@ -304,6 +304,29 @@ def test_access():
""" test_access """ """ test_access """
invoke_dataclass(1, 2) invoke_dataclass(1, 2)
@dataclass
class Access2:
a: int
b: int
def max(self):
if self.a > self.b:
return self.c
return self.b
@ms_function
def invoke_dataclass2(x, y):
""" invoke_dataclass """
acs = Access2(x, y)
return acs.max()
def test_access_attr_error():
""" test_access """
with pytest.raises(AttributeError):
invoke_dataclass2(1, 2)
def myfunc(x): def myfunc(x):
""" myfunc """ """ myfunc """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册