/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_ASCEND #include #ifdef _POSIX_C_SOURCE #undef _POSIX_C_SOURCE #endif #ifdef _XOPEN_SOURCE #undef _XOPEN_SOURCE #endif #include #include #include #include #include #include #include #include #include "paddle/fluid/framework/fleet/ascend_wrapper.h" #include "paddle/fluid/pybind/ascend_wrapper_py.h" #include "paddle/fluid/platform/ascend_npu_info.h" #include "paddle/fluid/platform/enforce.h" using namespace ge; // NOLINT namespace py = pybind11; namespace paddle { namespace pybind { void BindAscendWrapper(py::module *m) { py::class_>(*m, "AscendInstance") .def(py::init([]() { return framework::AscendInstance::GetInstance(); })) .def("init_global_resources", &framework::AscendInstance::InitGlobalResouces, py::call_guard()) .def("add_ascend_subgraph", &framework::AscendInstance::AddAscendSubgraph, py::call_guard()); } // end AscendWrapper std::map convert_map(const std::map& options){ std::map rets; for (auto &option : options) { ge::AscendString key = option.first.c_str(); ge::AscendString val = option.second.c_str(); rets[key] = val; } return rets; } ge::Status ge_initialize(std::map &options) { // NOLINT py::gil_scoped_release release; auto init_options=convert_map(options); ge::Status res = ge::GEInitialize(init_options); PADDLE_ENFORCE_EQ(res, ge::SUCCESS, platform::errors::Fatal("ge init error:%d", res)); py::gil_scoped_acquire acquire; return res; } enum AttrType { AT_INT64 = 0, AT_INT32, AT_UINT32, AT_LIST_INT64, AT_LIST_INT32, AT_LIST_UINT32, AT_FLOAT, AT_LIST_FLOAT, AT_ATTR_VALUE, AT_STRING, AT_LIST_STRING, AT_BOOL, AT_LIST_BOOL, AT_TENSOR, AT_LIST_TENSOR, AT_LIST_UINT8, AT_LIST_LIST_INT64, AT_LIST_DT, AT_DT, AT_LIST_NAMEATTR, AT_NAMEATTR }; void BindAscendDevice(py::module* m) { py::class_(*m, "NPUDevice") .def_static("get_device_count", static_cast(&platform::ascend::NPUDevice::GetDeviceCount)); } void BindAscendGraph(py::module *m) { m->def("ge_initialize", &ge_initialize, "GEInitialize"); m->def("ge_finalize", &GEFinalize, "GEFinalize"); //枚举封装 py::enum_(*m, "GEGraphRunMode") .value("PREDICTION", GraphRunMode::PREDICTION) .value("TRAIN", GraphRunMode::TRAIN) .export_values(); py::enum_(*m, "GEDataType") .value("DT_FLOAT", DataType::DT_FLOAT) .value("DT_FLOAT16", DataType::DT_FLOAT16) .value("DT_INT8", DataType::DT_INT8) .value("DT_INT16", DataType::DT_INT16) .value("DT_UINT16", DataType::DT_UINT16) .value("DT_UINT8", DataType::DT_UINT8) .value("DT_INT32", DataType::DT_INT32) .value("DT_INT64", DataType::DT_INT64) .value("DT_UINT32", DataType::DT_UINT32) .value("DT_UINT64", DataType::DT_UINT64) .value("DT_BOOL", DataType::DT_BOOL) .value("DT_DOUBLE", DataType::DT_DOUBLE) .value("DT_STRING", DataType::DT_STRING) .value("DT_DUAL_SUB_INT8", DataType::DT_DUAL_SUB_INT8) .value("DT_DUAL_SUB_UINT8", DataType::DT_DUAL_SUB_UINT8) .value("DT_COMPLEX64", DataType::DT_COMPLEX64) .value("DT_COMPLEX128", DataType::DT_COMPLEX128) .value("DT_QINT8", DataType::DT_QINT8) .value("DT_QINT16", DataType::DT_QINT16) .value("DT_QINT32", DataType::DT_QINT32) .value("DT_QUINT8", DataType::DT_QUINT8) .value("DT_QUINT16", DataType::DT_QUINT16) .value("DT_RESOURCE", DataType::DT_RESOURCE) .value("DT_STRING_REF", DataType::DT_STRING_REF) .value("DT_DUAL", DataType::DT_DUAL) .value("DT_UNDEFINED", DataType::DT_UNDEFINED) .export_values(); py::enum_(*m, "GEFormat") .value("FORMAT_NCHW", Format::FORMAT_NCHW) .value("FORMAT_NHWC", Format::FORMAT_NHWC) .value("FORMAT_ND", Format::FORMAT_ND) .value("FORMAT_NC1HWC0", Format::FORMAT_NC1HWC0) .value("FORMAT_FRACTAL_Z", Format::FORMAT_FRACTAL_Z) .value("FORMAT_NC1C0HWPAD", Format::FORMAT_NC1C0HWPAD) .value("FORMAT_NHWC1C0", Format::FORMAT_NHWC1C0) .value("FORMAT_FSR_NCHW", Format::FORMAT_FSR_NCHW) .value("FORMAT_FRACTAL_DECONV", Format::FORMAT_FRACTAL_DECONV) .value("FORMAT_C1HWNC0", Format::FORMAT_C1HWNC0) .value("FORMAT_FRACTAL_DECONV_TRANSPOSE", Format::FORMAT_FRACTAL_DECONV_TRANSPOSE) .value("FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS", Format::FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS) .value("FORMAT_NC1HWC0_C04", Format::FORMAT_NC1HWC0_C04) .value("FORMAT_FRACTAL_Z_C04", Format::FORMAT_FRACTAL_Z_C04) .value("FORMAT_CHWN", Format::FORMAT_CHWN) .value("FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS", Format::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS) .value("FORMAT_HWCN", Format::FORMAT_HWCN) .value("FORMAT_NC1KHKWHWC0", Format::FORMAT_NC1KHKWHWC0) .value("FORMAT_BN_WEIGHT", Format::FORMAT_BN_WEIGHT) .value("FORMAT_FILTER_HWCK", Format::FORMAT_FILTER_HWCK) .value("FORMAT_HASHTABLE_LOOKUP_LOOKUPS", Format::FORMAT_HASHTABLE_LOOKUP_LOOKUPS) .value("FORMAT_HASHTABLE_LOOKUP_KEYS", Format::FORMAT_HASHTABLE_LOOKUP_KEYS) .value("FORMAT_HASHTABLE_LOOKUP_VALUE", Format::FORMAT_HASHTABLE_LOOKUP_VALUE) .value("FORMAT_HASHTABLE_LOOKUP_OUTPUT", Format::FORMAT_HASHTABLE_LOOKUP_OUTPUT) .value("FORMAT_HASHTABLE_LOOKUP_HITS", Format::FORMAT_HASHTABLE_LOOKUP_HITS) .value("FORMAT_C1HWNCoC0", Format::FORMAT_C1HWNCoC0) .value("FORMAT_MD", Format::FORMAT_MD) .value("FORMAT_NDHWC", Format::FORMAT_NDHWC) .value("FORMAT_FRACTAL_ZZ", Format::FORMAT_FRACTAL_ZZ) .value("FORMAT_FRACTAL_NZ", Format::FORMAT_FRACTAL_NZ) .value("FORMAT_NCDHW", Format::FORMAT_NCDHW) .value("FORMAT_DHWCN", Format::FORMAT_DHWCN) .value("FORMAT_NDC1HWC0", Format::FORMAT_NDC1HWC0) .value("FORMAT_FRACTAL_Z_3D", Format::FORMAT_FRACTAL_Z_3D) .value("FORMAT_CN", Format::FORMAT_CN) .value("FORMAT_NC", Format::FORMAT_NC) .value("FORMAT_DHWNC", Format::FORMAT_DHWNC) .value("FORMAT_FRACTAL_Z_3D_TRANSPOSE", Format::FORMAT_FRACTAL_Z_3D_TRANSPOSE) .value("FORMAT_FRACTAL_ZN_LSTM", Format::FORMAT_FRACTAL_ZN_LSTM) .value("FORMAT_FRACTAL_Z_G", Format::FORMAT_FRACTAL_Z_G) .value("FORMAT_RESERVED", Format::FORMAT_RESERVED) .value("FORMAT_ALL", Format::FORMAT_ALL) .value("FORMAT_NULL", Format::FORMAT_NULL) .export_values(); py::enum_(*m, "GEUnknowShapeOpType") .value("DEPEND_IN_SHAPE", UnknowShapeOpType::DEPEND_IN_SHAPE) .value("DEPEND_CONST_VALUE", UnknowShapeOpType::DEPEND_CONST_VALUE) .value("DEPEND_SHAPE_RANGE", UnknowShapeOpType::DEPEND_SHAPE_RANGE) .value("DEPEND_COMPUTE", UnknowShapeOpType::DEPEND_COMPUTE) .export_values(); py::enum_(*m, "GEDeviceType") .value("NPU", DeviceType::NPU) .value("CPU", DeviceType::CPU) .export_values(); py::enum_(*m, "GEAttrType") .value("AT_INT64", AttrType::AT_INT64) .value("AT_INT32", AttrType::AT_INT32) .value("AT_UINT32", AttrType::AT_UINT32) .value("AT_LIST_INT64", AttrType::AT_LIST_INT64) .value("AT_LIST_INT32", AttrType::AT_LIST_INT32) .value("AT_LIST_UINT32", AttrType::AT_LIST_UINT32) .value("AT_FLOAT", AttrType::AT_FLOAT) .value("AT_LIST_FLOAT", AttrType::AT_LIST_FLOAT) .value("AT_ATTR_VALUE", AttrType::AT_ATTR_VALUE) .value("AT_STRING", AttrType::AT_STRING) .value("AT_LIST_STRING", AttrType::AT_LIST_STRING) .value("AT_BOOL", AttrType::AT_BOOL) .value("AT_LIST_BOOL", AttrType::AT_LIST_BOOL) .value("AT_TENSOR", AttrType::AT_TENSOR) .value("AT_LIST_TENSOR", AttrType::AT_LIST_TENSOR) .value("AT_LIST_UINT8", AttrType::AT_LIST_UINT8) .value("AT_LIST_LIST_INT64", AttrType::AT_LIST_LIST_INT64) .value("AT_LIST_DT", AttrType::AT_LIST_DT) .value("AT_DT", AttrType::AT_DT) .value("AT_LIST_NAMEATTR", AttrType::AT_LIST_NAMEATTR) .value("AT_NAMEATTR", AttrType::AT_NAMEATTR) .export_values(); // 类封装 py::class_(*m, "GESession") .def(py::init([](const std::map & options) { return std::unique_ptr(new ge::Session(convert_map(options))); })) .def("add_graph", (ge::Status (Session::*)(uint32_t, const Graph &)) & Session::AddGraph) .def("add_graph", [](Session& ss, uint32_t index, const Graph & graph, const std::map &options){ return ss.AddGraph(index, graph, convert_map(options)); }) .def("remove_graph", &Session::RemoveGraph) .def("run_graph", [](Session &ss, uint32_t graphId, const std::vector &inputs) -> py::tuple { std::vector outputs; ge::Status res = ss.RunGraph(graphId, inputs, outputs); return py::make_tuple(outputs, res); }, py::call_guard()) .def("build_graph", &Session::BuildGraph) .def("run_graph_async", &Session::RunGraphAsync) .def("register_call_back_func", static_cast(&ge::Session::RegisterCallBackFunc)) .def("is_graph_need_rebuild", &Session::IsGraphNeedRebuild); py::class_(*m, "GEGraph") .def(py::init<>()) .def(py::init()) .def("set_inputs", &Graph::SetInputs) .def("set_outputs", (Graph & (Graph::*)(const std::vector &)) & Graph::SetOutputs) .def("set_outputs", (Graph & (Graph::*)(const std::vector< std::pair>> &)) & Graph::SetOutputs) .def("set_outputs", (Graph & (Graph::*)(const std::vector> &)) & Graph::SetOutputs) .def("set_targets", &Graph::SetTargets) .def("is_valid", &Graph::IsValid) .def("add_op", &Graph::AddOp) .def("find_op_by_name", [](Graph &graph, const char* name) -> py::tuple { ge::Operator op; graphStatus status = graph.FindOpByName(name, op); return py::make_tuple(op, status); }) .def("find_op_by_type", [](Graph &graph, const char * type) -> py::tuple { std::vector ops; graphStatus status = graph.FindOpByType(type, ops); return py::make_tuple(ops, status); }) .def("get_all_op_name", [](Graph &graph) -> py::tuple { std::vector op_name; graphStatus status = graph.GetAllOpName(op_name); return py::make_tuple(op_name, status); }) .def("save_to_file", static_cast(&ge::Graph::SaveToFile)) .def("load_from_file", static_cast(&Graph::LoadFromFile)) .def("get_name", static_cast(&Graph::GetName)) .def("set_need_iteration", &Graph::SetNeedIteration); py::class_(*m, "GEOperator") .def(py::init<>()) .def(py::init()) .def(py::init()) .def("is_empty", &Operator::IsEmpty) .def("get_name", static_cast(&Operator::GetName)) .def("get_op_type", static_cast(&Operator::GetOpType)) .def("set_input", (Operator & (Operator::*)(const char*, const Operator &)) & Operator::SetInput) .def("set_input", (Operator & (Operator::*)(const char *, const Operator &, const char *)) & Operator::SetInput) .def("set_input", (Operator & (Operator::*)(const char *, const Operator &, uint32_t)) & Operator::SetInput) .def("add_control_input", &Operator::AddControlInput) .def("get_input_const_data", [](Operator &op, const char* dst_name) -> py::tuple { Tensor data; graphStatus res = op.GetInputConstData(dst_name, data); return py::make_tuple(data, res); }) .def("get_input_desc", (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc) .def("get_input_desc", [](Operator& op, const std::string& name){ return op.GetInputDescByName(name.c_str()); }) .def("get_dynamic_output_num", static_cast(&Operator::GetDynamicOutputNum)) .def("get_dynamic_input_num", static_cast(&Operator::GetDynamicInputNum)) .def("try_get_input_desc", [](Operator &op, const char* name) -> py::tuple { TensorDesc tensor_desc; graphStatus status = op.TryGetInputDesc(name, tensor_desc); return py::make_tuple(tensor_desc, status); }) .def("update_input_desc", static_cast(&Operator::UpdateInputDesc)) .def("get_output_desc", [](Operator& op, const std::string& name) { return op.GetOutputDescByName(name.c_str()); }) .def("get_output_desc", (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetOutputDesc) .def("update_output_desc", static_cast(&Operator::UpdateOutputDesc)) .def("get_dynamic_input_desc", static_cast(&Operator::GetDynamicInputDesc)) .def("update_dynamic_input_desc", static_cast(&Operator::UpdateDynamicInputDesc)) .def("get_dynamic_output_desc", static_cast(&Operator::GetDynamicOutputDesc)) .def("update_dynamic_output_desc", static_cast(&Operator::UpdateDynamicOutputDesc)) .def("infer_shape_and_type", &Operator::InferShapeAndType) .def("set_inference_context", &Operator::SetInferenceContext) .def("get_inference_context", &Operator::GetInferenceContext) .def("verify_all_attr", &Operator::VerifyAllAttr) .def("get_inputs_size", &Operator::GetInputsSize) .def("get_outputs_size", &Operator::GetOutputsSize) .def("get_all_attr_names_and_types", static_cast&) const>(&Operator::GetAllAttrNamesAndTypes)) .def("set_attr_int64", [](Operator &op, const char* name, int64_t value) -> Operator & { int64_t tar = (int64_t)value; return op.SetAttr(name, tar); }) .def("set_attr_int32", [](Operator &op, const char* name, int32_t value) -> Operator & { int32_t tar = (int32_t)value; return op.SetAttr(name, tar); }) .def("set_attr_uint32", [](Operator &op, const char* name, uint32_t value) -> Operator & { uint32_t tar = (uint32_t)value; return op.SetAttr(name, tar); }) .def("set_attr_vec_int64", [](Operator &op, const char* name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; int64_t tmp; for (int i = 0; i < len; i++) { tmp = (int64_t)value[i]; tar.push_back(tmp); } return op.SetAttr(name, tar); }) .def("set_attr_vec_int32", [](Operator &op, const char * name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; int32_t tmp; for (int i = 0; i < len; i++) { tmp = (int32_t)value[i]; tar.push_back(tmp); } return op.SetAttr(name, tar); }) .def("set_attr_vec_uint32", [](Operator &op, const char* name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; uint32_t tmp; for (int i = 0; i < len; i++) { tmp = (uint32_t)value[i]; tar.push_back(tmp); } return op.SetAttr(name, tar); }) .def("set_attr_list_int64", [](Operator &op, const char* name, std::initializer_list &attrValue) -> Operator & { return op.SetAttr(name, std::move(attrValue)); }) .def("set_attr_attrvalue", [](Operator &op, const char* name, AttrValue &attrValue) -> Operator & { return op.SetAttr(name, std::move(attrValue)); }) .def( "set_attr_float", [](Operator &op, const char* name, float value) -> Operator & { float tar = static_cast(value); return op.SetAttr(name, tar); }) .def("set_attr_vec_float", [](Operator &op, const char* name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; float tmp; for (int i = 0; i < len; i++) { tmp = static_cast(value[i]); tar.push_back(tmp); } return op.SetAttr(name, tar); }) .def("set_attr_string", (Operator & (Operator::*)(const char*, const char*)) & Operator::SetAttr) .def("set_attr_vec_string", (Operator & (Operator::*)(const char*, const std::vector &)) & Operator::SetAttr) .def("set_attr_bool", [](Operator &op, const char* name, bool value) -> Operator & { if (value) return op.SetAttr(name, true); else return op.SetAttr(name, false); }) .def("set_attr_vec_bool", [](Operator &op, const char* name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; for (int i = 0; i < len; i++) { if (value[i]) tar.push_back(true); else tar.push_back(false); } return op.SetAttr(name, tar); }) .def("set_attr_tensor", (Operator & (Operator::*)(const char* , const Tensor &)) & Operator::SetAttr) .def("set_attr_vec_tensor", (Operator & (Operator::*)(const char *, const std::vector &)) & Operator::SetAttr) .def("set_attr_vec_uint8", [](Operator &op, const char* name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; uint8_t tmp; for (int i = 0; i < len; i++) { tmp = (uint8_t)value[i]; tar.push_back(tmp); } return op.SetAttr(name, tar); }) .def("set_attr_vec_vec_int64", (Operator & (Operator::*)(const char*, const std::vector> &)) & Operator::SetAttr) .def("set_attr_vec_dtype", [](Operator &op, const char* name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; ge::DataType tmp; for (int i = 0; i < len; i++) { tmp = (ge::DataType)value[i]; tar.push_back(tmp); } return op.SetAttr(name, tar); }) .def("set_attr_dtype", [](Operator &op, const char* name, const DataType &value) -> Operator & { ge::DataType tar = (ge::DataType)value; return op.SetAttr(name, tar); }) .def("get_attr", [](Operator &op, const char* name, AttrType type) -> py::tuple { graphStatus res = -1; switch (type) { case AT_INT64: { int64_t i_64_av; res = op.GetAttr(name, i_64_av); return py::make_tuple(i_64_av, res); } break; case AT_INT32: { int32_t i_32_av; res = op.GetAttr(name, i_32_av); return py::make_tuple(i_32_av, res); } break; case AT_UINT32: { uint32_t ui_32_av; res = op.GetAttr(name, ui_32_av); return py::make_tuple(ui_32_av, res); } break; case AT_LIST_INT64: { std::vector v_i_64_av; res = op.GetAttr(name, v_i_64_av); return py::make_tuple(v_i_64_av, res); } break; case AT_LIST_INT32: { std::vector v_i_32_av; res = op.GetAttr(name, v_i_32_av); return py::make_tuple(v_i_32_av, res); } break; case AT_LIST_UINT32: { std::vector v_ui_32_av; res = op.GetAttr(name, v_ui_32_av); return py::make_tuple(v_ui_32_av, res); } break; case AT_FLOAT: { float f_av; res = op.GetAttr(name, f_av); return py::make_tuple(f_av, res); } break; case AT_LIST_FLOAT: { std::vector v_f_av; res = op.GetAttr(name, v_f_av); return py::make_tuple(v_f_av, res); } break; case AT_ATTR_VALUE: { AttrValue o_av; res = op.GetAttr(name, o_av); return py::make_tuple(o_av, res); } break; case AT_STRING: { ge::AscendString s_av; res = op.GetAttr(name, s_av); return py::make_tuple(s_av, res); } break; case AT_LIST_STRING: { std::vector v_s_av; res = op.GetAttr(name, v_s_av); return py::make_tuple(v_s_av, res); } break; case AT_BOOL: { bool b_av; res = op.GetAttr(name, b_av); return py::make_tuple(b_av, res); } break; case AT_LIST_BOOL: { std::vector v_b_av; res = op.GetAttr(name, v_b_av); return py::make_tuple(v_b_av, res); } break; case AT_TENSOR: { Tensor t_av; res = op.GetAttr(name, t_av); return py::make_tuple(t_av, res); } break; case AT_LIST_TENSOR: { std::vector v_t_av; res = op.GetAttr(name, v_t_av); return py::make_tuple(v_t_av, res); } break; case AT_LIST_UINT8: { std::vector v_ui_8_av; res = op.GetAttr(name, v_ui_8_av); return py::make_tuple(v_ui_8_av, res); } break; case AT_LIST_LIST_INT64: { std::vector> v_v_i_64_av; res = op.GetAttr(name, v_v_i_64_av); return py::make_tuple(v_v_i_64_av, res); } break; case AT_DT: { ge::DataType dt_av; res = op.GetAttr(name, dt_av); return py::make_tuple(dt_av, res); } break; case AT_LIST_DT: { std::vector v_dt_av; res = op.GetAttr(name, v_dt_av); return py::make_tuple(v_dt_av, res); } break; default: return py::make_tuple(0, res); break; } }) .def("break_connect", &Operator::BreakConnect) .def("get_subgraph_names_count", &Operator::GetSubgraphNamesCount) .def("get_subgraph_names", static_cast &) const>(&Operator::GetSubgraphNames)) .def("get_subgraph_builder", static_cast(&Operator::GetSubgraphBuilder)) .def("get_subgraph", static_cast(&Operator::GetSubgraph)) .def("get_dynamic_subgraph_builder", static_cast(&Operator::GetDynamicSubgraphBuilder)) .def("get_dynamic_subgraph", static_cast(&Operator::GetDynamicSubgraph)); py::class_(*m, "GETensor") .def(py::init<>()) .def(py::init()) .def(py::init &>()) .def(py::init()) .def("set_tensor_desc", &Tensor::SetTensorDesc) .def("get_tensor_desc", &Tensor::GetTensorDesc) // .def("set_data", (graphStatus(Tensor::*)(std::vector &&)) & // Tensor::SetData) .def("set_data", (graphStatus (Tensor::*)(const std::vector &)) & Tensor::SetData) .def("set_data", (graphStatus (Tensor::*)(const uint8_t *, size_t)) & Tensor::SetData) .def("set_data", (graphStatus (Tensor::*)(const char*)) & Tensor::SetData) .def("set_data", (graphStatus (Tensor::*)(const std::vector &)) & Tensor::SetData) .def("get_data", [](Tensor &ts) -> py::list { py::list v_data; uint8_t *data = ts.GetData(); size_t size = ts.GetSize(); for (size_t i = 0; i < size; ++i) { v_data.append(data[i]); } return v_data; }) .def("get_size", &Tensor::GetSize) .def("is_valid", &Tensor::IsValid) .def("clone", &Tensor::Clone); py::class_(*m, "GETensorDesc") .def(py::init<>()) .def(py::init(), py::arg("shape"), py::arg("format") = FORMAT_ND, py::arg("dt") = DT_FLOAT) .def(py::init()) .def("update", (void (TensorDesc::*)(const Shape&, Format, DataType)) & TensorDesc::Update, py::arg("shape"), py::arg("format") = FORMAT_ND, py::arg("dt") = DT_FLOAT) .def("set_shape", &TensorDesc::SetShape) .def("get_shape", &TensorDesc::GetShape) .def("set_unknown_dim_num_shape", &TensorDesc::SetUnknownDimNumShape) .def("set_shape_range", &TensorDesc::SetShapeRange) .def("get_shape_range", [](TensorDesc &tensorDesc) -> py::tuple { std::vector> range; graphStatus status = tensorDesc.GetShapeRange(range); return py::make_tuple(range, status); }) .def("set_format", &TensorDesc::SetFormat) .def("get_format", &TensorDesc::GetFormat) .def("get_origin_shape", &TensorDesc::GetOriginShape) .def("set_origin_shape", &TensorDesc::SetOriginShape) .def("set_origin_format", &TensorDesc::SetOriginFormat) .def("get_origin_format", &TensorDesc::GetOriginFormat) .def("set_data_type", &TensorDesc::SetDataType) .def("get_data_type", &TensorDesc::GetDataType) .def("set_name", static_cast(&TensorDesc::SetName)) .def("get_name", static_cast(&TensorDesc::GetName)) .def("set_size", &TensorDesc::SetSize) .def("get_size", &TensorDesc::GetSize) .def("set_real_dim_cnt", &TensorDesc::SetRealDimCnt) .def("get_real_dim_cnt", &TensorDesc::GetRealDimCnt); py::class_(*m, "GEShape") .def(py::init<>()) .def(py::init &>()) .def("get_dim_num", &Shape::GetDimNum) .def("set_dim", &Shape::SetDim) .def("get_dim", &Shape::GetDim) .def("get_dims", &Shape::GetDims) .def("get_shape_size", &Shape::GetShapeSize); py::class_(*m, "GEAttrValue").def(py::init<>()); py::class_(*m, "GEOperatorFactory") .def_static("create_operator", static_cast(&ge::OperatorFactory::CreateOperator)) .def("get_ops_type_list", []() -> py::tuple { std::vector all_ops; graphStatus status = OperatorFactory::GetOpsTypeList(all_ops); return py::make_tuple(all_ops, status); }) .def_static("is_exist_op", static_cast(&OperatorFactory::IsExistOp)); } } // end namespace pybind } // end namespace paddle #endif