From f362503d5740ce0d390a7186aa48d4f64e44672c Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Sat, 18 Jul 2020 17:53:13 +0800 Subject: [PATCH] support throw attribute error from c++ --- mindspore/ccsrc/pipeline/jit/pipeline.cc | 4 ++++ .../pipeline/jit/static_analysis/prim.cc | 4 ++-- mindspore/ccsrc/pybind_api/pybind_patch.h | 24 +++++++++++++++++++ mindspore/ccsrc/utils/log_adapter.cc | 5 ++-- mindspore/ccsrc/utils/log_adapter.h | 1 + mindspore/ccsrc/utils/log_adapter_py.cc | 4 ++++ .../python/pynative_mode/test_parse_method.py | 23 ++++++++++++++++++ 7 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 mindspore/ccsrc/pybind_api/pybind_patch.h diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 21d20c893..36be63387 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -40,6 +40,7 @@ #include "debug/trace.h" #include "pipeline/pynative/pynative_execute.h" #include "frontend/optimizer/py_pass_manager.h" +#include "pybind_api/pybind_patch.h" #if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES) #include "frontend/parallel/ps/common.h" @@ -536,6 +537,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py: } catch (const py::index_error &ex) { ReleaseResource(phase); throw py::index_error(ex); + } catch (const py::attribute_error &ex) { + ReleaseResource(phase); + throw py::attribute_error(ex); } catch (const std::exception &ex) { ReleaseResource(phase); // re-throw this exception to Python interpreter to handle it diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 90d4aaa12..7ab51bb22 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -761,8 +761,8 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng ValuePtr method = cls->GetMethod(item_name); if (method->isa()) { - MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() - << ", item value: " << item_v->ToString(); + MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() + << ", item value: " << item_v->ToString(); } // Infer class method diff --git a/mindspore/ccsrc/pybind_api/pybind_patch.h b/mindspore/ccsrc/pybind_api/pybind_patch.h new file mode 100644 index 000000000..a71774b26 --- /dev/null +++ b/mindspore/ccsrc/pybind_api/pybind_patch.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PYBIND_API_PYBIND_PATCH_H_ +#define PYBIND_API_PYBIND_PATCH_H_ + +namespace pybind11 { +PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError) +} + +#endif // PYBIND_API_PYBIND_PATCH_H_ diff --git a/mindspore/ccsrc/utils/log_adapter.cc b/mindspore/ccsrc/utils/log_adapter.cc index 702deefcb..175e790c3 100644 --- a/mindspore/ccsrc/utils/log_adapter.cc +++ b/mindspore/ccsrc/utils/log_adapter.cc @@ -145,10 +145,11 @@ static std::string ExceptionTypeToString(ExceptionType type) { _TO_STRING(IndexError), _TO_STRING(ValueError), _TO_STRING(TypeError), + _TO_STRING(AttributeError), }; // clang-format on #undef _TO_STRING - if (type < UnknownError || type > TypeError) { + if (type < UnknownError || type > AttributeError) { type = UnknownError; } return std::string(type_names[type]); @@ -212,7 +213,7 @@ void LogWriter::operator^(const LogStream &stream) const { std::ostringstream oss; oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] "; if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError && - exception_type_ != ValueError) { + exception_type_ != ValueError && exception_type_ != AttributeError) { oss << ExceptionTypeToString(exception_type_) << " "; } oss << msg.str(); diff --git a/mindspore/ccsrc/utils/log_adapter.h b/mindspore/ccsrc/utils/log_adapter.h index a0e9bfc6d..53c94a634 100644 --- a/mindspore/ccsrc/utils/log_adapter.h +++ b/mindspore/ccsrc/utils/log_adapter.h @@ -58,6 +58,7 @@ enum ExceptionType { IndexError, ValueError, TypeError, + AttributeError, }; struct LocationInfo { diff --git a/mindspore/ccsrc/utils/log_adapter_py.cc b/mindspore/ccsrc/utils/log_adapter_py.cc index c4793b960..db086f37a 100644 --- a/mindspore/ccsrc/utils/log_adapter_py.cc +++ b/mindspore/ccsrc/utils/log_adapter_py.cc @@ -18,6 +18,7 @@ #include #include "pybind11/pybind11.h" +#include "pybind_api/pybind_patch.h" namespace py = pybind11; namespace mindspore { @@ -38,6 +39,9 @@ class PyExceptionInitializer { if (exception_type == TypeError) { throw py::type_error(str); } + if (exception_type == AttributeError) { + throw py::attribute_error(str); + } py::pybind11_fail(str); } }; diff --git a/tests/ut/python/pynative_mode/test_parse_method.py b/tests/ut/python/pynative_mode/test_parse_method.py index f189b825e..0a8c1767d 100644 --- a/tests/ut/python/pynative_mode/test_parse_method.py +++ b/tests/ut/python/pynative_mode/test_parse_method.py @@ -304,6 +304,29 @@ def test_access(): """ test_access """ invoke_dataclass(1, 2) +@dataclass +class Access2: + a: int + b: int + + def max(self): + if self.a > self.b: + return self.c + return self.b + + +@ms_function +def invoke_dataclass2(x, y): + """ invoke_dataclass """ + acs = Access2(x, y) + return acs.max() + + +def test_access_attr_error(): + """ test_access """ + with pytest.raises(AttributeError): + invoke_dataclass2(1, 2) + def myfunc(x): """ myfunc """ -- GitLab