提交 761cdbfa 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2959 Decouple ir.Signature class from python

Merge pull request !2959 from hewei/decouple_signature
...@@ -30,17 +30,21 @@ ...@@ -30,17 +30,21 @@
#include "pybind_api/export_flags.h" #include "pybind_api/export_flags.h"
namespace mindspore { namespace mindspore {
static ValuePtr PyArgToValue(const py::object &arg) {
if (py::isinstance<SignatureEnumKind>(arg) &&
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
return nullptr;
}
return parse::data_converter::PyDataToValue(arg);
}
void PrimitivePy::set_signatures( void PrimitivePy::set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) { std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
signatures_.clear(); signatures_.clear();
for (auto &signature : signatures) { for (auto &signature : signatures) {
std::string name; auto [name, rw, kind, arg_default, dtype] = signature;
SignatureEnumRW rw; auto default_value = PyArgToValue(arg_default);
SignatureEnumKind kind; signatures_.emplace_back(name, rw, kind, default_value, dtype);
py::object default_value;
SignatureEnumDType dtype;
std::tie(name, rw, kind, default_value, dtype) = signature;
signatures_.emplace_back(Signature(name, rw, kind, default_value, dtype));
} }
set_has_signature(true); set_has_signature(true);
} }
......
...@@ -16,14 +16,11 @@ ...@@ -16,14 +16,11 @@
#ifndef MINDSPORE_CCSRC_IR_SIGNATURE_H_ #ifndef MINDSPORE_CCSRC_IR_SIGNATURE_H_
#define MINDSPORE_CCSRC_IR_SIGNATURE_H_ #define MINDSPORE_CCSRC_IR_SIGNATURE_H_
#include <string> #include <string>
#include <vector> #include <vector>
#include "pybind11/operators.h"
#include "ir/value.h" #include "ir/value.h"
namespace py = pybind11;
namespace mindspore { namespace mindspore {
// Input signature, support type // Input signature, support type
enum SignatureEnumRW { enum SignatureEnumRW {
...@@ -62,8 +59,10 @@ struct Signature { ...@@ -62,8 +59,10 @@ struct Signature {
ValuePtr default_value; // nullptr for no default value ValuePtr default_value; // nullptr for no default value
SignatureEnumDType dtype; SignatureEnumDType dtype;
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object &arg_default, const SignatureEnumDType &arg_dtype); const ValuePtr &arg_default, const SignatureEnumDType &arg_dtype)
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind); : name(arg_name), rw(rw_tag), kind(arg_kind), default_value(arg_default), dtype(arg_dtype) {}
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
: Signature(arg_name, rw_tag, arg_kind, nullptr, SignatureEnumDType::kDTypeEmptyDefaultValue) {}
}; };
} // namespace mindspore } // namespace mindspore
......
...@@ -15,30 +15,14 @@ ...@@ -15,30 +15,14 @@
*/ */
#include "ir/signature.h" #include "ir/signature.h"
#include "pybind11/operators.h" #include "pybind11/operators.h"
#include "pybind_api/api_register.h" #include "pybind_api/api_register.h"
#include "pipeline/parse/data_converter.h" #include "pipeline/parse/data_converter.h"
namespace mindspore { namespace py = pybind11;
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object &arg_default, const SignatureEnumDType &arg_dtype)
: name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) {
if (py::isinstance<SignatureEnumKind>(arg_default) &&
py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) {
default_value = nullptr;
} else {
default_value = parse::data_converter::PyDataToValue(arg_default);
}
}
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
: name(arg_name),
rw(rw_tag),
kind(arg_kind),
default_value(nullptr),
dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {}
namespace mindspore {
// Bind SignatureEnumRW as a python class.
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic()) (void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
.value("RW_READ", SignatureEnumRW::kRWRead) .value("RW_READ", SignatureEnumRW::kRWRead)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册