未验证 提交 6eabbc80 编写于 作者: L Leo Chen 提交者: GitHub

fix compilation on ascend-20.1 (#30722)

fix compilation on ascend-20.1
上级 904cc443
...@@ -32,6 +32,7 @@ option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF) ...@@ -32,6 +32,7 @@ option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF) option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF) option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF)
option(WITH_ASCEND "Compile PaddlePaddle with ASCEND" OFF) option(WITH_ASCEND "Compile PaddlePaddle with ASCEND" OFF)
option(WITH_ASCEND_CXX11 "Compile PaddlePaddle with ASCEND and CXX11 ABI" OFF)
if (WITH_GPU AND WITH_XPU) if (WITH_GPU AND WITH_XPU)
message(FATAL_ERROR "Error when compile GPU and XPU at the same time") message(FATAL_ERROR "Error when compile GPU and XPU at the same time")
endif() endif()
...@@ -61,6 +62,10 @@ if(WITH_MUSL) ...@@ -61,6 +62,10 @@ if(WITH_MUSL)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations -Wno-error=pessimizing-move -Wno-error=deprecated-copy") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations -Wno-error=pessimizing-move -Wno-error=deprecated-copy")
endif() endif()
if(WITH_ASCEND AND NOT WITH_ASCEND_CXX11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
if(WIN32) if(WIN32)
option(MSVC_STATIC_CRT "use static C Runtime library by default" ON) option(MSVC_STATIC_CRT "use static C Runtime library by default" ON)
......
...@@ -42,6 +42,10 @@ set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so) ...@@ -42,6 +42,10 @@ set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so)
set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so) set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so)
INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR}) INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR})
if(EXISTS ${ATLAS_RUNTIME_INC_DIR}/graph/ascend_string.h)
add_definitions(-DPADDLE_WITH_ASCEND_STRING)
endif()
ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL) ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib}) SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib})
......
...@@ -198,8 +198,13 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -198,8 +198,13 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"-Dprotobuf_MSVC_STATIC_RUNTIME=${MSVC_STATIC_CRT}") "-Dprotobuf_MSVC_STATIC_RUNTIME=${MSVC_STATIC_CRT}")
ENDIF() ENDIF()
if(WITH_ASCEND AND NOT WITH_ASCEND_CXX11)
SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git)
SET(PROTOBUF_TAG v3.8.0)
else()
SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git) SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git)
SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546) SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546)
endif()
cache_third_party(${TARGET_NAME} cache_third_party(${TARGET_NAME}
REPOSITORY ${PROTOBUF_REPOSITORY} REPOSITORY ${PROTOBUF_REPOSITORY}
...@@ -234,7 +239,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -234,7 +239,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
) )
ENDFUNCTION() ENDFUNCTION()
SET(PROTOBUF_VERSION 3.1.0) SET(PROTOBUF_VERSION 3.8.0)
IF(NOT PROTOBUF_FOUND) IF(NOT PROTOBUF_FOUND)
build_protobuf(extern_protobuf FALSE) build_protobuf(extern_protobuf FALSE)
......
...@@ -39,33 +39,37 @@ namespace framework { ...@@ -39,33 +39,37 @@ namespace framework {
typedef ge::Graph AscendGraphDesc; typedef ge::Graph AscendGraphDesc;
#ifdef PADDLE_WITH_ASCEND_STRING
using AscendString = AscendString;
#else
using AscendString = std::string;
#endif
class AscendInstance { class AscendInstance {
public: public:
virtual ~AscendInstance() {} virtual ~AscendInstance() {}
AscendInstance() {} AscendInstance() {}
std::map<ge::AscendString, ge::AscendString> GetDefaultInitOptions() { std::map<AscendString, AscendString> GetDefaultInitOptions() {
std::map<ge::AscendString, ge::AscendString> init_options; std::map<AscendString, AscendString> init_options;
init_options["ge.exec.deviceId"] = "0"; init_options["ge.exec.deviceId"] = "0";
init_options["ge.graphRunMode"] = "1"; init_options["ge.graphRunMode"] = "1";
return init_options; return init_options;
} }
std::map<ge::AscendString, ge::AscendString> GetDefaultInitSessionOptions() { std::map<AscendString, AscendString> GetDefaultInitSessionOptions() {
std::map<ge::AscendString, ge::AscendString> init_options; std::map<AscendString, AscendString> init_options;
init_options["a"] = "b"; init_options["a"] = "b";
init_options["ge.trainFlag"] = "1"; init_options["ge.trainFlag"] = "1";
return init_options; return init_options;
} }
ge::Status InitGEForUT(){ ge::Status InitGEForUT() { return ge::GEInitialize(GetDefaultInitOptions()); }
return ge::GEInitialize(GetDefaultInitOptions());
}
void InitGlobalResouces() { void InitGlobalResouces() {
LOG(INFO) << "Begin InitGlobalResouces"; LOG(INFO) << "Begin InitGlobalResouces";
session_.reset(new ge::Session(GetDefaultInitSessionOptions())); session_.reset(new ge::Session(GetDefaultInitSessionOptions()));
if (session_ == nullptr){ if (session_ == nullptr) {
LOG(FATAL) << "new session error:" << session_; LOG(FATAL) << "new session error:" << session_;
} }
LOG(INFO) << "End InitGlobalResouces"; LOG(INFO) << "End InitGlobalResouces";
...@@ -191,6 +195,6 @@ class AscendInstance { ...@@ -191,6 +195,6 @@ class AscendInstance {
private: private:
static std::shared_ptr<AscendInstance> ascend_instance_; static std::shared_ptr<AscendInstance> ascend_instance_;
}; };
} // end namespace framework } // namespace framework
} // end namespace paddle } // namespace paddle
#endif #endif
...@@ -32,9 +32,9 @@ limitations under the License. */ ...@@ -32,9 +32,9 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/fleet/ascend_wrapper.h" #include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#include "paddle/fluid/pybind/ascend_wrapper_py.h"
#include "paddle/fluid/platform/ascend_npu_info.h" #include "paddle/fluid/platform/ascend_npu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/ascend_wrapper_py.h"
using namespace ge; // NOLINT using namespace ge; // NOLINT
namespace py = pybind11; namespace py = pybind11;
...@@ -42,6 +42,12 @@ namespace py = pybind11; ...@@ -42,6 +42,12 @@ namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
#ifdef PADDLE_WITH_ASCEND_STRING
using AscendString = AscendString;
#else
using AscendString = std::string;
#endif
void BindAscendWrapper(py::module *m) { void BindAscendWrapper(py::module *m) {
py::class_<framework::AscendInstance, py::class_<framework::AscendInstance,
std::shared_ptr<framework::AscendInstance>>(*m, "AscendInstance") std::shared_ptr<framework::AscendInstance>>(*m, "AscendInstance")
...@@ -51,24 +57,26 @@ void BindAscendWrapper(py::module *m) { ...@@ -51,24 +57,26 @@ void BindAscendWrapper(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("add_ascend_subgraph", &framework::AscendInstance::AddAscendSubgraph, .def("add_ascend_subgraph", &framework::AscendInstance::AddAscendSubgraph,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} // end AscendWrapper }
std::map<ge::AscendString, ge::AscendString> convert_map(const std::map<std::string, std::string>& options){ std::map<AscendString, AscendString> convert_map(
std::map<ge::AscendString, ge::AscendString> rets; const std::map<std::string, std::string> &options) {
std::map<AscendString, AscendString> rets;
for (auto &option : options) { for (auto &option : options) {
ge::AscendString key = option.first.c_str(); AscendString key = option.first.c_str();
ge::AscendString val = option.second.c_str(); AscendString val = option.second.c_str();
rets[key] = val; rets[key] = val;
} }
return rets; return rets;
} }
ge::Status ge_initialize(std::map<std::string, std::string> &options) { // NOLINT ge::Status ge_initialize(
std::map<std::string, std::string> &options) { // NOLINT
py::gil_scoped_release release; py::gil_scoped_release release;
auto init_options=convert_map(options); auto init_options = convert_map(options);
ge::Status res = ge::GEInitialize(init_options); ge::Status res = ge::GEInitialize(init_options);
PADDLE_ENFORCE_EQ(res, PADDLE_ENFORCE_EQ(res, ge::SUCCESS,
ge::SUCCESS, platform::errors::Fatal("ge init error:%d", res)); platform::errors::Fatal("ge init error:%d", res));
py::gil_scoped_acquire acquire; py::gil_scoped_acquire acquire;
return res; return res;
} }
...@@ -97,9 +105,10 @@ enum AttrType { ...@@ -97,9 +105,10 @@ enum AttrType {
AT_NAMEATTR AT_NAMEATTR
}; };
void BindAscendDevice(py::module* m) { void BindAscendDevice(py::module *m) {
py::class_<platform::ascend::NPUDevice>(*m, "NPUDevice") py::class_<platform::ascend::NPUDevice>(*m, "NPUDevice")
.def_static("get_device_count", .def_static(
"get_device_count",
static_cast<int (*)()>(&platform::ascend::NPUDevice::GetDeviceCount)); static_cast<int (*)()>(&platform::ascend::NPUDevice::GetDeviceCount));
} }
...@@ -107,7 +116,7 @@ void BindAscendGraph(py::module *m) { ...@@ -107,7 +116,7 @@ void BindAscendGraph(py::module *m) {
m->def("ge_initialize", &ge_initialize, "GEInitialize"); m->def("ge_initialize", &ge_initialize, "GEInitialize");
m->def("ge_finalize", &GEFinalize, "GEFinalize"); m->def("ge_finalize", &GEFinalize, "GEFinalize");
//枚举封装 // enum
py::enum_<GraphRunMode>(*m, "GEGraphRunMode") py::enum_<GraphRunMode>(*m, "GEGraphRunMode")
.value("PREDICTION", GraphRunMode::PREDICTION) .value("PREDICTION", GraphRunMode::PREDICTION)
.value("TRAIN", GraphRunMode::TRAIN) .value("TRAIN", GraphRunMode::TRAIN)
...@@ -235,14 +244,15 @@ void BindAscendGraph(py::module *m) { ...@@ -235,14 +244,15 @@ void BindAscendGraph(py::module *m) {
// 类封装 // 类封装
py::class_<Session>(*m, "GESession") py::class_<Session>(*m, "GESession")
.def(py::init([](const std::map<std::string, std::string> & options) { .def(py::init([](const std::map<std::string, std::string> &options) {
return std::unique_ptr<ge::Session>(new ge::Session(convert_map(options))); return std::unique_ptr<ge::Session>(
new ge::Session(convert_map(options)));
})) }))
.def("add_graph", (ge::Status (Session::*)(uint32_t, const Graph &)) &
Session::AddGraph)
.def("add_graph", .def("add_graph",
(ge::Status (Session::*)(uint32_t, const Graph &)) & Session::AddGraph) [](Session &ss, uint32_t index, const Graph &graph,
.def("add_graph", const std::map<std::string, std::string> &options) {
[](Session& ss, uint32_t index, const Graph & graph,
const std::map<std::string, std::string> &options){
return ss.AddGraph(index, graph, convert_map(options)); return ss.AddGraph(index, graph, convert_map(options));
}) })
.def("remove_graph", &Session::RemoveGraph) .def("remove_graph", &Session::RemoveGraph)
...@@ -256,8 +266,20 @@ void BindAscendGraph(py::module *m) { ...@@ -256,8 +266,20 @@ void BindAscendGraph(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("build_graph", &Session::BuildGraph) .def("build_graph", &Session::BuildGraph)
.def("run_graph_async", &Session::RunGraphAsync) .def("run_graph_async", &Session::RunGraphAsync)
#ifdef PADDLE_WITH_ASCEND_STRING
.def("register_call_back_func",
static_cast<ge::Status (ge::Session::*)( // NOLINT
const char *, const ge::Session::pCallBackFunc &)>(
&ge::Session::RegisterCallBackFunc))
#else
.def("register_call_back_func", .def("register_call_back_func",
static_cast<ge::Status (ge::Session::*)(const char*, const ge::session::pCallBackFunc&)>(&ge::Session::RegisterCallBackFunc)) (Status (Session::*)( // NOLINT
const std::string &,
std::function<uint32_t(
uint32_t graph_id,
const std::map<std::string, ge::Tensor> &params_list)>)) &
Session::RegisterCallBackFunc)
#endif
.def("is_graph_need_rebuild", &Session::IsGraphNeedRebuild); .def("is_graph_need_rebuild", &Session::IsGraphNeedRebuild);
py::class_<Graph>(*m, "GEGraph") py::class_<Graph>(*m, "GEGraph")
...@@ -272,121 +294,189 @@ void BindAscendGraph(py::module *m) { ...@@ -272,121 +294,189 @@ void BindAscendGraph(py::module *m) {
Graph::SetOutputs) Graph::SetOutputs)
.def("set_outputs", .def("set_outputs",
(Graph & (Graph &
(Graph::*)(const std::vector<std::pair<ge::Operator, ge::AscendString>> (Graph::*)(const std::vector<std::pair<ge::Operator, AscendString>>
&)) & &)) &
Graph::SetOutputs) Graph::SetOutputs)
.def("set_targets", &Graph::SetTargets) .def("set_targets", &Graph::SetTargets)
.def("is_valid", &Graph::IsValid) .def("is_valid", &Graph::IsValid)
.def("add_op", &Graph::AddOp) .def("add_op", &Graph::AddOp)
.def("find_op_by_name", .def("find_op_by_name",
[](Graph &graph, const char* name) -> py::tuple { [](Graph &graph, const char *name) -> py::tuple {
ge::Operator op; ge::Operator op;
graphStatus status = graph.FindOpByName(name, op); graphStatus status = graph.FindOpByName(name, op);
return py::make_tuple(op, status); return py::make_tuple(op, status);
}) })
.def("find_op_by_type", .def("find_op_by_type",
[](Graph &graph, const char * type) -> py::tuple { [](Graph &graph, const char *type) -> py::tuple {
std::vector<ge::Operator> ops; std::vector<ge::Operator> ops;
graphStatus status = graph.FindOpByType(type, ops); graphStatus status = graph.FindOpByType(type, ops);
return py::make_tuple(ops, status); return py::make_tuple(ops, status);
}) })
.def("get_all_op_name", .def("get_all_op_name",
[](Graph &graph) -> py::tuple { [](Graph &graph) -> py::tuple {
std::vector<ge::AscendString> op_name; std::vector<AscendString> op_name;
graphStatus status = graph.GetAllOpName(op_name); graphStatus status = graph.GetAllOpName(op_name);
return py::make_tuple(op_name, status); return py::make_tuple(op_name, status);
}) })
.def("save_to_file", static_cast<ge::graphStatus (ge::Graph::*)(const char *) const>(&ge::Graph::SaveToFile)) #ifdef PADDLE_WITH_ASCEND_STRING
.def("load_from_file", static_cast<ge::graphStatus (ge::Graph::*)(const char*)>(&Graph::LoadFromFile)) .def("save_to_file",
.def("get_name", static_cast<ge::graphStatus (ge::Graph::*)(ge::AscendString&) const>(&Graph::GetName)) static_cast<ge::graphStatus (ge::Graph::*)(const char *) const>(
&ge::Graph::SaveToFile))
.def("load_from_file",
static_cast<ge::graphStatus (ge::Graph::*)(const char *)>(
&Graph::LoadFromFile))
.def("get_name",
static_cast<ge::graphStatus (ge::Graph::*)(AscendString &) const>(
&Graph::GetName))
#else
.def("save_to_file", &Graph::SaveToFile)
.def("load_from_file", &Graph::LoadFromFile)
.def("get_name", &Graph::GetName)
#endif
.def("set_need_iteration", &Graph::SetNeedIteration); .def("set_need_iteration", &Graph::SetNeedIteration);
py::class_<Operator>(*m, "GEOperator") py::class_<Operator>(*m, "GEOperator")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const char *>()) .def(py::init<const char *>())
.def(py::init<const char*, const char *>()) .def(py::init<const char *, const char *>())
.def("is_empty", &Operator::IsEmpty) .def("is_empty", &Operator::IsEmpty)
#ifdef PADDLE_WITH_ASCEND_STRING
.def("get_name", .def("get_name",
static_cast<ge::graphStatus (ge::Operator::*)(ge::AscendString&) const>(&Operator::GetName)) static_cast<ge::graphStatus (ge::Operator::*)(AscendString &) const>(
&Operator::GetName))
.def("get_op_type", .def("get_op_type",
static_cast<ge::graphStatus (ge::Operator::*)(ge::AscendString&) const>(&Operator::GetOpType)) static_cast<ge::graphStatus (ge::Operator::*)(AscendString &) const>(
&Operator::GetOpType))
.def("set_input", .def("set_input",
(Operator & (Operator::*)(const char*, const Operator &)) & (Operator & (Operator::*)(const char *, const Operator &)) &
Operator::SetInput) Operator::SetInput)
.def("set_input", .def("set_input",
(Operator & (Operator::*)(const char *, const Operator &, (Operator &
const char *)) & (Operator::*)(const char *, const Operator &, const char *)) &
Operator::SetInput) Operator::SetInput)
.def("set_input", (Operator & (Operator::*)(const char *, .def("set_input", (Operator & (Operator::*)(const char *,
const Operator &, uint32_t)) & const Operator &, uint32_t)) &
Operator::SetInput) Operator::SetInput)
#else
.def("get_name", &Operator::GetName)
.def("get_op_type", &Operator::GetOpType)
.def("set_input",
(Operator & (Operator::*)(const std::string &, const Operator &)) &
Operator::SetInput)
.def("set_input",
(Operator & (Operator::*)(const std::string &, const Operator &,
const std::string &)) &
Operator::SetInput)
.def("set_input", (Operator & (Operator::*)(const std::string &,
const Operator &, uint32_t)) &
Operator::SetInput)
#endif
.def("add_control_input", &Operator::AddControlInput) .def("add_control_input", &Operator::AddControlInput)
.def("get_input_const_data", .def("get_input_const_data",
[](Operator &op, const char* dst_name) -> py::tuple { [](Operator &op, const char *dst_name) -> py::tuple {
Tensor data; Tensor data;
graphStatus res = op.GetInputConstData(dst_name, data); graphStatus res = op.GetInputConstData(dst_name, data);
return py::make_tuple(data, res); return py::make_tuple(data, res);
}) })
#ifdef PADDLE_WITH_ASCEND_STRING
.def("get_input_desc", .def("get_input_desc",
(TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc) (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc)
.def("get_input_desc", .def("get_input_desc",
[](Operator& op, const std::string& name){ [](Operator &op, const std::string &name) {
return op.GetInputDescByName(name.c_str()); return op.GetInputDescByName(name.c_str());
}) })
.def("get_dynamic_output_num", static_cast<int (ge::Operator::*)(const char*) const>(&Operator::GetDynamicOutputNum)) .def("get_dynamic_output_num",
.def("get_dynamic_input_num", static_cast<int (ge::Operator::*)(const char*) const>(&Operator::GetDynamicInputNum)) static_cast<int (ge::Operator::*)(const char *) const>(
&Operator::GetDynamicOutputNum))
.def("get_dynamic_input_num",
static_cast<int (ge::Operator::*)(const char *) const>(
&Operator::GetDynamicInputNum))
#else
.def("get_input_desc",
(TensorDesc (Operator::*)(const std::string &) const) &
Operator::GetInputDesc)
.def("get_input_desc",
(TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc)
.def("get_dynamic_output_num", &Operator::GetDynamicOutputNum)
.def("get_dynamic_input_num", &Operator::GetDynamicInputNum)
#endif
.def("try_get_input_desc", .def("try_get_input_desc",
[](Operator &op, const char* name) -> py::tuple { [](Operator &op, const char *name) -> py::tuple {
TensorDesc tensor_desc; TensorDesc tensor_desc;
graphStatus status = op.TryGetInputDesc(name, tensor_desc); graphStatus status = op.TryGetInputDesc(name, tensor_desc);
return py::make_tuple(tensor_desc, status); return py::make_tuple(tensor_desc, status);
}) })
#ifdef PADDLE_WITH_ASCEND_STRING
.def("update_input_desc", .def("update_input_desc",
static_cast<ge::graphStatus (ge::Operator::*)(const char*, const TensorDesc&)>(&Operator::UpdateInputDesc)) static_cast<ge::graphStatus (ge::Operator::*)( // NOLINT
const char *, const TensorDesc &)>(&Operator::UpdateInputDesc))
.def("get_output_desc", .def("get_output_desc",
[](Operator& op, const std::string& name) { [](Operator &op, const std::string &name) {
return op.GetOutputDescByName(name.c_str()); return op.GetOutputDescByName(name.c_str());
}) })
.def("get_output_desc", .def("get_output_desc",
(TensorDesc (Operator::*)(uint32_t) const) & Operator::GetOutputDesc) (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetOutputDesc)
.def("update_output_desc", .def("update_output_desc",
static_cast<ge::graphStatus (ge::Operator::*)(const char*, const TensorDesc&)>(&Operator::UpdateOutputDesc)) static_cast<ge::graphStatus (ge::Operator::*)( // NOLINT
const char *, const TensorDesc &)>(&Operator::UpdateOutputDesc))
.def("get_dynamic_input_desc", .def("get_dynamic_input_desc",
static_cast<ge::TensorDesc (ge::Operator::*)(const char*, uint32_t) const>(&Operator::GetDynamicInputDesc)) static_cast<ge::TensorDesc (ge::Operator::*)(const char *, uint32_t)
const>(&Operator::GetDynamicInputDesc))
.def("update_dynamic_input_desc", .def("update_dynamic_input_desc",
static_cast<ge::graphStatus (ge::Operator::*)(const char*, uint32_t, const TensorDesc&)>(&Operator::UpdateDynamicInputDesc)) static_cast<ge::graphStatus (ge::Operator::*)(const char *, uint32_t,
const TensorDesc &)>(
&Operator::UpdateDynamicInputDesc))
.def("get_dynamic_output_desc", .def("get_dynamic_output_desc",
static_cast<ge::TensorDesc (ge::Operator::*)(const char*, uint32_t) const>(&Operator::GetDynamicOutputDesc)) static_cast<ge::TensorDesc (ge::Operator::*)(const char *, uint32_t)
const>(&Operator::GetDynamicOutputDesc))
.def("update_dynamic_output_desc", .def("update_dynamic_output_desc",
static_cast<ge::graphStatus (ge::Operator::*)(const char*, uint32_t, const TensorDesc&)>(&Operator::UpdateDynamicOutputDesc)) static_cast<ge::graphStatus (ge::Operator::*)(const char *, uint32_t,
const TensorDesc &)>(
&Operator::UpdateDynamicOutputDesc))
#else
.def("update_input_desc", &Operator::UpdateInputDesc)
.def("get_output_desc",
(TensorDesc (Operator::*)(const std::string &) const) &
Operator::GetOutputDesc)
.def("get_output_desc",
(TensorDesc (Operator::*)(uint32_t) const) & Operator::GetOutputDesc)
.def("update_output_desc", &Operator::UpdateOutputDesc)
.def("get_dynamic_input_desc", &Operator::GetDynamicInputDesc)
.def("update_dynamic_input_desc", &Operator::UpdateDynamicInputDesc)
.def("get_dynamic_output_desc", &Operator::GetDynamicOutputDesc)
.def("update_dynamic_output_desc", &Operator::UpdateDynamicOutputDesc)
#endif
.def("infer_shape_and_type", &Operator::InferShapeAndType) .def("infer_shape_and_type", &Operator::InferShapeAndType)
.def("set_inference_context", &Operator::SetInferenceContext) .def("set_inference_context", &Operator::SetInferenceContext)
.def("get_inference_context", &Operator::GetInferenceContext) .def("get_inference_context", &Operator::GetInferenceContext)
.def("verify_all_attr", &Operator::VerifyAllAttr) .def("verify_all_attr", &Operator::VerifyAllAttr)
.def("get_inputs_size", &Operator::GetInputsSize) .def("get_inputs_size", &Operator::GetInputsSize)
.def("get_outputs_size", &Operator::GetOutputsSize) .def("get_outputs_size", &Operator::GetOutputsSize)
#ifdef PADDLE_WITH_ASCEND_STRING
.def("get_all_attr_names_and_types", .def("get_all_attr_names_and_types",
static_cast<ge::graphStatus (ge::Operator::*)(std::map<ge::AscendString, ge::AscendString>&) const>(&Operator::GetAllAttrNamesAndTypes)) static_cast<ge::graphStatus (ge::Operator::*)( // NOLINT
std::map<AscendString, AscendString> &) const>(
&Operator::GetAllAttrNamesAndTypes))
#else
.def("get_all_attr_names_and_types", &Operator::GetAllAttrNamesAndTypes)
#endif
.def("set_attr_int64", .def("set_attr_int64",
[](Operator &op, const char* name, [](Operator &op, const char *name, int64_t value) -> Operator & {
int64_t value) -> Operator & {
int64_t tar = (int64_t)value; int64_t tar = (int64_t)value;
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_int32", .def("set_attr_int32",
[](Operator &op, const char* name, [](Operator &op, const char *name, int32_t value) -> Operator & {
int32_t value) -> Operator & {
int32_t tar = (int32_t)value; int32_t tar = (int32_t)value;
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_uint32", .def("set_attr_uint32",
[](Operator &op, const char* name, [](Operator &op, const char *name, uint32_t value) -> Operator & {
uint32_t value) -> Operator & {
uint32_t tar = (uint32_t)value; uint32_t tar = (uint32_t)value;
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_vec_int64", .def("set_attr_vec_int64",
[](Operator &op, const char* name, [](Operator &op, const char *name,
const std::vector<int64_t> &value) -> Operator & { const std::vector<int64_t> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<int64_t> tar; std::vector<int64_t> tar;
...@@ -398,7 +488,7 @@ void BindAscendGraph(py::module *m) { ...@@ -398,7 +488,7 @@ void BindAscendGraph(py::module *m) {
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_vec_int32", .def("set_attr_vec_int32",
[](Operator &op, const char * name, [](Operator &op, const char *name,
const std::vector<int32_t> &value) -> Operator & { const std::vector<int32_t> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<int32_t> tar; std::vector<int32_t> tar;
...@@ -410,7 +500,7 @@ void BindAscendGraph(py::module *m) { ...@@ -410,7 +500,7 @@ void BindAscendGraph(py::module *m) {
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_vec_uint32", .def("set_attr_vec_uint32",
[](Operator &op, const char* name, [](Operator &op, const char *name,
const std::vector<uint32_t> &value) -> Operator & { const std::vector<uint32_t> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<uint32_t> tar; std::vector<uint32_t> tar;
...@@ -422,21 +512,20 @@ void BindAscendGraph(py::module *m) { ...@@ -422,21 +512,20 @@ void BindAscendGraph(py::module *m) {
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_list_int64", .def("set_attr_list_int64",
[](Operator &op, const char* name, [](Operator &op, const char *name,
std::initializer_list<int64_t> &attrValue) -> Operator & { std::initializer_list<int64_t> &attrValue) -> Operator & {
return op.SetAttr(name, std::move(attrValue)); return op.SetAttr(name, std::move(attrValue));
}) })
.def("set_attr_attrvalue", .def("set_attr_attrvalue",
[](Operator &op, const char* name, AttrValue &attrValue) [](Operator &op, const char *name, AttrValue &attrValue)
-> Operator & { return op.SetAttr(name, std::move(attrValue)); }) -> Operator & { return op.SetAttr(name, std::move(attrValue)); })
.def( .def("set_attr_float",
"set_attr_float", [](Operator &op, const char *name, float value) -> Operator & {
[](Operator &op, const char* name, float value) -> Operator & {
float tar = static_cast<float>(value); float tar = static_cast<float>(value);
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_vec_float", .def("set_attr_vec_float",
[](Operator &op, const char* name, [](Operator &op, const char *name,
const std::vector<float> &value) -> Operator & { const std::vector<float> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<float> tar; std::vector<float> tar;
...@@ -447,22 +536,32 @@ void BindAscendGraph(py::module *m) { ...@@ -447,22 +536,32 @@ void BindAscendGraph(py::module *m) {
} }
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_string", (Operator & (Operator::*)(const char*, #ifdef PADDLE_WITH_ASCEND_STRING
const char*)) & .def("set_attr_string",
(Operator & (Operator::*)(const char *, const char *)) &
Operator::SetAttr) Operator::SetAttr)
.def("set_attr_vec_string", .def("set_attr_vec_string",
(Operator & (Operator::*)(const char*, (Operator &
const std::vector<ge::AscendString> &)) & (Operator::*)(const char *, const std::vector<AscendString> &)) &
Operator::SetAttr)
#else
.def("set_attr_string", (Operator & (Operator::*)(const std::string &,
const std::string &)) &
Operator::SetAttr)
.def("set_attr_vec_string",
(Operator & (Operator::*)(const std::string &,
const std::vector<std::string> &)) &
Operator::SetAttr) Operator::SetAttr)
#endif
.def("set_attr_bool", .def("set_attr_bool",
[](Operator &op, const char* name, bool value) -> Operator & { [](Operator &op, const char *name, bool value) -> Operator & {
if (value) if (value)
return op.SetAttr(name, true); return op.SetAttr(name, true);
else else
return op.SetAttr(name, false); return op.SetAttr(name, false);
}) })
.def("set_attr_vec_bool", .def("set_attr_vec_bool",
[](Operator &op, const char* name, [](Operator &op, const char *name,
const std::vector<bool> &value) -> Operator & { const std::vector<bool> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<bool> tar; std::vector<bool> tar;
...@@ -474,15 +573,25 @@ void BindAscendGraph(py::module *m) { ...@@ -474,15 +573,25 @@ void BindAscendGraph(py::module *m) {
} }
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
#ifdef PADDLE_WITH_ASCEND_STRING
.def("set_attr_tensor", .def("set_attr_tensor",
(Operator & (Operator::*)(const char* , const Tensor &)) & (Operator & (Operator::*)(const char *, const Tensor &)) &
Operator::SetAttr) Operator::SetAttr)
.def("set_attr_vec_tensor", .def("set_attr_vec_tensor",
(Operator & (Operator &
(Operator::*)(const char *, const std::vector<Tensor> &)) & (Operator::*)(const char *, const std::vector<Tensor> &)) &
Operator::SetAttr) Operator::SetAttr)
#else
.def("set_attr_tensor",
(Operator & (Operator::*)(const std::string &, const Tensor &)) &
Operator::SetAttr)
.def("set_attr_vec_tensor",
(Operator &
(Operator::*)(const std::string &, const std::vector<Tensor> &)) &
Operator::SetAttr)
#endif
.def("set_attr_vec_uint8", .def("set_attr_vec_uint8",
[](Operator &op, const char* name, [](Operator &op, const char *name,
const std::vector<uint8_t> &value) -> Operator & { const std::vector<uint8_t> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<uint8_t> tar; std::vector<uint8_t> tar;
...@@ -493,13 +602,21 @@ void BindAscendGraph(py::module *m) { ...@@ -493,13 +602,21 @@ void BindAscendGraph(py::module *m) {
} }
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
#ifdef PADDLE_WITH_ASCEND_STRING
.def("set_attr_vec_vec_int64",
(Operator &
(Operator::*)(const char *,
const std::vector<std::vector<int64_t>> &)) &
Operator::SetAttr)
#else
.def("set_attr_vec_vec_int64", .def("set_attr_vec_vec_int64",
(Operator & (Operator &
(Operator::*)(const char*, (Operator::*)(const std::string &,
const std::vector<std::vector<int64_t>> &)) & const std::vector<std::vector<int64_t>> &)) &
Operator::SetAttr) Operator::SetAttr)
#endif
.def("set_attr_vec_dtype", .def("set_attr_vec_dtype",
[](Operator &op, const char* name, [](Operator &op, const char *name,
const std::vector<DataType> &value) -> Operator & { const std::vector<DataType> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<ge::DataType> tar; std::vector<ge::DataType> tar;
...@@ -511,15 +628,13 @@ void BindAscendGraph(py::module *m) { ...@@ -511,15 +628,13 @@ void BindAscendGraph(py::module *m) {
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_dtype", .def("set_attr_dtype",
[](Operator &op, const char* name, [](Operator &op, const char *name,
const DataType &value) -> Operator & { const DataType &value) -> Operator & {
ge::DataType tar = (ge::DataType)value; ge::DataType tar = (ge::DataType)value;
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("get_attr", .def("get_attr",
[](Operator &op, const char* name, [](Operator &op, const char *name, AttrType type) -> py::tuple {
AttrType type) -> py::tuple {
graphStatus res = -1; graphStatus res = -1;
switch (type) { switch (type) {
case AT_INT64: { case AT_INT64: {
...@@ -568,12 +683,12 @@ void BindAscendGraph(py::module *m) { ...@@ -568,12 +683,12 @@ void BindAscendGraph(py::module *m) {
return py::make_tuple(o_av, res); return py::make_tuple(o_av, res);
} break; } break;
case AT_STRING: { case AT_STRING: {
ge::AscendString s_av; AscendString s_av;
res = op.GetAttr(name, s_av); res = op.GetAttr(name, s_av);
return py::make_tuple(s_av, res); return py::make_tuple(s_av, res);
} break; } break;
case AT_LIST_STRING: { case AT_LIST_STRING: {
std::vector<ge::AscendString> v_s_av; std::vector<AscendString> v_s_av;
res = op.GetAttr(name, v_s_av); res = op.GetAttr(name, v_s_av);
return py::make_tuple(v_s_av, res); return py::make_tuple(v_s_av, res);
} break; } break;
...@@ -624,11 +739,31 @@ void BindAscendGraph(py::module *m) { ...@@ -624,11 +739,31 @@ void BindAscendGraph(py::module *m) {
}) })
.def("break_connect", &Operator::BreakConnect) .def("break_connect", &Operator::BreakConnect)
.def("get_subgraph_names_count", &Operator::GetSubgraphNamesCount) .def("get_subgraph_names_count", &Operator::GetSubgraphNamesCount)
.def("get_subgraph_names", static_cast<ge::graphStatus (ge::Operator::*)(std::vector<ge::AscendString> &) const>(&Operator::GetSubgraphNames)) #ifdef PADDLE_WITH_ASCEND_STRING
.def("get_subgraph_builder", static_cast<ge::SubgraphBuilder (ge::Operator::*)(const char*) const>(&Operator::GetSubgraphBuilder)) .def("get_subgraph_names",
.def("get_subgraph", static_cast<ge::Graph (ge::Operator::*)(const char*) const>(&Operator::GetSubgraph)) static_cast<ge::graphStatus (ge::Operator::*)( // NOLINT
.def("get_dynamic_subgraph_builder", static_cast<ge::SubgraphBuilder (ge::Operator::*)(const char*, uint32_t) const>(&Operator::GetDynamicSubgraphBuilder)) std::vector<AscendString> &) const>(&Operator::GetSubgraphNames))
.def("get_dynamic_subgraph", static_cast<ge::Graph (ge::Operator::*)(const char*, uint32_t) const>(&Operator::GetDynamicSubgraph)); .def("get_subgraph_builder",
static_cast<ge::SubgraphBuilder (ge::Operator::*)(const char *)
const>(&Operator::GetSubgraphBuilder))
.def("get_subgraph",
static_cast<ge::Graph (ge::Operator::*)(const char *) const>(
&Operator::GetSubgraph))
.def("get_dynamic_subgraph_builder",
static_cast<ge::SubgraphBuilder (ge::Operator::*)(const char *,
uint32_t) const>(
&Operator::GetDynamicSubgraphBuilder))
.def("get_dynamic_subgraph",
static_cast<ge::Graph (ge::Operator::*)(const char *, uint32_t)
const>(&Operator::GetDynamicSubgraph));
#else
.def("get_subgraph_names_count", &Operator::GetSubgraphNamesCount)
.def("get_subgraph_names", &Operator::GetSubgraphNames)
.def("get_subgraph_builder", &Operator::GetSubgraphBuilder)
.def("get_subgraph", &Operator::GetSubgraph)
.def("get_dynamic_subgraph_builder", &Operator::GetDynamicSubgraphBuilder)
.def("get_dynamic_subgraph", &Operator::GetDynamicSubgraph);
#endif
py::class_<Tensor>(*m, "GETensor") py::class_<Tensor>(*m, "GETensor")
.def(py::init<>()) .def(py::init<>())
...@@ -643,10 +778,15 @@ void BindAscendGraph(py::module *m) { ...@@ -643,10 +778,15 @@ void BindAscendGraph(py::module *m) {
Tensor::SetData) Tensor::SetData)
.def("set_data", .def("set_data",
(graphStatus (Tensor::*)(const uint8_t *, size_t)) & Tensor::SetData) (graphStatus (Tensor::*)(const uint8_t *, size_t)) & Tensor::SetData)
#ifdef PADDLE_WITH_ASCEND_STRING
.def("set_data", .def("set_data",
(graphStatus (Tensor::*)(const char*)) & Tensor::SetData) (graphStatus (Tensor::*)(const char *)) & Tensor::SetData)
#else
.def("set_data", .def("set_data",
(graphStatus (Tensor::*)(const std::vector<ge::AscendString> &)) & (graphStatus (Tensor::*)(const std::string &)) & Tensor::SetData)
#endif
.def("set_data",
(graphStatus (Tensor::*)(const std::vector<AscendString> &)) &
Tensor::SetData) Tensor::SetData)
.def("get_data", .def("get_data",
...@@ -668,8 +808,8 @@ void BindAscendGraph(py::module *m) { ...@@ -668,8 +808,8 @@ void BindAscendGraph(py::module *m) {
.def(py::init<Shape, Format, DataType>(), py::arg("shape"), .def(py::init<Shape, Format, DataType>(), py::arg("shape"),
py::arg("format") = FORMAT_ND, py::arg("dt") = DT_FLOAT) py::arg("format") = FORMAT_ND, py::arg("dt") = DT_FLOAT)
.def(py::init<const TensorDesc &>()) .def(py::init<const TensorDesc &>())
.def("update", .def("update", (void (TensorDesc::*)(const Shape &, Format, DataType)) &
(void (TensorDesc::*)(const Shape&, Format, DataType)) & TensorDesc::Update, TensorDesc::Update,
py::arg("shape"), py::arg("format") = FORMAT_ND, py::arg("shape"), py::arg("format") = FORMAT_ND,
py::arg("dt") = DT_FLOAT) py::arg("dt") = DT_FLOAT)
.def("set_shape", &TensorDesc::SetShape) .def("set_shape", &TensorDesc::SetShape)
...@@ -690,8 +830,16 @@ void BindAscendGraph(py::module *m) { ...@@ -690,8 +830,16 @@ void BindAscendGraph(py::module *m) {
.def("get_origin_format", &TensorDesc::GetOriginFormat) .def("get_origin_format", &TensorDesc::GetOriginFormat)
.def("set_data_type", &TensorDesc::SetDataType) .def("set_data_type", &TensorDesc::SetDataType)
.def("get_data_type", &TensorDesc::GetDataType) .def("get_data_type", &TensorDesc::GetDataType)
.def("set_name", static_cast<void (ge::TensorDesc::*)(const char*)>(&TensorDesc::SetName)) #ifdef PADDLE_WITH_ASCEND_STRING
.def("get_name", static_cast<ge::graphStatus (ge::TensorDesc::*)(ge::AscendString&)>(&TensorDesc::GetName)) .def("set_name", static_cast<void (ge::TensorDesc::*)(const char *)>(
&TensorDesc::SetName))
.def("get_name",
static_cast<ge::graphStatus (ge::TensorDesc::*)(AscendString &)>(
&TensorDesc::GetName))
#else
.def("set_name", &TensorDesc::SetName)
.def("get_name", &TensorDesc::GetName)
#endif
.def("set_size", &TensorDesc::SetSize) .def("set_size", &TensorDesc::SetSize)
.def("get_size", &TensorDesc::GetSize) .def("get_size", &TensorDesc::GetSize)
.def("set_real_dim_cnt", &TensorDesc::SetRealDimCnt) .def("set_real_dim_cnt", &TensorDesc::SetRealDimCnt)
...@@ -709,19 +857,27 @@ void BindAscendGraph(py::module *m) { ...@@ -709,19 +857,27 @@ void BindAscendGraph(py::module *m) {
py::class_<AttrValue>(*m, "GEAttrValue").def(py::init<>()); py::class_<AttrValue>(*m, "GEAttrValue").def(py::init<>());
py::class_<OperatorFactory>(*m, "GEOperatorFactory") py::class_<OperatorFactory>(*m, "GEOperatorFactory")
#ifdef PADDLE_WITH_ASCEND_STRING
.def_static("create_operator", .def_static("create_operator",
static_cast<ge::Operator (*)(const char*, const char*)>(&ge::OperatorFactory::CreateOperator)) static_cast<ge::Operator (*)(const char *, const char *)>(
&ge::OperatorFactory::CreateOperator))
#else
.def("create_operator", &OperatorFactory::CreateOperator)
#endif
.def("get_ops_type_list", .def("get_ops_type_list",
[]() -> py::tuple { []() -> py::tuple {
std::vector<ge::AscendString> all_ops; std::vector<AscendString> all_ops;
graphStatus status = OperatorFactory::GetOpsTypeList(all_ops); graphStatus status = OperatorFactory::GetOpsTypeList(all_ops);
return py::make_tuple(all_ops, status); return py::make_tuple(all_ops, status);
}) })
.def_static("is_exist_op", #ifdef PADDLE_WITH_ASCEND_STRING
static_cast<bool (*)(const char*)>(&OperatorFactory::IsExistOp)); .def_static("is_exist_op", static_cast<bool (*)(const char *)>(
&OperatorFactory::IsExistOp));
#else
.def("is_exist_op", &OperatorFactory::IsExistOp);
#endif
} }
} // end namespace pybind } // namespace pybind
} // end namespace paddle } // namespace paddle
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册