Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6eabbc80
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6eabbc80
编写于
1月 27, 2021
作者:
L
Leo Chen
提交者:
GitHub
1月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix compilation on ascend-20.1 (#30722)
fix compilation on ascend-20.1
上级
904cc443
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
314 addition
and
140 deletion
+314
-140
CMakeLists.txt
CMakeLists.txt
+5
-0
cmake/external/ascend.cmake
cmake/external/ascend.cmake
+4
-0
cmake/external/protobuf.cmake
cmake/external/protobuf.cmake
+6
-1
paddle/fluid/framework/fleet/ascend_wrapper.h
paddle/fluid/framework/fleet/ascend_wrapper.h
+25
-21
paddle/fluid/pybind/ascend_wrapper_py.cc
paddle/fluid/pybind/ascend_wrapper_py.cc
+274
-118
未找到文件。
CMakeLists.txt
浏览文件 @
6eabbc80
...
@@ -32,6 +32,7 @@ option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
...
@@ -32,6 +32,7 @@ option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
option
(
WITH_XPU
"Compile PaddlePaddle with BAIDU KUNLUN XPU"
OFF
)
option
(
WITH_XPU
"Compile PaddlePaddle with BAIDU KUNLUN XPU"
OFF
)
option
(
WITH_WIN_DUMP_DBG
"Compile with windows core dump debug mode"
OFF
)
option
(
WITH_WIN_DUMP_DBG
"Compile with windows core dump debug mode"
OFF
)
option
(
WITH_ASCEND
"Compile PaddlePaddle with ASCEND"
OFF
)
option
(
WITH_ASCEND
"Compile PaddlePaddle with ASCEND"
OFF
)
option
(
WITH_ASCEND_CXX11
"Compile PaddlePaddle with ASCEND and CXX11 ABI"
OFF
)
if
(
WITH_GPU AND WITH_XPU
)
if
(
WITH_GPU AND WITH_XPU
)
message
(
FATAL_ERROR
"Error when compile GPU and XPU at the same time"
)
message
(
FATAL_ERROR
"Error when compile GPU and XPU at the same time"
)
endif
()
endif
()
...
@@ -61,6 +62,10 @@ if(WITH_MUSL)
...
@@ -61,6 +62,10 @@ if(WITH_MUSL)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wno-error=deprecated-declarations -Wno-deprecated-declarations -Wno-error=pessimizing-move -Wno-error=deprecated-copy"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wno-error=deprecated-declarations -Wno-deprecated-declarations -Wno-error=pessimizing-move -Wno-error=deprecated-copy"
)
endif
()
endif
()
if
(
WITH_ASCEND AND NOT WITH_ASCEND_CXX11
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0"
)
endif
()
if
(
WIN32
)
if
(
WIN32
)
option
(
MSVC_STATIC_CRT
"use static C Runtime library by default"
ON
)
option
(
MSVC_STATIC_CRT
"use static C Runtime library by default"
ON
)
...
...
cmake/external/ascend.cmake
浏览文件 @
6eabbc80
...
@@ -42,6 +42,10 @@ set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so)
...
@@ -42,6 +42,10 @@ set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so)
set
(
atlas_acl_lib
${
ATLAS_RUNTIME_DIR
}
/libascendcl.so
)
set
(
atlas_acl_lib
${
ATLAS_RUNTIME_DIR
}
/libascendcl.so
)
INCLUDE_DIRECTORIES
(
${
ATLAS_RUNTIME_INC_DIR
}
)
INCLUDE_DIRECTORIES
(
${
ATLAS_RUNTIME_INC_DIR
}
)
if
(
EXISTS
${
ATLAS_RUNTIME_INC_DIR
}
/graph/ascend_string.h
)
add_definitions
(
-DPADDLE_WITH_ASCEND_STRING
)
endif
()
ADD_LIBRARY
(
ascend_ge SHARED IMPORTED GLOBAL
)
ADD_LIBRARY
(
ascend_ge SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET ascend_ge PROPERTY IMPORTED_LOCATION
${
atlas_ge_runner_lib
}
)
SET_PROPERTY
(
TARGET ascend_ge PROPERTY IMPORTED_LOCATION
${
atlas_ge_runner_lib
}
)
...
...
cmake/external/protobuf.cmake
浏览文件 @
6eabbc80
...
@@ -198,8 +198,13 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
...
@@ -198,8 +198,13 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"-Dprotobuf_MSVC_STATIC_RUNTIME=
${
MSVC_STATIC_CRT
}
"
)
"-Dprotobuf_MSVC_STATIC_RUNTIME=
${
MSVC_STATIC_CRT
}
"
)
ENDIF
()
ENDIF
()
if
(
WITH_ASCEND AND NOT WITH_ASCEND_CXX11
)
SET
(
PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git
)
SET
(
PROTOBUF_TAG v3.8.0
)
else
()
SET
(
PROTOBUF_REPOSITORY
${
GIT_URL
}
/protocolbuffers/protobuf.git
)
SET
(
PROTOBUF_REPOSITORY
${
GIT_URL
}
/protocolbuffers/protobuf.git
)
SET
(
PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546
)
SET
(
PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546
)
endif
()
cache_third_party
(
${
TARGET_NAME
}
cache_third_party
(
${
TARGET_NAME
}
REPOSITORY
${
PROTOBUF_REPOSITORY
}
REPOSITORY
${
PROTOBUF_REPOSITORY
}
...
@@ -234,7 +239,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
...
@@ -234,7 +239,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
)
)
ENDFUNCTION
()
ENDFUNCTION
()
SET
(
PROTOBUF_VERSION 3.
1
.0
)
SET
(
PROTOBUF_VERSION 3.
8
.0
)
IF
(
NOT PROTOBUF_FOUND
)
IF
(
NOT PROTOBUF_FOUND
)
build_protobuf
(
extern_protobuf FALSE
)
build_protobuf
(
extern_protobuf FALSE
)
...
...
paddle/fluid/framework/fleet/ascend_wrapper.h
浏览文件 @
6eabbc80
...
@@ -39,33 +39,37 @@ namespace framework {
...
@@ -39,33 +39,37 @@ namespace framework {
typedef
ge
::
Graph
AscendGraphDesc
;
typedef
ge
::
Graph
AscendGraphDesc
;
#ifdef PADDLE_WITH_ASCEND_STRING
using
AscendString
=
AscendString
;
#else
using
AscendString
=
std
::
string
;
#endif
class
AscendInstance
{
class
AscendInstance
{
public:
public:
virtual
~
AscendInstance
()
{}
virtual
~
AscendInstance
()
{}
AscendInstance
()
{}
AscendInstance
()
{}
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
GetDefaultInitOptions
()
{
std
::
map
<
AscendString
,
AscendString
>
GetDefaultInitOptions
()
{
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
init_options
;
std
::
map
<
AscendString
,
AscendString
>
init_options
;
init_options
[
"ge.exec.deviceId"
]
=
"0"
;
init_options
[
"ge.exec.deviceId"
]
=
"0"
;
init_options
[
"ge.graphRunMode"
]
=
"1"
;
init_options
[
"ge.graphRunMode"
]
=
"1"
;
return
init_options
;
return
init_options
;
}
}
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
GetDefaultInitSessionOptions
()
{
std
::
map
<
AscendString
,
AscendString
>
GetDefaultInitSessionOptions
()
{
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
init_options
;
std
::
map
<
AscendString
,
AscendString
>
init_options
;
init_options
[
"a"
]
=
"b"
;
init_options
[
"a"
]
=
"b"
;
init_options
[
"ge.trainFlag"
]
=
"1"
;
init_options
[
"ge.trainFlag"
]
=
"1"
;
return
init_options
;
return
init_options
;
}
}
ge
::
Status
InitGEForUT
(){
ge
::
Status
InitGEForUT
()
{
return
ge
::
GEInitialize
(
GetDefaultInitOptions
());
}
return
ge
::
GEInitialize
(
GetDefaultInitOptions
());
}
void
InitGlobalResouces
()
{
void
InitGlobalResouces
()
{
LOG
(
INFO
)
<<
"Begin InitGlobalResouces"
;
LOG
(
INFO
)
<<
"Begin InitGlobalResouces"
;
session_
.
reset
(
new
ge
::
Session
(
GetDefaultInitSessionOptions
()));
session_
.
reset
(
new
ge
::
Session
(
GetDefaultInitSessionOptions
()));
if
(
session_
==
nullptr
)
{
if
(
session_
==
nullptr
)
{
LOG
(
FATAL
)
<<
"new session error:"
<<
session_
;
LOG
(
FATAL
)
<<
"new session error:"
<<
session_
;
}
}
LOG
(
INFO
)
<<
"End InitGlobalResouces"
;
LOG
(
INFO
)
<<
"End InitGlobalResouces"
;
...
@@ -191,6 +195,6 @@ class AscendInstance {
...
@@ -191,6 +195,6 @@ class AscendInstance {
private:
private:
static
std
::
shared_ptr
<
AscendInstance
>
ascend_instance_
;
static
std
::
shared_ptr
<
AscendInstance
>
ascend_instance_
;
};
};
}
//
end
namespace framework
}
// namespace framework
}
//
end
namespace paddle
}
// namespace paddle
#endif
#endif
paddle/fluid/pybind/ascend_wrapper_py.cc
浏览文件 @
6eabbc80
...
@@ -32,9 +32,9 @@ limitations under the License. */
...
@@ -32,9 +32,9 @@ limitations under the License. */
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#include "paddle/fluid/pybind/ascend_wrapper_py.h"
#include "paddle/fluid/platform/ascend_npu_info.h"
#include "paddle/fluid/platform/ascend_npu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/ascend_wrapper_py.h"
using
namespace
ge
;
// NOLINT
using
namespace
ge
;
// NOLINT
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
...
@@ -42,6 +42,12 @@ namespace py = pybind11;
...
@@ -42,6 +42,12 @@ namespace py = pybind11;
namespace
paddle
{
namespace
paddle
{
namespace
pybind
{
namespace
pybind
{
#ifdef PADDLE_WITH_ASCEND_STRING
using
AscendString
=
AscendString
;
#else
using
AscendString
=
std
::
string
;
#endif
void
BindAscendWrapper
(
py
::
module
*
m
)
{
void
BindAscendWrapper
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
AscendInstance
,
py
::
class_
<
framework
::
AscendInstance
,
std
::
shared_ptr
<
framework
::
AscendInstance
>>
(
*
m
,
"AscendInstance"
)
std
::
shared_ptr
<
framework
::
AscendInstance
>>
(
*
m
,
"AscendInstance"
)
...
@@ -51,24 +57,26 @@ void BindAscendWrapper(py::module *m) {
...
@@ -51,24 +57,26 @@ void BindAscendWrapper(py::module *m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"add_ascend_subgraph"
,
&
framework
::
AscendInstance
::
AddAscendSubgraph
,
.
def
(
"add_ascend_subgraph"
,
&
framework
::
AscendInstance
::
AddAscendSubgraph
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
}
// end AscendWrapper
}
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
convert_map
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
options
){
std
::
map
<
AscendString
,
AscendString
>
convert_map
(
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
rets
;
const
std
::
map
<
std
::
string
,
std
::
string
>
&
options
)
{
std
::
map
<
AscendString
,
AscendString
>
rets
;
for
(
auto
&
option
:
options
)
{
for
(
auto
&
option
:
options
)
{
ge
::
AscendString
key
=
option
.
first
.
c_str
();
AscendString
key
=
option
.
first
.
c_str
();
ge
::
AscendString
val
=
option
.
second
.
c_str
();
AscendString
val
=
option
.
second
.
c_str
();
rets
[
key
]
=
val
;
rets
[
key
]
=
val
;
}
}
return
rets
;
return
rets
;
}
}
ge
::
Status
ge_initialize
(
std
::
map
<
std
::
string
,
std
::
string
>
&
options
)
{
// NOLINT
ge
::
Status
ge_initialize
(
std
::
map
<
std
::
string
,
std
::
string
>
&
options
)
{
// NOLINT
py
::
gil_scoped_release
release
;
py
::
gil_scoped_release
release
;
auto
init_options
=
convert_map
(
options
);
auto
init_options
=
convert_map
(
options
);
ge
::
Status
res
=
ge
::
GEInitialize
(
init_options
);
ge
::
Status
res
=
ge
::
GEInitialize
(
init_options
);
PADDLE_ENFORCE_EQ
(
res
,
PADDLE_ENFORCE_EQ
(
res
,
ge
::
SUCCESS
,
ge
::
SUCCESS
,
platform
::
errors
::
Fatal
(
"ge init error:%d"
,
res
));
platform
::
errors
::
Fatal
(
"ge init error:%d"
,
res
));
py
::
gil_scoped_acquire
acquire
;
py
::
gil_scoped_acquire
acquire
;
return
res
;
return
res
;
}
}
...
@@ -97,9 +105,10 @@ enum AttrType {
...
@@ -97,9 +105,10 @@ enum AttrType {
AT_NAMEATTR
AT_NAMEATTR
};
};
void
BindAscendDevice
(
py
::
module
*
m
)
{
void
BindAscendDevice
(
py
::
module
*
m
)
{
py
::
class_
<
platform
::
ascend
::
NPUDevice
>
(
*
m
,
"NPUDevice"
)
py
::
class_
<
platform
::
ascend
::
NPUDevice
>
(
*
m
,
"NPUDevice"
)
.
def_static
(
"get_device_count"
,
.
def_static
(
"get_device_count"
,
static_cast
<
int
(
*
)()
>
(
&
platform
::
ascend
::
NPUDevice
::
GetDeviceCount
));
static_cast
<
int
(
*
)()
>
(
&
platform
::
ascend
::
NPUDevice
::
GetDeviceCount
));
}
}
...
@@ -107,7 +116,7 @@ void BindAscendGraph(py::module *m) {
...
@@ -107,7 +116,7 @@ void BindAscendGraph(py::module *m) {
m
->
def
(
"ge_initialize"
,
&
ge_initialize
,
"GEInitialize"
);
m
->
def
(
"ge_initialize"
,
&
ge_initialize
,
"GEInitialize"
);
m
->
def
(
"ge_finalize"
,
&
GEFinalize
,
"GEFinalize"
);
m
->
def
(
"ge_finalize"
,
&
GEFinalize
,
"GEFinalize"
);
//
枚举封装
//
enum
py
::
enum_
<
GraphRunMode
>
(
*
m
,
"GEGraphRunMode"
)
py
::
enum_
<
GraphRunMode
>
(
*
m
,
"GEGraphRunMode"
)
.
value
(
"PREDICTION"
,
GraphRunMode
::
PREDICTION
)
.
value
(
"PREDICTION"
,
GraphRunMode
::
PREDICTION
)
.
value
(
"TRAIN"
,
GraphRunMode
::
TRAIN
)
.
value
(
"TRAIN"
,
GraphRunMode
::
TRAIN
)
...
@@ -235,14 +244,15 @@ void BindAscendGraph(py::module *m) {
...
@@ -235,14 +244,15 @@ void BindAscendGraph(py::module *m) {
// 类封装
// 类封装
py
::
class_
<
Session
>
(
*
m
,
"GESession"
)
py
::
class_
<
Session
>
(
*
m
,
"GESession"
)
.
def
(
py
::
init
([](
const
std
::
map
<
std
::
string
,
std
::
string
>
&
options
)
{
.
def
(
py
::
init
([](
const
std
::
map
<
std
::
string
,
std
::
string
>
&
options
)
{
return
std
::
unique_ptr
<
ge
::
Session
>
(
new
ge
::
Session
(
convert_map
(
options
)));
return
std
::
unique_ptr
<
ge
::
Session
>
(
new
ge
::
Session
(
convert_map
(
options
)));
}))
}))
.
def
(
"add_graph"
,
(
ge
::
Status
(
Session
::*
)(
uint32_t
,
const
Graph
&
))
&
Session
::
AddGraph
)
.
def
(
"add_graph"
,
.
def
(
"add_graph"
,
(
ge
::
Status
(
Session
::*
)(
uint32_t
,
const
Graph
&
))
&
Session
::
AddGraph
)
[](
Session
&
ss
,
uint32_t
index
,
const
Graph
&
graph
,
.
def
(
"add_graph"
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
options
)
{
[](
Session
&
ss
,
uint32_t
index
,
const
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
options
){
return
ss
.
AddGraph
(
index
,
graph
,
convert_map
(
options
));
return
ss
.
AddGraph
(
index
,
graph
,
convert_map
(
options
));
})
})
.
def
(
"remove_graph"
,
&
Session
::
RemoveGraph
)
.
def
(
"remove_graph"
,
&
Session
::
RemoveGraph
)
...
@@ -256,8 +266,20 @@ void BindAscendGraph(py::module *m) {
...
@@ -256,8 +266,20 @@ void BindAscendGraph(py::module *m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"build_graph"
,
&
Session
::
BuildGraph
)
.
def
(
"build_graph"
,
&
Session
::
BuildGraph
)
.
def
(
"run_graph_async"
,
&
Session
::
RunGraphAsync
)
.
def
(
"run_graph_async"
,
&
Session
::
RunGraphAsync
)
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"register_call_back_func"
,
static_cast
<
ge
::
Status
(
ge
::
Session
::*
)(
// NOLINT
const
char
*
,
const
ge
::
Session
::
pCallBackFunc
&
)
>
(
&
ge
::
Session
::
RegisterCallBackFunc
))
#else
.
def
(
"register_call_back_func"
,
.
def
(
"register_call_back_func"
,
static_cast
<
ge
::
Status
(
ge
::
Session
::*
)(
const
char
*
,
const
ge
::
session
::
pCallBackFunc
&
)
>
(
&
ge
::
Session
::
RegisterCallBackFunc
))
(
Status
(
Session
::*
)(
// NOLINT
const
std
::
string
&
,
std
::
function
<
uint32_t
(
uint32_t
graph_id
,
const
std
::
map
<
std
::
string
,
ge
::
Tensor
>
&
params_list
)
>
))
&
Session
::
RegisterCallBackFunc
)
#endif
.
def
(
"is_graph_need_rebuild"
,
&
Session
::
IsGraphNeedRebuild
);
.
def
(
"is_graph_need_rebuild"
,
&
Session
::
IsGraphNeedRebuild
);
py
::
class_
<
Graph
>
(
*
m
,
"GEGraph"
)
py
::
class_
<
Graph
>
(
*
m
,
"GEGraph"
)
...
@@ -272,121 +294,189 @@ void BindAscendGraph(py::module *m) {
...
@@ -272,121 +294,189 @@ void BindAscendGraph(py::module *m) {
Graph
::
SetOutputs
)
Graph
::
SetOutputs
)
.
def
(
"set_outputs"
,
.
def
(
"set_outputs"
,
(
Graph
&
(
Graph
&
(
Graph
::*
)(
const
std
::
vector
<
std
::
pair
<
ge
::
Operator
,
ge
::
AscendString
>>
(
Graph
::*
)(
const
std
::
vector
<
std
::
pair
<
ge
::
Operator
,
AscendString
>>
&
))
&
&
))
&
Graph
::
SetOutputs
)
Graph
::
SetOutputs
)
.
def
(
"set_targets"
,
&
Graph
::
SetTargets
)
.
def
(
"set_targets"
,
&
Graph
::
SetTargets
)
.
def
(
"is_valid"
,
&
Graph
::
IsValid
)
.
def
(
"is_valid"
,
&
Graph
::
IsValid
)
.
def
(
"add_op"
,
&
Graph
::
AddOp
)
.
def
(
"add_op"
,
&
Graph
::
AddOp
)
.
def
(
"find_op_by_name"
,
.
def
(
"find_op_by_name"
,
[](
Graph
&
graph
,
const
char
*
name
)
->
py
::
tuple
{
[](
Graph
&
graph
,
const
char
*
name
)
->
py
::
tuple
{
ge
::
Operator
op
;
ge
::
Operator
op
;
graphStatus
status
=
graph
.
FindOpByName
(
name
,
op
);
graphStatus
status
=
graph
.
FindOpByName
(
name
,
op
);
return
py
::
make_tuple
(
op
,
status
);
return
py
::
make_tuple
(
op
,
status
);
})
})
.
def
(
"find_op_by_type"
,
.
def
(
"find_op_by_type"
,
[](
Graph
&
graph
,
const
char
*
type
)
->
py
::
tuple
{
[](
Graph
&
graph
,
const
char
*
type
)
->
py
::
tuple
{
std
::
vector
<
ge
::
Operator
>
ops
;
std
::
vector
<
ge
::
Operator
>
ops
;
graphStatus
status
=
graph
.
FindOpByType
(
type
,
ops
);
graphStatus
status
=
graph
.
FindOpByType
(
type
,
ops
);
return
py
::
make_tuple
(
ops
,
status
);
return
py
::
make_tuple
(
ops
,
status
);
})
})
.
def
(
"get_all_op_name"
,
.
def
(
"get_all_op_name"
,
[](
Graph
&
graph
)
->
py
::
tuple
{
[](
Graph
&
graph
)
->
py
::
tuple
{
std
::
vector
<
ge
::
AscendString
>
op_name
;
std
::
vector
<
AscendString
>
op_name
;
graphStatus
status
=
graph
.
GetAllOpName
(
op_name
);
graphStatus
status
=
graph
.
GetAllOpName
(
op_name
);
return
py
::
make_tuple
(
op_name
,
status
);
return
py
::
make_tuple
(
op_name
,
status
);
})
})
.
def
(
"save_to_file"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Graph
::*
)(
const
char
*
)
const
>
(
&
ge
::
Graph
::
SaveToFile
))
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"load_from_file"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Graph
::*
)(
const
char
*
)
>
(
&
Graph
::
LoadFromFile
))
.
def
(
"save_to_file"
,
.
def
(
"get_name"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Graph
::*
)(
ge
::
AscendString
&
)
const
>
(
&
Graph
::
GetName
))
static_cast
<
ge
::
graphStatus
(
ge
::
Graph
::*
)(
const
char
*
)
const
>
(
&
ge
::
Graph
::
SaveToFile
))
.
def
(
"load_from_file"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Graph
::*
)(
const
char
*
)
>
(
&
Graph
::
LoadFromFile
))
.
def
(
"get_name"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Graph
::*
)(
AscendString
&
)
const
>
(
&
Graph
::
GetName
))
#else
.
def
(
"save_to_file"
,
&
Graph
::
SaveToFile
)
.
def
(
"load_from_file"
,
&
Graph
::
LoadFromFile
)
.
def
(
"get_name"
,
&
Graph
::
GetName
)
#endif
.
def
(
"set_need_iteration"
,
&
Graph
::
SetNeedIteration
);
.
def
(
"set_need_iteration"
,
&
Graph
::
SetNeedIteration
);
py
::
class_
<
Operator
>
(
*
m
,
"GEOperator"
)
py
::
class_
<
Operator
>
(
*
m
,
"GEOperator"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
const
char
*>
())
.
def
(
py
::
init
<
const
char
*>
())
.
def
(
py
::
init
<
const
char
*
,
const
char
*>
())
.
def
(
py
::
init
<
const
char
*
,
const
char
*>
())
.
def
(
"is_empty"
,
&
Operator
::
IsEmpty
)
.
def
(
"is_empty"
,
&
Operator
::
IsEmpty
)
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"get_name"
,
.
def
(
"get_name"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
ge
::
AscendString
&
)
const
>
(
&
Operator
::
GetName
))
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
AscendString
&
)
const
>
(
&
Operator
::
GetName
))
.
def
(
"get_op_type"
,
.
def
(
"get_op_type"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
ge
::
AscendString
&
)
const
>
(
&
Operator
::
GetOpType
))
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
AscendString
&
)
const
>
(
&
Operator
::
GetOpType
))
.
def
(
"set_input"
,
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
Operator
&
))
&
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
Operator
&
))
&
Operator
::
SetInput
)
Operator
::
SetInput
)
.
def
(
"set_input"
,
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
Operator
&
,
(
Operator
&
const
char
*
))
&
(
Operator
::*
)(
const
char
*
,
const
Operator
&
,
const
char
*
))
&
Operator
::
SetInput
)
Operator
::
SetInput
)
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
Operator
&
,
uint32_t
))
&
const
Operator
&
,
uint32_t
))
&
Operator
::
SetInput
)
Operator
::
SetInput
)
#else
.
def
(
"get_name"
,
&
Operator
::
GetName
)
.
def
(
"get_op_type"
,
&
Operator
::
GetOpType
)
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
Operator
&
))
&
Operator
::
SetInput
)
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
Operator
&
,
const
std
::
string
&
))
&
Operator
::
SetInput
)
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
Operator
&
,
uint32_t
))
&
Operator
::
SetInput
)
#endif
.
def
(
"add_control_input"
,
&
Operator
::
AddControlInput
)
.
def
(
"add_control_input"
,
&
Operator
::
AddControlInput
)
.
def
(
"get_input_const_data"
,
.
def
(
"get_input_const_data"
,
[](
Operator
&
op
,
const
char
*
dst_name
)
->
py
::
tuple
{
[](
Operator
&
op
,
const
char
*
dst_name
)
->
py
::
tuple
{
Tensor
data
;
Tensor
data
;
graphStatus
res
=
op
.
GetInputConstData
(
dst_name
,
data
);
graphStatus
res
=
op
.
GetInputConstData
(
dst_name
,
data
);
return
py
::
make_tuple
(
data
,
res
);
return
py
::
make_tuple
(
data
,
res
);
})
})
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"get_input_desc"
,
.
def
(
"get_input_desc"
,
(
TensorDesc
(
Operator
::*
)(
uint32_t
)
const
)
&
Operator
::
GetInputDesc
)
(
TensorDesc
(
Operator
::*
)(
uint32_t
)
const
)
&
Operator
::
GetInputDesc
)
.
def
(
"get_input_desc"
,
.
def
(
"get_input_desc"
,
[](
Operator
&
op
,
const
std
::
string
&
name
)
{
[](
Operator
&
op
,
const
std
::
string
&
name
)
{
return
op
.
GetInputDescByName
(
name
.
c_str
());
return
op
.
GetInputDescByName
(
name
.
c_str
());
})
})
.
def
(
"get_dynamic_output_num"
,
static_cast
<
int
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetDynamicOutputNum
))
.
def
(
"get_dynamic_output_num"
,
.
def
(
"get_dynamic_input_num"
,
static_cast
<
int
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetDynamicInputNum
))
static_cast
<
int
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetDynamicOutputNum
))
.
def
(
"get_dynamic_input_num"
,
static_cast
<
int
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetDynamicInputNum
))
#else
.
def
(
"get_input_desc"
,
(
TensorDesc
(
Operator
::*
)(
const
std
::
string
&
)
const
)
&
Operator
::
GetInputDesc
)
.
def
(
"get_input_desc"
,
(
TensorDesc
(
Operator
::*
)(
uint32_t
)
const
)
&
Operator
::
GetInputDesc
)
.
def
(
"get_dynamic_output_num"
,
&
Operator
::
GetDynamicOutputNum
)
.
def
(
"get_dynamic_input_num"
,
&
Operator
::
GetDynamicInputNum
)
#endif
.
def
(
"try_get_input_desc"
,
.
def
(
"try_get_input_desc"
,
[](
Operator
&
op
,
const
char
*
name
)
->
py
::
tuple
{
[](
Operator
&
op
,
const
char
*
name
)
->
py
::
tuple
{
TensorDesc
tensor_desc
;
TensorDesc
tensor_desc
;
graphStatus
status
=
op
.
TryGetInputDesc
(
name
,
tensor_desc
);
graphStatus
status
=
op
.
TryGetInputDesc
(
name
,
tensor_desc
);
return
py
::
make_tuple
(
tensor_desc
,
status
);
return
py
::
make_tuple
(
tensor_desc
,
status
);
})
})
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"update_input_desc"
,
.
def
(
"update_input_desc"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
const
char
*
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateInputDesc
))
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
// NOLINT
const
char
*
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateInputDesc
))
.
def
(
"get_output_desc"
,
.
def
(
"get_output_desc"
,
[](
Operator
&
op
,
const
std
::
string
&
name
)
{
[](
Operator
&
op
,
const
std
::
string
&
name
)
{
return
op
.
GetOutputDescByName
(
name
.
c_str
());
return
op
.
GetOutputDescByName
(
name
.
c_str
());
})
})
.
def
(
"get_output_desc"
,
.
def
(
"get_output_desc"
,
(
TensorDesc
(
Operator
::*
)(
uint32_t
)
const
)
&
Operator
::
GetOutputDesc
)
(
TensorDesc
(
Operator
::*
)(
uint32_t
)
const
)
&
Operator
::
GetOutputDesc
)
.
def
(
"update_output_desc"
,
.
def
(
"update_output_desc"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
const
char
*
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateOutputDesc
))
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
// NOLINT
const
char
*
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateOutputDesc
))
.
def
(
"get_dynamic_input_desc"
,
.
def
(
"get_dynamic_input_desc"
,
static_cast
<
ge
::
TensorDesc
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicInputDesc
))
static_cast
<
ge
::
TensorDesc
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicInputDesc
))
.
def
(
"update_dynamic_input_desc"
,
.
def
(
"update_dynamic_input_desc"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateDynamicInputDesc
))
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateDynamicInputDesc
))
.
def
(
"get_dynamic_output_desc"
,
.
def
(
"get_dynamic_output_desc"
,
static_cast
<
ge
::
TensorDesc
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicOutputDesc
))
static_cast
<
ge
::
TensorDesc
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicOutputDesc
))
.
def
(
"update_dynamic_output_desc"
,
.
def
(
"update_dynamic_output_desc"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateDynamicOutputDesc
))
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateDynamicOutputDesc
))
#else
.
def
(
"update_input_desc"
,
&
Operator
::
UpdateInputDesc
)
.
def
(
"get_output_desc"
,
(
TensorDesc
(
Operator
::*
)(
const
std
::
string
&
)
const
)
&
Operator
::
GetOutputDesc
)
.
def
(
"get_output_desc"
,
(
TensorDesc
(
Operator
::*
)(
uint32_t
)
const
)
&
Operator
::
GetOutputDesc
)
.
def
(
"update_output_desc"
,
&
Operator
::
UpdateOutputDesc
)
.
def
(
"get_dynamic_input_desc"
,
&
Operator
::
GetDynamicInputDesc
)
.
def
(
"update_dynamic_input_desc"
,
&
Operator
::
UpdateDynamicInputDesc
)
.
def
(
"get_dynamic_output_desc"
,
&
Operator
::
GetDynamicOutputDesc
)
.
def
(
"update_dynamic_output_desc"
,
&
Operator
::
UpdateDynamicOutputDesc
)
#endif
.
def
(
"infer_shape_and_type"
,
&
Operator
::
InferShapeAndType
)
.
def
(
"infer_shape_and_type"
,
&
Operator
::
InferShapeAndType
)
.
def
(
"set_inference_context"
,
&
Operator
::
SetInferenceContext
)
.
def
(
"set_inference_context"
,
&
Operator
::
SetInferenceContext
)
.
def
(
"get_inference_context"
,
&
Operator
::
GetInferenceContext
)
.
def
(
"get_inference_context"
,
&
Operator
::
GetInferenceContext
)
.
def
(
"verify_all_attr"
,
&
Operator
::
VerifyAllAttr
)
.
def
(
"verify_all_attr"
,
&
Operator
::
VerifyAllAttr
)
.
def
(
"get_inputs_size"
,
&
Operator
::
GetInputsSize
)
.
def
(
"get_inputs_size"
,
&
Operator
::
GetInputsSize
)
.
def
(
"get_outputs_size"
,
&
Operator
::
GetOutputsSize
)
.
def
(
"get_outputs_size"
,
&
Operator
::
GetOutputsSize
)
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"get_all_attr_names_and_types"
,
.
def
(
"get_all_attr_names_and_types"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>&
)
const
>
(
&
Operator
::
GetAllAttrNamesAndTypes
))
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
// NOLINT
std
::
map
<
AscendString
,
AscendString
>
&
)
const
>
(
&
Operator
::
GetAllAttrNamesAndTypes
))
#else
.
def
(
"get_all_attr_names_and_types"
,
&
Operator
::
GetAllAttrNamesAndTypes
)
#endif
.
def
(
"set_attr_int64"
,
.
def
(
"set_attr_int64"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
int64_t
value
)
->
Operator
&
{
int64_t
value
)
->
Operator
&
{
int64_t
tar
=
(
int64_t
)
value
;
int64_t
tar
=
(
int64_t
)
value
;
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_int32"
,
.
def
(
"set_attr_int32"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
int32_t
value
)
->
Operator
&
{
int32_t
value
)
->
Operator
&
{
int32_t
tar
=
(
int32_t
)
value
;
int32_t
tar
=
(
int32_t
)
value
;
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_uint32"
,
.
def
(
"set_attr_uint32"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
uint32_t
value
)
->
Operator
&
{
uint32_t
value
)
->
Operator
&
{
uint32_t
tar
=
(
uint32_t
)
value
;
uint32_t
tar
=
(
uint32_t
)
value
;
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_vec_int64"
,
.
def
(
"set_attr_vec_int64"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
int64_t
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
int64_t
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
int64_t
>
tar
;
std
::
vector
<
int64_t
>
tar
;
...
@@ -398,7 +488,7 @@ void BindAscendGraph(py::module *m) {
...
@@ -398,7 +488,7 @@ void BindAscendGraph(py::module *m) {
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_vec_int32"
,
.
def
(
"set_attr_vec_int32"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
int32_t
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
int32_t
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
int32_t
>
tar
;
std
::
vector
<
int32_t
>
tar
;
...
@@ -410,7 +500,7 @@ void BindAscendGraph(py::module *m) {
...
@@ -410,7 +500,7 @@ void BindAscendGraph(py::module *m) {
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_vec_uint32"
,
.
def
(
"set_attr_vec_uint32"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
uint32_t
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
uint32_t
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
uint32_t
>
tar
;
std
::
vector
<
uint32_t
>
tar
;
...
@@ -422,21 +512,20 @@ void BindAscendGraph(py::module *m) {
...
@@ -422,21 +512,20 @@ void BindAscendGraph(py::module *m) {
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_list_int64"
,
.
def
(
"set_attr_list_int64"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
std
::
initializer_list
<
int64_t
>
&
attrValue
)
->
Operator
&
{
std
::
initializer_list
<
int64_t
>
&
attrValue
)
->
Operator
&
{
return
op
.
SetAttr
(
name
,
std
::
move
(
attrValue
));
return
op
.
SetAttr
(
name
,
std
::
move
(
attrValue
));
})
})
.
def
(
"set_attr_attrvalue"
,
.
def
(
"set_attr_attrvalue"
,
[](
Operator
&
op
,
const
char
*
name
,
AttrValue
&
attrValue
)
[](
Operator
&
op
,
const
char
*
name
,
AttrValue
&
attrValue
)
->
Operator
&
{
return
op
.
SetAttr
(
name
,
std
::
move
(
attrValue
));
})
->
Operator
&
{
return
op
.
SetAttr
(
name
,
std
::
move
(
attrValue
));
})
.
def
(
.
def
(
"set_attr_float"
,
"set_attr_float"
,
[](
Operator
&
op
,
const
char
*
name
,
float
value
)
->
Operator
&
{
[](
Operator
&
op
,
const
char
*
name
,
float
value
)
->
Operator
&
{
float
tar
=
static_cast
<
float
>
(
value
);
float
tar
=
static_cast
<
float
>
(
value
);
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_vec_float"
,
.
def
(
"set_attr_vec_float"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
float
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
float
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
float
>
tar
;
std
::
vector
<
float
>
tar
;
...
@@ -447,22 +536,32 @@ void BindAscendGraph(py::module *m) {
...
@@ -447,22 +536,32 @@ void BindAscendGraph(py::module *m) {
}
}
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_string"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
#ifdef PADDLE_WITH_ASCEND_STRING
const
char
*
))
&
.
def
(
"set_attr_string"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
char
*
))
&
Operator
::
SetAttr
)
Operator
::
SetAttr
)
.
def
(
"set_attr_vec_string"
,
.
def
(
"set_attr_vec_string"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
(
Operator
&
const
std
::
vector
<
ge
::
AscendString
>
&
))
&
(
Operator
::*
)(
const
char
*
,
const
std
::
vector
<
AscendString
>
&
))
&
Operator
::
SetAttr
)
#else
.
def
(
"set_attr_string"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
std
::
string
&
))
&
Operator
::
SetAttr
)
.
def
(
"set_attr_vec_string"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
std
::
vector
<
std
::
string
>
&
))
&
Operator
::
SetAttr
)
Operator
::
SetAttr
)
#endif
.
def
(
"set_attr_bool"
,
.
def
(
"set_attr_bool"
,
[](
Operator
&
op
,
const
char
*
name
,
bool
value
)
->
Operator
&
{
[](
Operator
&
op
,
const
char
*
name
,
bool
value
)
->
Operator
&
{
if
(
value
)
if
(
value
)
return
op
.
SetAttr
(
name
,
true
);
return
op
.
SetAttr
(
name
,
true
);
else
else
return
op
.
SetAttr
(
name
,
false
);
return
op
.
SetAttr
(
name
,
false
);
})
})
.
def
(
"set_attr_vec_bool"
,
.
def
(
"set_attr_vec_bool"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
bool
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
bool
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
bool
>
tar
;
std
::
vector
<
bool
>
tar
;
...
@@ -474,15 +573,25 @@ void BindAscendGraph(py::module *m) {
...
@@ -474,15 +573,25 @@ void BindAscendGraph(py::module *m) {
}
}
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"set_attr_tensor"
,
.
def
(
"set_attr_tensor"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
Tensor
&
))
&
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
Tensor
&
))
&
Operator
::
SetAttr
)
Operator
::
SetAttr
)
.
def
(
"set_attr_vec_tensor"
,
.
def
(
"set_attr_vec_tensor"
,
(
Operator
&
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
std
::
vector
<
Tensor
>
&
))
&
(
Operator
::*
)(
const
char
*
,
const
std
::
vector
<
Tensor
>
&
))
&
Operator
::
SetAttr
)
Operator
::
SetAttr
)
#else
.
def
(
"set_attr_tensor"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
Tensor
&
))
&
Operator
::
SetAttr
)
.
def
(
"set_attr_vec_tensor"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
std
::
vector
<
Tensor
>
&
))
&
Operator
::
SetAttr
)
#endif
.
def
(
"set_attr_vec_uint8"
,
.
def
(
"set_attr_vec_uint8"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
uint8_t
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
uint8_t
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
uint8_t
>
tar
;
std
::
vector
<
uint8_t
>
tar
;
...
@@ -493,13 +602,21 @@ void BindAscendGraph(py::module *m) {
...
@@ -493,13 +602,21 @@ void BindAscendGraph(py::module *m) {
}
}
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"set_attr_vec_vec_int64"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
))
&
Operator
::
SetAttr
)
#else
.
def
(
"set_attr_vec_vec_int64"
,
.
def
(
"set_attr_vec_vec_int64"
,
(
Operator
&
(
Operator
&
(
Operator
::*
)(
const
char
*
,
(
Operator
::*
)(
const
std
::
string
&
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
))
&
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
))
&
Operator
::
SetAttr
)
Operator
::
SetAttr
)
#endif
.
def
(
"set_attr_vec_dtype"
,
.
def
(
"set_attr_vec_dtype"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
DataType
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
DataType
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
ge
::
DataType
>
tar
;
std
::
vector
<
ge
::
DataType
>
tar
;
...
@@ -511,15 +628,13 @@ void BindAscendGraph(py::module *m) {
...
@@ -511,15 +628,13 @@ void BindAscendGraph(py::module *m) {
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_dtype"
,
.
def
(
"set_attr_dtype"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
DataType
&
value
)
->
Operator
&
{
const
DataType
&
value
)
->
Operator
&
{
ge
::
DataType
tar
=
(
ge
::
DataType
)
value
;
ge
::
DataType
tar
=
(
ge
::
DataType
)
value
;
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"get_attr"
,
.
def
(
"get_attr"
,
[](
Operator
&
op
,
const
char
*
name
,
[](
Operator
&
op
,
const
char
*
name
,
AttrType
type
)
->
py
::
tuple
{
AttrType
type
)
->
py
::
tuple
{
graphStatus
res
=
-
1
;
graphStatus
res
=
-
1
;
switch
(
type
)
{
switch
(
type
)
{
case
AT_INT64
:
{
case
AT_INT64
:
{
...
@@ -568,12 +683,12 @@ void BindAscendGraph(py::module *m) {
...
@@ -568,12 +683,12 @@ void BindAscendGraph(py::module *m) {
return
py
::
make_tuple
(
o_av
,
res
);
return
py
::
make_tuple
(
o_av
,
res
);
}
break
;
}
break
;
case
AT_STRING
:
{
case
AT_STRING
:
{
ge
::
AscendString
s_av
;
AscendString
s_av
;
res
=
op
.
GetAttr
(
name
,
s_av
);
res
=
op
.
GetAttr
(
name
,
s_av
);
return
py
::
make_tuple
(
s_av
,
res
);
return
py
::
make_tuple
(
s_av
,
res
);
}
break
;
}
break
;
case
AT_LIST_STRING
:
{
case
AT_LIST_STRING
:
{
std
::
vector
<
ge
::
AscendString
>
v_s_av
;
std
::
vector
<
AscendString
>
v_s_av
;
res
=
op
.
GetAttr
(
name
,
v_s_av
);
res
=
op
.
GetAttr
(
name
,
v_s_av
);
return
py
::
make_tuple
(
v_s_av
,
res
);
return
py
::
make_tuple
(
v_s_av
,
res
);
}
break
;
}
break
;
...
@@ -624,11 +739,31 @@ void BindAscendGraph(py::module *m) {
...
@@ -624,11 +739,31 @@ void BindAscendGraph(py::module *m) {
})
})
.
def
(
"break_connect"
,
&
Operator
::
BreakConnect
)
.
def
(
"break_connect"
,
&
Operator
::
BreakConnect
)
.
def
(
"get_subgraph_names_count"
,
&
Operator
::
GetSubgraphNamesCount
)
.
def
(
"get_subgraph_names_count"
,
&
Operator
::
GetSubgraphNamesCount
)
.
def
(
"get_subgraph_names"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
std
::
vector
<
ge
::
AscendString
>
&
)
const
>
(
&
Operator
::
GetSubgraphNames
))
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"get_subgraph_builder"
,
static_cast
<
ge
::
SubgraphBuilder
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetSubgraphBuilder
))
.
def
(
"get_subgraph_names"
,
.
def
(
"get_subgraph"
,
static_cast
<
ge
::
Graph
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetSubgraph
))
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
// NOLINT
.
def
(
"get_dynamic_subgraph_builder"
,
static_cast
<
ge
::
SubgraphBuilder
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicSubgraphBuilder
))
std
::
vector
<
AscendString
>
&
)
const
>
(
&
Operator
::
GetSubgraphNames
))
.
def
(
"get_dynamic_subgraph"
,
static_cast
<
ge
::
Graph
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicSubgraph
));
.
def
(
"get_subgraph_builder"
,
static_cast
<
ge
::
SubgraphBuilder
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetSubgraphBuilder
))
.
def
(
"get_subgraph"
,
static_cast
<
ge
::
Graph
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetSubgraph
))
.
def
(
"get_dynamic_subgraph_builder"
,
static_cast
<
ge
::
SubgraphBuilder
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicSubgraphBuilder
))
.
def
(
"get_dynamic_subgraph"
,
static_cast
<
ge
::
Graph
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicSubgraph
));
#else
.
def
(
"get_subgraph_names_count"
,
&
Operator
::
GetSubgraphNamesCount
)
.
def
(
"get_subgraph_names"
,
&
Operator
::
GetSubgraphNames
)
.
def
(
"get_subgraph_builder"
,
&
Operator
::
GetSubgraphBuilder
)
.
def
(
"get_subgraph"
,
&
Operator
::
GetSubgraph
)
.
def
(
"get_dynamic_subgraph_builder"
,
&
Operator
::
GetDynamicSubgraphBuilder
)
.
def
(
"get_dynamic_subgraph"
,
&
Operator
::
GetDynamicSubgraph
);
#endif
py
::
class_
<
Tensor
>
(
*
m
,
"GETensor"
)
py
::
class_
<
Tensor
>
(
*
m
,
"GETensor"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
...
@@ -643,10 +778,15 @@ void BindAscendGraph(py::module *m) {
...
@@ -643,10 +778,15 @@ void BindAscendGraph(py::module *m) {
Tensor
::
SetData
)
Tensor
::
SetData
)
.
def
(
"set_data"
,
.
def
(
"set_data"
,
(
graphStatus
(
Tensor
::*
)(
const
uint8_t
*
,
size_t
))
&
Tensor
::
SetData
)
(
graphStatus
(
Tensor
::*
)(
const
uint8_t
*
,
size_t
))
&
Tensor
::
SetData
)
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"set_data"
,
.
def
(
"set_data"
,
(
graphStatus
(
Tensor
::*
)(
const
char
*
))
&
Tensor
::
SetData
)
(
graphStatus
(
Tensor
::*
)(
const
char
*
))
&
Tensor
::
SetData
)
#else
.
def
(
"set_data"
,
.
def
(
"set_data"
,
(
graphStatus
(
Tensor
::*
)(
const
std
::
vector
<
ge
::
AscendString
>
&
))
&
(
graphStatus
(
Tensor
::*
)(
const
std
::
string
&
))
&
Tensor
::
SetData
)
#endif
.
def
(
"set_data"
,
(
graphStatus
(
Tensor
::*
)(
const
std
::
vector
<
AscendString
>
&
))
&
Tensor
::
SetData
)
Tensor
::
SetData
)
.
def
(
"get_data"
,
.
def
(
"get_data"
,
...
@@ -668,8 +808,8 @@ void BindAscendGraph(py::module *m) {
...
@@ -668,8 +808,8 @@ void BindAscendGraph(py::module *m) {
.
def
(
py
::
init
<
Shape
,
Format
,
DataType
>
(),
py
::
arg
(
"shape"
),
.
def
(
py
::
init
<
Shape
,
Format
,
DataType
>
(),
py
::
arg
(
"shape"
),
py
::
arg
(
"format"
)
=
FORMAT_ND
,
py
::
arg
(
"dt"
)
=
DT_FLOAT
)
py
::
arg
(
"format"
)
=
FORMAT_ND
,
py
::
arg
(
"dt"
)
=
DT_FLOAT
)
.
def
(
py
::
init
<
const
TensorDesc
&>
())
.
def
(
py
::
init
<
const
TensorDesc
&>
())
.
def
(
"update"
,
.
def
(
"update"
,
(
void
(
TensorDesc
::*
)(
const
Shape
&
,
Format
,
DataType
))
&
(
void
(
TensorDesc
::*
)(
const
Shape
&
,
Format
,
DataType
))
&
TensorDesc
::
Update
,
TensorDesc
::
Update
,
py
::
arg
(
"shape"
),
py
::
arg
(
"format"
)
=
FORMAT_ND
,
py
::
arg
(
"shape"
),
py
::
arg
(
"format"
)
=
FORMAT_ND
,
py
::
arg
(
"dt"
)
=
DT_FLOAT
)
py
::
arg
(
"dt"
)
=
DT_FLOAT
)
.
def
(
"set_shape"
,
&
TensorDesc
::
SetShape
)
.
def
(
"set_shape"
,
&
TensorDesc
::
SetShape
)
...
@@ -690,8 +830,16 @@ void BindAscendGraph(py::module *m) {
...
@@ -690,8 +830,16 @@ void BindAscendGraph(py::module *m) {
.
def
(
"get_origin_format"
,
&
TensorDesc
::
GetOriginFormat
)
.
def
(
"get_origin_format"
,
&
TensorDesc
::
GetOriginFormat
)
.
def
(
"set_data_type"
,
&
TensorDesc
::
SetDataType
)
.
def
(
"set_data_type"
,
&
TensorDesc
::
SetDataType
)
.
def
(
"get_data_type"
,
&
TensorDesc
::
GetDataType
)
.
def
(
"get_data_type"
,
&
TensorDesc
::
GetDataType
)
.
def
(
"set_name"
,
static_cast
<
void
(
ge
::
TensorDesc
::*
)(
const
char
*
)
>
(
&
TensorDesc
::
SetName
))
#ifdef PADDLE_WITH_ASCEND_STRING
.
def
(
"get_name"
,
static_cast
<
ge
::
graphStatus
(
ge
::
TensorDesc
::*
)(
ge
::
AscendString
&
)
>
(
&
TensorDesc
::
GetName
))
.
def
(
"set_name"
,
static_cast
<
void
(
ge
::
TensorDesc
::*
)(
const
char
*
)
>
(
&
TensorDesc
::
SetName
))
.
def
(
"get_name"
,
static_cast
<
ge
::
graphStatus
(
ge
::
TensorDesc
::*
)(
AscendString
&
)
>
(
&
TensorDesc
::
GetName
))
#else
.
def
(
"set_name"
,
&
TensorDesc
::
SetName
)
.
def
(
"get_name"
,
&
TensorDesc
::
GetName
)
#endif
.
def
(
"set_size"
,
&
TensorDesc
::
SetSize
)
.
def
(
"set_size"
,
&
TensorDesc
::
SetSize
)
.
def
(
"get_size"
,
&
TensorDesc
::
GetSize
)
.
def
(
"get_size"
,
&
TensorDesc
::
GetSize
)
.
def
(
"set_real_dim_cnt"
,
&
TensorDesc
::
SetRealDimCnt
)
.
def
(
"set_real_dim_cnt"
,
&
TensorDesc
::
SetRealDimCnt
)
...
@@ -709,19 +857,27 @@ void BindAscendGraph(py::module *m) {
...
@@ -709,19 +857,27 @@ void BindAscendGraph(py::module *m) {
py
::
class_
<
AttrValue
>
(
*
m
,
"GEAttrValue"
).
def
(
py
::
init
<>
());
py
::
class_
<
AttrValue
>
(
*
m
,
"GEAttrValue"
).
def
(
py
::
init
<>
());
py
::
class_
<
OperatorFactory
>
(
*
m
,
"GEOperatorFactory"
)
py
::
class_
<
OperatorFactory
>
(
*
m
,
"GEOperatorFactory"
)
#ifdef PADDLE_WITH_ASCEND_STRING
.
def_static
(
"create_operator"
,
.
def_static
(
"create_operator"
,
static_cast
<
ge
::
Operator
(
*
)(
const
char
*
,
const
char
*
)
>
(
&
ge
::
OperatorFactory
::
CreateOperator
))
static_cast
<
ge
::
Operator
(
*
)(
const
char
*
,
const
char
*
)
>
(
&
ge
::
OperatorFactory
::
CreateOperator
))
#else
.
def
(
"create_operator"
,
&
OperatorFactory
::
CreateOperator
)
#endif
.
def
(
"get_ops_type_list"
,
.
def
(
"get_ops_type_list"
,
[]()
->
py
::
tuple
{
[]()
->
py
::
tuple
{
std
::
vector
<
ge
::
AscendString
>
all_ops
;
std
::
vector
<
AscendString
>
all_ops
;
graphStatus
status
=
OperatorFactory
::
GetOpsTypeList
(
all_ops
);
graphStatus
status
=
OperatorFactory
::
GetOpsTypeList
(
all_ops
);
return
py
::
make_tuple
(
all_ops
,
status
);
return
py
::
make_tuple
(
all_ops
,
status
);
})
})
.
def_static
(
"is_exist_op"
,
#ifdef PADDLE_WITH_ASCEND_STRING
static_cast
<
bool
(
*
)(
const
char
*
)
>
(
&
OperatorFactory
::
IsExistOp
));
.
def_static
(
"is_exist_op"
,
static_cast
<
bool
(
*
)(
const
char
*
)
>
(
&
OperatorFactory
::
IsExistOp
));
#else
.
def
(
"is_exist_op"
,
&
OperatorFactory
::
IsExistOp
);
#endif
}
}
}
//
end
namespace pybind
}
// namespace pybind
}
//
end
namespace paddle
}
// namespace paddle
#endif
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录