提交 761cdbfa 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2959 Decouple ir.Signature class from python

Merge pull request !2959 from hewei/decouple_signature
......@@ -30,17 +30,21 @@
#include "pybind_api/export_flags.h"
namespace mindspore {
static ValuePtr PyArgToValue(const py::object &arg) {
if (py::isinstance<SignatureEnumKind>(arg) &&
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
return nullptr;
}
return parse::data_converter::PyDataToValue(arg);
}
void PrimitivePy::set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
signatures_.clear();
for (auto &signature : signatures) {
std::string name;
SignatureEnumRW rw;
SignatureEnumKind kind;
py::object default_value;
SignatureEnumDType dtype;
std::tie(name, rw, kind, default_value, dtype) = signature;
signatures_.emplace_back(Signature(name, rw, kind, default_value, dtype));
auto [name, rw, kind, arg_default, dtype] = signature;
auto default_value = PyArgToValue(arg_default);
signatures_.emplace_back(name, rw, kind, default_value, dtype);
}
set_has_signature(true);
}
......
......@@ -16,14 +16,11 @@
#ifndef MINDSPORE_CCSRC_IR_SIGNATURE_H_
#define MINDSPORE_CCSRC_IR_SIGNATURE_H_
#include <string>
#include <vector>
#include "pybind11/operators.h"
#include "ir/value.h"
namespace py = pybind11;
namespace mindspore {
// Input signature, support type
enum SignatureEnumRW {
......@@ -62,8 +59,10 @@ struct Signature {
ValuePtr default_value; // nullptr for no default value
SignatureEnumDType dtype;
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object &arg_default, const SignatureEnumDType &arg_dtype);
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind);
const ValuePtr &arg_default, const SignatureEnumDType &arg_dtype)
: name(arg_name), rw(rw_tag), kind(arg_kind), default_value(arg_default), dtype(arg_dtype) {}
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
: Signature(arg_name, rw_tag, arg_kind, nullptr, SignatureEnumDType::kDTypeEmptyDefaultValue) {}
};
} // namespace mindspore
......
......@@ -15,30 +15,14 @@
*/
#include "ir/signature.h"
#include "pybind11/operators.h"
#include "pybind_api/api_register.h"
#include "pipeline/parse/data_converter.h"
namespace mindspore {
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object &arg_default, const SignatureEnumDType &arg_dtype)
: name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) {
if (py::isinstance<SignatureEnumKind>(arg_default) &&
py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) {
default_value = nullptr;
} else {
default_value = parse::data_converter::PyDataToValue(arg_default);
}
}
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
: name(arg_name),
rw(rw_tag),
kind(arg_kind),
default_value(nullptr),
dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {}
namespace py = pybind11;
namespace mindspore {
// Bind SignatureEnumRW as a python class.
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
.value("RW_READ", SignatureEnumRW::kRWRead)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册