From 6eabbc8076daba7fecc6434b61b33a4441b45b25 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 27 Jan 2021 15:28:31 +0800 Subject: [PATCH] fix compilation on ascend-20.1 (#30722) fix compilation on ascend-20.1 --- CMakeLists.txt | 5 + cmake/external/ascend.cmake | 4 + cmake/external/protobuf.cmake | 7 +- paddle/fluid/framework/fleet/ascend_wrapper.h | 46 +- paddle/fluid/pybind/ascend_wrapper_py.cc | 392 ++++++++++++------ 5 files changed, 314 insertions(+), 140 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 487aa200d7f..043a799b6a1 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,7 @@ option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF) option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF) option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF) option(WITH_ASCEND "Compile PaddlePaddle with ASCEND" OFF) +option(WITH_ASCEND_CXX11 "Compile PaddlePaddle with ASCEND and CXX11 ABI" OFF) if (WITH_GPU AND WITH_XPU) message(FATAL_ERROR "Error when compile GPU and XPU at the same time") endif() @@ -61,6 +62,10 @@ if(WITH_MUSL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations -Wno-error=pessimizing-move -Wno-error=deprecated-copy") endif() +if(WITH_ASCEND AND NOT WITH_ASCEND_CXX11) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") +endif() + if(WIN32) option(MSVC_STATIC_CRT "use static C Runtime library by default" ON) diff --git a/cmake/external/ascend.cmake b/cmake/external/ascend.cmake index 656007a5b96..a0b6f480f95 100644 --- a/cmake/external/ascend.cmake +++ b/cmake/external/ascend.cmake @@ -42,6 +42,10 @@ set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so) set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so) INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR}) +if(EXISTS ${ATLAS_RUNTIME_INC_DIR}/graph/ascend_string.h) + add_definitions(-DPADDLE_WITH_ASCEND_STRING) +endif() + ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL) SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib}) diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 905c17b9304..6bd188c4833 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -198,8 +198,13 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) "-Dprotobuf_MSVC_STATIC_RUNTIME=${MSVC_STATIC_CRT}") ENDIF() +if(WITH_ASCEND AND NOT WITH_ASCEND_CXX11) + SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git) + SET(PROTOBUF_TAG v3.8.0) +else() SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git) SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546) +endif() cache_third_party(${TARGET_NAME} REPOSITORY ${PROTOBUF_REPOSITORY} @@ -234,7 +239,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ) ENDFUNCTION() -SET(PROTOBUF_VERSION 3.1.0) +SET(PROTOBUF_VERSION 3.8.0) IF(NOT PROTOBUF_FOUND) build_protobuf(extern_protobuf FALSE) diff --git a/paddle/fluid/framework/fleet/ascend_wrapper.h b/paddle/fluid/framework/fleet/ascend_wrapper.h index a44466ca105..912d1b1c040 100644 --- a/paddle/fluid/framework/fleet/ascend_wrapper.h +++ b/paddle/fluid/framework/fleet/ascend_wrapper.h @@ -39,36 +39,40 @@ namespace framework { typedef ge::Graph AscendGraphDesc; +#ifdef PADDLE_WITH_ASCEND_STRING +using AscendString = AscendString; +#else +using AscendString = std::string; +#endif + class AscendInstance { public: virtual ~AscendInstance() {} AscendInstance() {} - std::map GetDefaultInitOptions() { - std::map init_options; - init_options["ge.exec.deviceId"] = "0"; - init_options["ge.graphRunMode"] = "1"; - return init_options; + std::map GetDefaultInitOptions() { + std::map init_options; + init_options["ge.exec.deviceId"] = "0"; + init_options["ge.graphRunMode"] = "1"; + return init_options; } - std::map GetDefaultInitSessionOptions() { - std::map init_options; - init_options["a"] = "b"; - init_options["ge.trainFlag"] = "1"; - return init_options; + std::map GetDefaultInitSessionOptions() { + std::map init_options; + init_options["a"] = "b"; + init_options["ge.trainFlag"] = "1"; + return init_options; } - ge::Status InitGEForUT(){ - return ge::GEInitialize(GetDefaultInitOptions()); - } + ge::Status InitGEForUT() { return ge::GEInitialize(GetDefaultInitOptions()); } void InitGlobalResouces() { - LOG(INFO) << "Begin InitGlobalResouces"; - session_.reset(new ge::Session(GetDefaultInitSessionOptions())); - if (session_ == nullptr){ - LOG(FATAL) << "new session error:" << session_; - } - LOG(INFO) << "End InitGlobalResouces"; + LOG(INFO) << "Begin InitGlobalResouces"; + session_.reset(new ge::Session(GetDefaultInitSessionOptions())); + if (session_ == nullptr) { + LOG(FATAL) << "new session error:" << session_; + } + LOG(INFO) << "End InitGlobalResouces"; } static std::shared_ptr GetInstance() { @@ -191,6 +195,6 @@ class AscendInstance { private: static std::shared_ptr ascend_instance_; }; -} // end namespace framework -} // end namespace paddle +} // namespace framework +} // namespace paddle #endif diff --git a/paddle/fluid/pybind/ascend_wrapper_py.cc b/paddle/fluid/pybind/ascend_wrapper_py.cc index 492eb7fb5d3..11c90b8f90d 100644 --- a/paddle/fluid/pybind/ascend_wrapper_py.cc +++ b/paddle/fluid/pybind/ascend_wrapper_py.cc @@ -32,9 +32,9 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/fleet/ascend_wrapper.h" -#include "paddle/fluid/pybind/ascend_wrapper_py.h" #include "paddle/fluid/platform/ascend_npu_info.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/pybind/ascend_wrapper_py.h" using namespace ge; // NOLINT namespace py = pybind11; @@ -42,6 +42,12 @@ namespace py = pybind11; namespace paddle { namespace pybind { +#ifdef PADDLE_WITH_ASCEND_STRING +using AscendString = AscendString; +#else +using AscendString = std::string; +#endif + void BindAscendWrapper(py::module *m) { py::class_>(*m, "AscendInstance") @@ -51,24 +57,26 @@ void BindAscendWrapper(py::module *m) { py::call_guard()) .def("add_ascend_subgraph", &framework::AscendInstance::AddAscendSubgraph, py::call_guard()); -} // end AscendWrapper +} -std::map convert_map(const std::map& options){ - std::map rets; +std::map convert_map( + const std::map &options) { + std::map rets; for (auto &option : options) { - ge::AscendString key = option.first.c_str(); - ge::AscendString val = option.second.c_str(); + AscendString key = option.first.c_str(); + AscendString val = option.second.c_str(); rets[key] = val; } return rets; } -ge::Status ge_initialize(std::map &options) { // NOLINT +ge::Status ge_initialize( + std::map &options) { // NOLINT py::gil_scoped_release release; - auto init_options=convert_map(options); + auto init_options = convert_map(options); ge::Status res = ge::GEInitialize(init_options); - PADDLE_ENFORCE_EQ(res, - ge::SUCCESS, platform::errors::Fatal("ge init error:%d", res)); + PADDLE_ENFORCE_EQ(res, ge::SUCCESS, + platform::errors::Fatal("ge init error:%d", res)); py::gil_scoped_acquire acquire; return res; } @@ -97,17 +105,18 @@ enum AttrType { AT_NAMEATTR }; -void BindAscendDevice(py::module* m) { - py::class_(*m, "NPUDevice") - .def_static("get_device_count", - static_cast(&platform::ascend::NPUDevice::GetDeviceCount)); +void BindAscendDevice(py::module *m) { + py::class_(*m, "NPUDevice") + .def_static( + "get_device_count", + static_cast(&platform::ascend::NPUDevice::GetDeviceCount)); } void BindAscendGraph(py::module *m) { m->def("ge_initialize", &ge_initialize, "GEInitialize"); m->def("ge_finalize", &GEFinalize, "GEFinalize"); - //枚举封装 + // enum py::enum_(*m, "GEGraphRunMode") .value("PREDICTION", GraphRunMode::PREDICTION) .value("TRAIN", GraphRunMode::TRAIN) @@ -235,29 +244,42 @@ void BindAscendGraph(py::module *m) { // 类封装 py::class_(*m, "GESession") - .def(py::init([](const std::map & options) { - return std::unique_ptr(new ge::Session(convert_map(options))); - })) - .def("add_graph", - (ge::Status (Session::*)(uint32_t, const Graph &)) & Session::AddGraph) + .def(py::init([](const std::map &options) { + return std::unique_ptr( + new ge::Session(convert_map(options))); + })) + .def("add_graph", (ge::Status (Session::*)(uint32_t, const Graph &)) & + Session::AddGraph) .def("add_graph", - [](Session& ss, uint32_t index, const Graph & graph, - const std::map &options){ - return ss.AddGraph(index, graph, convert_map(options)); - }) + [](Session &ss, uint32_t index, const Graph &graph, + const std::map &options) { + return ss.AddGraph(index, graph, convert_map(options)); + }) .def("remove_graph", &Session::RemoveGraph) .def("run_graph", [](Session &ss, uint32_t graphId, const std::vector &inputs) -> py::tuple { std::vector outputs; - ge::Status res = ss.RunGraph(graphId, inputs, outputs); + ge::Status res = ss.RunGraph(graphId, inputs, outputs); return py::make_tuple(outputs, res); }, py::call_guard()) .def("build_graph", &Session::BuildGraph) .def("run_graph_async", &Session::RunGraphAsync) - .def("register_call_back_func", - static_cast(&ge::Session::RegisterCallBackFunc)) +#ifdef PADDLE_WITH_ASCEND_STRING + .def("register_call_back_func", + static_cast( + &ge::Session::RegisterCallBackFunc)) +#else + .def("register_call_back_func", + (Status (Session::*)( // NOLINT + const std::string &, + std::function ¶ms_list)>)) & + Session::RegisterCallBackFunc) +#endif .def("is_graph_need_rebuild", &Session::IsGraphNeedRebuild); py::class_(*m, "GEGraph") @@ -272,121 +294,189 @@ void BindAscendGraph(py::module *m) { Graph::SetOutputs) .def("set_outputs", (Graph & - (Graph::*)(const std::vector> + (Graph::*)(const std::vector> &)) & Graph::SetOutputs) .def("set_targets", &Graph::SetTargets) .def("is_valid", &Graph::IsValid) .def("add_op", &Graph::AddOp) .def("find_op_by_name", - [](Graph &graph, const char* name) -> py::tuple { + [](Graph &graph, const char *name) -> py::tuple { ge::Operator op; graphStatus status = graph.FindOpByName(name, op); return py::make_tuple(op, status); }) .def("find_op_by_type", - [](Graph &graph, const char * type) -> py::tuple { + [](Graph &graph, const char *type) -> py::tuple { std::vector ops; graphStatus status = graph.FindOpByType(type, ops); return py::make_tuple(ops, status); }) .def("get_all_op_name", [](Graph &graph) -> py::tuple { - std::vector op_name; + std::vector op_name; graphStatus status = graph.GetAllOpName(op_name); return py::make_tuple(op_name, status); }) - .def("save_to_file", static_cast(&ge::Graph::SaveToFile)) - .def("load_from_file", static_cast(&Graph::LoadFromFile)) - .def("get_name", static_cast(&Graph::GetName)) +#ifdef PADDLE_WITH_ASCEND_STRING + .def("save_to_file", + static_cast( + &ge::Graph::SaveToFile)) + .def("load_from_file", + static_cast( + &Graph::LoadFromFile)) + .def("get_name", + static_cast( + &Graph::GetName)) +#else + .def("save_to_file", &Graph::SaveToFile) + .def("load_from_file", &Graph::LoadFromFile) + .def("get_name", &Graph::GetName) +#endif .def("set_need_iteration", &Graph::SetNeedIteration); py::class_(*m, "GEOperator") .def(py::init<>()) .def(py::init()) - .def(py::init()) + .def(py::init()) .def("is_empty", &Operator::IsEmpty) - .def("get_name", - static_cast(&Operator::GetName)) - .def("get_op_type", - static_cast(&Operator::GetOpType)) +#ifdef PADDLE_WITH_ASCEND_STRING + .def("get_name", + static_cast( + &Operator::GetName)) + .def("get_op_type", + static_cast( + &Operator::GetOpType)) .def("set_input", - (Operator & (Operator::*)(const char*, const Operator &)) & + (Operator & (Operator::*)(const char *, const Operator &)) & Operator::SetInput) .def("set_input", - (Operator & (Operator::*)(const char *, const Operator &, - const char *)) & + (Operator & + (Operator::*)(const char *, const Operator &, const char *)) & Operator::SetInput) .def("set_input", (Operator & (Operator::*)(const char *, const Operator &, uint32_t)) & Operator::SetInput) +#else + .def("get_name", &Operator::GetName) + .def("get_op_type", &Operator::GetOpType) + .def("set_input", + (Operator & (Operator::*)(const std::string &, const Operator &)) & + Operator::SetInput) + .def("set_input", + (Operator & (Operator::*)(const std::string &, const Operator &, + const std::string &)) & + Operator::SetInput) + .def("set_input", (Operator & (Operator::*)(const std::string &, + const Operator &, uint32_t)) & + Operator::SetInput) +#endif .def("add_control_input", &Operator::AddControlInput) .def("get_input_const_data", - [](Operator &op, const char* dst_name) -> py::tuple { + [](Operator &op, const char *dst_name) -> py::tuple { Tensor data; graphStatus res = op.GetInputConstData(dst_name, data); return py::make_tuple(data, res); }) +#ifdef PADDLE_WITH_ASCEND_STRING .def("get_input_desc", - (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc) + (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc) .def("get_input_desc", - [](Operator& op, const std::string& name){ - return op.GetInputDescByName(name.c_str()); + [](Operator &op, const std::string &name) { + return op.GetInputDescByName(name.c_str()); }) - .def("get_dynamic_output_num", static_cast(&Operator::GetDynamicOutputNum)) - .def("get_dynamic_input_num", static_cast(&Operator::GetDynamicInputNum)) + .def("get_dynamic_output_num", + static_cast( + &Operator::GetDynamicOutputNum)) + .def("get_dynamic_input_num", + static_cast( + &Operator::GetDynamicInputNum)) +#else + .def("get_input_desc", + (TensorDesc (Operator::*)(const std::string &) const) & + Operator::GetInputDesc) + .def("get_input_desc", + (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc) + .def("get_dynamic_output_num", &Operator::GetDynamicOutputNum) + .def("get_dynamic_input_num", &Operator::GetDynamicInputNum) +#endif .def("try_get_input_desc", - [](Operator &op, const char* name) -> py::tuple { + [](Operator &op, const char *name) -> py::tuple { TensorDesc tensor_desc; graphStatus status = op.TryGetInputDesc(name, tensor_desc); return py::make_tuple(tensor_desc, status); }) - .def("update_input_desc", - static_cast(&Operator::UpdateInputDesc)) +#ifdef PADDLE_WITH_ASCEND_STRING + .def("update_input_desc", + static_cast(&Operator::UpdateInputDesc)) .def("get_output_desc", - [](Operator& op, const std::string& name) { - return op.GetOutputDescByName(name.c_str()); + [](Operator &op, const std::string &name) { + return op.GetOutputDescByName(name.c_str()); }) .def("get_output_desc", (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetOutputDesc) - .def("update_output_desc", - static_cast(&Operator::UpdateOutputDesc)) - .def("get_dynamic_input_desc", - static_cast(&Operator::GetDynamicInputDesc)) - .def("update_dynamic_input_desc", - static_cast(&Operator::UpdateDynamicInputDesc)) - .def("get_dynamic_output_desc", - static_cast(&Operator::GetDynamicOutputDesc)) - .def("update_dynamic_output_desc", - static_cast(&Operator::UpdateDynamicOutputDesc)) + .def("update_output_desc", + static_cast(&Operator::UpdateOutputDesc)) + .def("get_dynamic_input_desc", + static_cast(&Operator::GetDynamicInputDesc)) + .def("update_dynamic_input_desc", + static_cast( + &Operator::UpdateDynamicInputDesc)) + .def("get_dynamic_output_desc", + static_cast(&Operator::GetDynamicOutputDesc)) + .def("update_dynamic_output_desc", + static_cast( + &Operator::UpdateDynamicOutputDesc)) +#else + .def("update_input_desc", &Operator::UpdateInputDesc) + .def("get_output_desc", + (TensorDesc (Operator::*)(const std::string &) const) & + Operator::GetOutputDesc) + .def("get_output_desc", + (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetOutputDesc) + .def("update_output_desc", &Operator::UpdateOutputDesc) + .def("get_dynamic_input_desc", &Operator::GetDynamicInputDesc) + .def("update_dynamic_input_desc", &Operator::UpdateDynamicInputDesc) + .def("get_dynamic_output_desc", &Operator::GetDynamicOutputDesc) + .def("update_dynamic_output_desc", &Operator::UpdateDynamicOutputDesc) +#endif .def("infer_shape_and_type", &Operator::InferShapeAndType) .def("set_inference_context", &Operator::SetInferenceContext) .def("get_inference_context", &Operator::GetInferenceContext) .def("verify_all_attr", &Operator::VerifyAllAttr) .def("get_inputs_size", &Operator::GetInputsSize) .def("get_outputs_size", &Operator::GetOutputsSize) - .def("get_all_attr_names_and_types", - static_cast&) const>(&Operator::GetAllAttrNamesAndTypes)) +#ifdef PADDLE_WITH_ASCEND_STRING + .def("get_all_attr_names_and_types", + static_cast &) const>( + &Operator::GetAllAttrNamesAndTypes)) +#else + .def("get_all_attr_names_and_types", &Operator::GetAllAttrNamesAndTypes) +#endif .def("set_attr_int64", - [](Operator &op, const char* name, - int64_t value) -> Operator & { + [](Operator &op, const char *name, int64_t value) -> Operator & { int64_t tar = (int64_t)value; return op.SetAttr(name, tar); }) .def("set_attr_int32", - [](Operator &op, const char* name, - int32_t value) -> Operator & { + [](Operator &op, const char *name, int32_t value) -> Operator & { int32_t tar = (int32_t)value; return op.SetAttr(name, tar); }) .def("set_attr_uint32", - [](Operator &op, const char* name, - uint32_t value) -> Operator & { + [](Operator &op, const char *name, uint32_t value) -> Operator & { uint32_t tar = (uint32_t)value; return op.SetAttr(name, tar); }) .def("set_attr_vec_int64", - [](Operator &op, const char* name, + [](Operator &op, const char *name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; @@ -398,7 +488,7 @@ void BindAscendGraph(py::module *m) { return op.SetAttr(name, tar); }) .def("set_attr_vec_int32", - [](Operator &op, const char * name, + [](Operator &op, const char *name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; @@ -410,7 +500,7 @@ void BindAscendGraph(py::module *m) { return op.SetAttr(name, tar); }) .def("set_attr_vec_uint32", - [](Operator &op, const char* name, + [](Operator &op, const char *name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; @@ -422,21 +512,20 @@ void BindAscendGraph(py::module *m) { return op.SetAttr(name, tar); }) .def("set_attr_list_int64", - [](Operator &op, const char* name, + [](Operator &op, const char *name, std::initializer_list &attrValue) -> Operator & { return op.SetAttr(name, std::move(attrValue)); }) .def("set_attr_attrvalue", - [](Operator &op, const char* name, AttrValue &attrValue) + [](Operator &op, const char *name, AttrValue &attrValue) -> Operator & { return op.SetAttr(name, std::move(attrValue)); }) - .def( - "set_attr_float", - [](Operator &op, const char* name, float value) -> Operator & { - float tar = static_cast(value); - return op.SetAttr(name, tar); - }) + .def("set_attr_float", + [](Operator &op, const char *name, float value) -> Operator & { + float tar = static_cast(value); + return op.SetAttr(name, tar); + }) .def("set_attr_vec_float", - [](Operator &op, const char* name, + [](Operator &op, const char *name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; @@ -447,22 +536,32 @@ void BindAscendGraph(py::module *m) { } return op.SetAttr(name, tar); }) - .def("set_attr_string", (Operator & (Operator::*)(const char*, - const char*)) & +#ifdef PADDLE_WITH_ASCEND_STRING + .def("set_attr_string", + (Operator & (Operator::*)(const char *, const char *)) & + Operator::SetAttr) + .def("set_attr_vec_string", + (Operator & + (Operator::*)(const char *, const std::vector &)) & + Operator::SetAttr) +#else + .def("set_attr_string", (Operator & (Operator::*)(const std::string &, + const std::string &)) & Operator::SetAttr) .def("set_attr_vec_string", - (Operator & (Operator::*)(const char*, - const std::vector &)) & + (Operator & (Operator::*)(const std::string &, + const std::vector &)) & Operator::SetAttr) +#endif .def("set_attr_bool", - [](Operator &op, const char* name, bool value) -> Operator & { + [](Operator &op, const char *name, bool value) -> Operator & { if (value) return op.SetAttr(name, true); else return op.SetAttr(name, false); }) .def("set_attr_vec_bool", - [](Operator &op, const char* name, + [](Operator &op, const char *name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; @@ -474,15 +573,25 @@ void BindAscendGraph(py::module *m) { } return op.SetAttr(name, tar); }) +#ifdef PADDLE_WITH_ASCEND_STRING .def("set_attr_tensor", - (Operator & (Operator::*)(const char* , const Tensor &)) & + (Operator & (Operator::*)(const char *, const Tensor &)) & Operator::SetAttr) .def("set_attr_vec_tensor", (Operator & (Operator::*)(const char *, const std::vector &)) & Operator::SetAttr) +#else + .def("set_attr_tensor", + (Operator & (Operator::*)(const std::string &, const Tensor &)) & + Operator::SetAttr) + .def("set_attr_vec_tensor", + (Operator & + (Operator::*)(const std::string &, const std::vector &)) & + Operator::SetAttr) +#endif .def("set_attr_vec_uint8", - [](Operator &op, const char* name, + [](Operator &op, const char *name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; @@ -493,13 +602,21 @@ void BindAscendGraph(py::module *m) { } return op.SetAttr(name, tar); }) +#ifdef PADDLE_WITH_ASCEND_STRING + .def("set_attr_vec_vec_int64", + (Operator & + (Operator::*)(const char *, + const std::vector> &)) & + Operator::SetAttr) +#else .def("set_attr_vec_vec_int64", (Operator & - (Operator::*)(const char*, + (Operator::*)(const std::string &, const std::vector> &)) & Operator::SetAttr) +#endif .def("set_attr_vec_dtype", - [](Operator &op, const char* name, + [](Operator &op, const char *name, const std::vector &value) -> Operator & { int len = value.size(); std::vector tar; @@ -511,15 +628,13 @@ void BindAscendGraph(py::module *m) { return op.SetAttr(name, tar); }) .def("set_attr_dtype", - [](Operator &op, const char* name, + [](Operator &op, const char *name, const DataType &value) -> Operator & { ge::DataType tar = (ge::DataType)value; return op.SetAttr(name, tar); }) - .def("get_attr", - [](Operator &op, const char* name, - AttrType type) -> py::tuple { + [](Operator &op, const char *name, AttrType type) -> py::tuple { graphStatus res = -1; switch (type) { case AT_INT64: { @@ -568,12 +683,12 @@ void BindAscendGraph(py::module *m) { return py::make_tuple(o_av, res); } break; case AT_STRING: { - ge::AscendString s_av; + AscendString s_av; res = op.GetAttr(name, s_av); return py::make_tuple(s_av, res); } break; case AT_LIST_STRING: { - std::vector v_s_av; + std::vector v_s_av; res = op.GetAttr(name, v_s_av); return py::make_tuple(v_s_av, res); } break; @@ -624,11 +739,31 @@ void BindAscendGraph(py::module *m) { }) .def("break_connect", &Operator::BreakConnect) .def("get_subgraph_names_count", &Operator::GetSubgraphNamesCount) - .def("get_subgraph_names", static_cast &) const>(&Operator::GetSubgraphNames)) - .def("get_subgraph_builder", static_cast(&Operator::GetSubgraphBuilder)) - .def("get_subgraph", static_cast(&Operator::GetSubgraph)) - .def("get_dynamic_subgraph_builder", static_cast(&Operator::GetDynamicSubgraphBuilder)) - .def("get_dynamic_subgraph", static_cast(&Operator::GetDynamicSubgraph)); +#ifdef PADDLE_WITH_ASCEND_STRING + .def("get_subgraph_names", + static_cast &) const>(&Operator::GetSubgraphNames)) + .def("get_subgraph_builder", + static_cast(&Operator::GetSubgraphBuilder)) + .def("get_subgraph", + static_cast( + &Operator::GetSubgraph)) + .def("get_dynamic_subgraph_builder", + static_cast( + &Operator::GetDynamicSubgraphBuilder)) + .def("get_dynamic_subgraph", + static_cast(&Operator::GetDynamicSubgraph)); +#else + .def("get_subgraph_names_count", &Operator::GetSubgraphNamesCount) + .def("get_subgraph_names", &Operator::GetSubgraphNames) + .def("get_subgraph_builder", &Operator::GetSubgraphBuilder) + .def("get_subgraph", &Operator::GetSubgraph) + .def("get_dynamic_subgraph_builder", &Operator::GetDynamicSubgraphBuilder) + .def("get_dynamic_subgraph", &Operator::GetDynamicSubgraph); +#endif py::class_(*m, "GETensor") .def(py::init<>()) @@ -643,10 +778,15 @@ void BindAscendGraph(py::module *m) { Tensor::SetData) .def("set_data", (graphStatus (Tensor::*)(const uint8_t *, size_t)) & Tensor::SetData) +#ifdef PADDLE_WITH_ASCEND_STRING .def("set_data", - (graphStatus (Tensor::*)(const char*)) & Tensor::SetData) + (graphStatus (Tensor::*)(const char *)) & Tensor::SetData) +#else .def("set_data", - (graphStatus (Tensor::*)(const std::vector &)) & + (graphStatus (Tensor::*)(const std::string &)) & Tensor::SetData) +#endif + .def("set_data", + (graphStatus (Tensor::*)(const std::vector &)) & Tensor::SetData) .def("get_data", @@ -668,8 +808,8 @@ void BindAscendGraph(py::module *m) { .def(py::init(), py::arg("shape"), py::arg("format") = FORMAT_ND, py::arg("dt") = DT_FLOAT) .def(py::init()) - .def("update", - (void (TensorDesc::*)(const Shape&, Format, DataType)) & TensorDesc::Update, + .def("update", (void (TensorDesc::*)(const Shape &, Format, DataType)) & + TensorDesc::Update, py::arg("shape"), py::arg("format") = FORMAT_ND, py::arg("dt") = DT_FLOAT) .def("set_shape", &TensorDesc::SetShape) @@ -690,8 +830,16 @@ void BindAscendGraph(py::module *m) { .def("get_origin_format", &TensorDesc::GetOriginFormat) .def("set_data_type", &TensorDesc::SetDataType) .def("get_data_type", &TensorDesc::GetDataType) - .def("set_name", static_cast(&TensorDesc::SetName)) - .def("get_name", static_cast(&TensorDesc::GetName)) +#ifdef PADDLE_WITH_ASCEND_STRING + .def("set_name", static_cast( + &TensorDesc::SetName)) + .def("get_name", + static_cast( + &TensorDesc::GetName)) +#else + .def("set_name", &TensorDesc::SetName) + .def("get_name", &TensorDesc::GetName) +#endif .def("set_size", &TensorDesc::SetSize) .def("get_size", &TensorDesc::GetSize) .def("set_real_dim_cnt", &TensorDesc::SetRealDimCnt) @@ -709,19 +857,27 @@ void BindAscendGraph(py::module *m) { py::class_(*m, "GEAttrValue").def(py::init<>()); py::class_(*m, "GEOperatorFactory") - .def_static("create_operator", - static_cast(&ge::OperatorFactory::CreateOperator)) +#ifdef PADDLE_WITH_ASCEND_STRING + .def_static("create_operator", + static_cast( + &ge::OperatorFactory::CreateOperator)) +#else + .def("create_operator", &OperatorFactory::CreateOperator) +#endif .def("get_ops_type_list", []() -> py::tuple { - std::vector all_ops; + std::vector all_ops; graphStatus status = OperatorFactory::GetOpsTypeList(all_ops); return py::make_tuple(all_ops, status); }) - .def_static("is_exist_op", - static_cast(&OperatorFactory::IsExistOp)); - +#ifdef PADDLE_WITH_ASCEND_STRING + .def_static("is_exist_op", static_cast( + &OperatorFactory::IsExistOp)); +#else + .def("is_exist_op", &OperatorFactory::IsExistOp); +#endif } -} // end namespace pybind -} // end namespace paddle +} // namespace pybind +} // namespace paddle #endif -- GitLab