Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1882f2ce
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1882f2ce
编写于
1月 15, 2021
作者:
G
gongweibao
提交者:
GitHub
1月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix compilcation on CANN20.1 and older (#30494)
Fix compilcation on CANN20.1 and older
上级
6dd52c5b
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
200 addition
and
162 deletion
+200
-162
CMakeLists.txt
CMakeLists.txt
+0
-1
cmake/external/ascend.cmake
cmake/external/ascend.cmake
+31
-43
cmake/external/cryptopp.cmake
cmake/external/cryptopp.cmake
+2
-2
cmake/external/dlpack.cmake
cmake/external/dlpack.cmake
+1
-1
cmake/external/gflags.cmake
cmake/external/gflags.cmake
+3
-3
cmake/external/glog.cmake
cmake/external/glog.cmake
+3
-3
cmake/external/grpc.cmake
cmake/external/grpc.cmake
+1
-1
cmake/external/openblas.cmake
cmake/external/openblas.cmake
+1
-1
cmake/external/protobuf.cmake
cmake/external/protobuf.cmake
+4
-4
cmake/external/pybind11.cmake
cmake/external/pybind11.cmake
+2
-2
cmake/external/warpctc.cmake
cmake/external/warpctc.cmake
+2
-2
cmake/external/xbyak.cmake
cmake/external/xbyak.cmake
+1
-1
cmake/external/xxhash.cmake
cmake/external/xxhash.cmake
+1
-1
cmake/external/zlib.cmake
cmake/external/zlib.cmake
+2
-2
paddle/fluid/framework/fleet/CMakeLists.txt
paddle/fluid/framework/fleet/CMakeLists.txt
+1
-1
paddle/fluid/framework/fleet/ascend_wrapper.h
paddle/fluid/framework/fleet/ascend_wrapper.h
+22
-9
paddle/fluid/pybind/ascend_wrapper_py.cc
paddle/fluid/pybind/ascend_wrapper_py.cc
+110
-85
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+13
-0
未找到文件。
CMakeLists.txt
浏览文件 @
1882f2ce
...
@@ -326,7 +326,6 @@ set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
...
@@ -326,7 +326,6 @@ set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
set
(
CMAKE_CXX_FLAGS_RELWITHDEBINFO
"-O3 -g -DNDEBUG"
)
set
(
CMAKE_CXX_FLAGS_RELWITHDEBINFO
"-O3 -g -DNDEBUG"
)
set
(
CMAKE_C_FLAGS_RELWITHDEBINFO
"-O3 -g -DNDEBUG"
)
set
(
CMAKE_C_FLAGS_RELWITHDEBINFO
"-O3 -g -DNDEBUG"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0"
)
if
(
ON_INFER
)
if
(
ON_INFER
)
# you can trun off the paddle fluid and inference lib by set ON_INFER=OFF
# you can trun off the paddle fluid and inference lib by set ON_INFER=OFF
...
...
cmake/external/ascend.cmake
浏览文件 @
1882f2ce
...
@@ -12,50 +12,38 @@
...
@@ -12,50 +12,38 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
INCLUDE
(
ExternalProject
)
SET
(
ASCEND_PROJECT
"extern_ascend"
)
#NOTE: Logic is from
IF
((
NOT DEFINED ASCEND_VER
)
OR
(
NOT DEFINED ASCEND_URL
))
# https://github.com/mindspore-ai/graphengine/blob/master/CMakeLists.txt
MESSAGE
(
STATUS
"use pre defined download url"
)
if
(
DEFINED ENV{ASCEND_CUSTOM_PATH}
)
SET
(
ASCEND_VER
"0.1.1"
CACHE STRING
""
FORCE
)
set
(
ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}
)
SET
(
ASCEND_NAME
"ascend"
CACHE STRING
""
FORCE
)
else
()
SET
(
ASCEND_URL
"http://paddle-ascend.bj.bcebos.com/ascend.tar.gz"
CACHE STRING
""
FORCE
)
set
(
ASCEND_DIR /usr/local/Ascend
)
ENDIF
()
endif
()
MESSAGE
(
STATUS
"ASCEND_NAME:
${
ASCEND_NAME
}
, ASCEND_URL:
${
ASCEND_URL
}
"
)
SET
(
ASCEND_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/ascend"
)
SET
(
ASCEND_DOWNLOAD_DIR
"
${
ASCEND_SOURCE_DIR
}
/src/
${
ASCEND_PROJECT
}
"
)
SET
(
ASCEND_DST_DIR
"ascend"
)
SET
(
ASCEND_INSTALL_ROOT
"
${
THIRD_PARTY_PATH
}
/install"
)
SET
(
ASCEND_INSTALL_DIR
${
ASCEND_INSTALL_ROOT
}
/
${
ASCEND_DST_DIR
}
)
SET
(
ASCEND_ROOT
${
ASCEND_INSTALL_DIR
}
)
SET
(
ASCEND_INC_DIR
${
ASCEND_ROOT
}
/include
)
SET
(
ASCEND_LIB_DIR
${
ASCEND_ROOT
}
/lib
)
SET
(
ASCEND_LIB
${
ASCEND_LIB_DIR
}
/libge_runner.so
)
SET
(
ASCEND_GRAPH_LIB
${
ASCEND_LIB_DIR
}
/libgraph.so
)
SET
(
CMAKE_INSTALL_RPATH
"
${
CMAKE_INSTALL_RPATH
}
"
"
${
ASCEND_ROOT
}
/lib"
)
INCLUDE_DIRECTORIES
(
${
ASCEND_INC_DIR
}
)
set
(
ASCEND_DRIVER_DIR
${
ASCEND_DIR
}
/driver/lib64
)
FILE
(
WRITE
${
ASCEND_DOWNLOAD_DIR
}
/CMakeLists.txt
set
(
ASCEND_DRIVER_COMMON_DIR
${
ASCEND_DIR
}
/driver/lib64/common
)
"PROJECT(ASCEND)
\n
"
set
(
ASCEND_DRIVER_SHARE_DIR
${
ASCEND_DIR
}
/driver/lib64/share
)
"cmake_minimum_required(VERSION 3.0)
\n
"
set
(
ASCEND_RUNTIME_DIR
${
ASCEND_DIR
}
/fwkacllib/lib64
)
"install(DIRECTORY
${
ASCEND_NAME
}
/include
${
ASCEND_NAME
}
/lib
\n
"
set
(
ASCEND_ATC_DIR
${
ASCEND_DIR
}
/atc/lib64
)
" DESTINATION
${
ASCEND_DST_DIR
}
)
\n
"
)
set
(
ASCEND_ACL_DIR
${
ASCEND_DIR
}
/acllib/lib64
)
ExternalProject_Add
(
set
(
STATIC_ACL_LIB
${
ASCEND_ACL_DIR
}
)
${
ASCEND_PROJECT
}
${
EXTERNAL_PROJECT_LOG_ARGS
}
PREFIX
${
ASCEND_SOURCE_DIR
}
DOWNLOAD_DIR
${
ASCEND_DOWNLOAD_DIR
}
DOWNLOAD_COMMAND wget --no-check-certificate
${
ASCEND_URL
}
-c -q -O
${
ASCEND_NAME
}
.tar.gz
&& tar zxvf
${
ASCEND_NAME
}
.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
ASCEND_INSTALL_ROOT
}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=
${
ASCEND_INSTALL_ROOT
}
)
ADD_LIBRARY
(
ascend SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET ascend PROPERTY IMPORTED_LOCATION
${
ASCEND_LIB
}
)
ADD_LIBRARY
(
ascend_graph SHARED IMPORTED GLOBAL
)
set
(
ASCEND_MS_RUNTIME_PATH
${
ASCEND_RUNTIME_DIR
}
${
ASCEND_ACL_DIR
}
${
ASCEND_ATC_DIR
}
)
SET_PROPERTY
(
TARGET ascend_graph PROPERTY IMPORTED_LOCATION
${
ASCEND_GRAPH_LIB
}
)
set
(
ASCEND_MS_DRIVER_PATH
${
ASCEND_DRIVER_DIR
}
${
ASCEND_DRIVER_COMMON_DIR
}
)
ADD_DEPENDENCIES
(
ascend ascend_graph
${
ASCEND_PROJECT
}
)
set
(
ATLAS_RUNTIME_DIR
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64
)
set
(
ATLAS_RUNTIME_INC_DIR
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/include
)
set
(
ATLAS_ACL_DIR
${
ASCEND_DIR
}
/ascend-toolkit/latest/acllib/lib64
)
set
(
ATLAS_ATC_DIR
${
ASCEND_DIR
}
/ascend-toolkit/latest/atc/lib64
)
set
(
ATLAS_MS_RUNTIME_PATH
${
ATLAS_RUNTIME_DIR
}
${
ATLAS_ACL_DIR
}
${
ATLAS_ATC_DIR
}
)
set
(
atlas_graph
${
ATLAS_RUNTIME_DIR
}
/libgraph.so
)
set
(
atlas_ge_runner
${
ATLAS_RUNTIME_DIR
}
/libge_runner.so
)
INCLUDE_DIRECTORIES
(
${
ATLAS_RUNTIME_INC_DIR
}
)
ADD_LIBRARY
(
ascend_ge SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET ascend_ge PROPERTY IMPORTED_LOCATION
${
atlas_ge_runner
}
)
ADD_LIBRARY
(
ascend_graph SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET ascend_graph PROPERTY IMPORTED_LOCATION
${
atlas_graph
}
)
add_custom_target
(
extern_ascend DEPENDS ascend_ge ascend_graph
)
cmake/external/cryptopp.cmake
浏览文件 @
1882f2ce
...
@@ -17,7 +17,7 @@ INCLUDE(ExternalProject)
...
@@ -17,7 +17,7 @@ INCLUDE(ExternalProject)
SET
(
CRYPTOPP_PREFIX_DIR
${
THIRD_PARTY_PATH
}
/cryptopp
)
SET
(
CRYPTOPP_PREFIX_DIR
${
THIRD_PARTY_PATH
}
/cryptopp
)
SET
(
CRYPTOPP_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/cryptopp
)
SET
(
CRYPTOPP_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/cryptopp
)
SET
(
CRYPTOPP_INCLUDE_DIR
"
${
CRYPTOPP_INSTALL_DIR
}
/include"
CACHE PATH
"cryptopp include directory."
FORCE
)
SET
(
CRYPTOPP_INCLUDE_DIR
"
${
CRYPTOPP_INSTALL_DIR
}
/include"
CACHE PATH
"cryptopp include directory."
FORCE
)
SET
(
CRYPTOPP_REPOSITORY
https://gitee.com/tianjianhe
/cryptopp.git
)
SET
(
CRYPTOPP_REPOSITORY
${
GIT_URL
}
/weidai11
/cryptopp.git
)
SET
(
CRYPTOPP_TAG CRYPTOPP_8_2_0
)
SET
(
CRYPTOPP_TAG CRYPTOPP_8_2_0
)
IF
(
WIN32
)
IF
(
WIN32
)
...
@@ -33,7 +33,7 @@ set(CRYPTOPP_CMAKE_ARGS ${COMMON_CMAKE_ARGS}
...
@@ -33,7 +33,7 @@ set(CRYPTOPP_CMAKE_ARGS ${COMMON_CMAKE_ARGS}
-DCMAKE_INSTALL_LIBDIR=
${
CRYPTOPP_INSTALL_DIR
}
/lib
-DCMAKE_INSTALL_LIBDIR=
${
CRYPTOPP_INSTALL_DIR
}
/lib
-DCMAKE_INSTALL_PREFIX=
${
CRYPTOPP_INSTALL_DIR
}
-DCMAKE_INSTALL_PREFIX=
${
CRYPTOPP_INSTALL_DIR
}
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
"-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0"
-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
-DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
-DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
-DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
-DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
...
...
cmake/external/dlpack.cmake
浏览文件 @
1882f2ce
...
@@ -17,7 +17,7 @@ include(ExternalProject)
...
@@ -17,7 +17,7 @@ include(ExternalProject)
set
(
DLPACK_PREFIX_DIR
${
THIRD_PARTY_PATH
}
/dlpack
)
set
(
DLPACK_PREFIX_DIR
${
THIRD_PARTY_PATH
}
/dlpack
)
set
(
DLPACK_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/dlpack/src/extern_dlpack
)
set
(
DLPACK_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/dlpack/src/extern_dlpack
)
set
(
DLPACK_REPOSITORY
https://gitee.com/tianjianhe
/dlpack.git
)
set
(
DLPACK_REPOSITORY
${
GIT_URL
}
/dmlc
/dlpack.git
)
set
(
DLPACK_TAG v0.2
)
set
(
DLPACK_TAG v0.2
)
cache_third_party
(
extern_dlpack
cache_third_party
(
extern_dlpack
...
...
cmake/external/gflags.cmake
浏览文件 @
1882f2ce
...
@@ -18,8 +18,8 @@ SET(GFLAGS_PREFIX_DIR ${THIRD_PARTY_PATH}/gflags)
...
@@ -18,8 +18,8 @@ SET(GFLAGS_PREFIX_DIR ${THIRD_PARTY_PATH}/gflags)
SET
(
GFLAGS_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/gflags/src/extern_gflags
)
SET
(
GFLAGS_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/gflags/src/extern_gflags
)
SET
(
GFLAGS_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/gflags
)
SET
(
GFLAGS_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/gflags
)
SET
(
GFLAGS_INCLUDE_DIR
"
${
GFLAGS_INSTALL_DIR
}
/include"
CACHE PATH
"gflags include directory."
FORCE
)
SET
(
GFLAGS_INCLUDE_DIR
"
${
GFLAGS_INSTALL_DIR
}
/include"
CACHE PATH
"gflags include directory."
FORCE
)
set
(
GFLAGS_REPOSITORY
https://gitee.com/tianjianhe
/gflags.git
)
set
(
GFLAGS_REPOSITORY
${
GIT_URL
}
/gflags
/gflags.git
)
set
(
GFLAGS_TAG
77592648e3f3be87d6c7123eb81cbad75f9aef5a
)
set
(
GFLAGS_TAG
"v2.2.2"
)
IF
(
WIN32
)
IF
(
WIN32
)
set
(
GFLAGS_LIBRARIES
"
${
GFLAGS_INSTALL_DIR
}
/lib/gflags_static.lib"
CACHE FILEPATH
"GFLAGS_LIBRARIES"
FORCE
)
set
(
GFLAGS_LIBRARIES
"
${
GFLAGS_INSTALL_DIR
}
/lib/gflags_static.lib"
CACHE FILEPATH
"GFLAGS_LIBRARIES"
FORCE
)
ELSE
(
WIN32
)
ELSE
(
WIN32
)
...
@@ -48,7 +48,7 @@ ExternalProject_Add(
...
@@ -48,7 +48,7 @@ ExternalProject_Add(
INSTALL_COMMAND
${
INSTALL_COMMAND
}
INSTALL_COMMAND
${
INSTALL_COMMAND
}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
-DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
-DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
"-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0"
-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
-DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
-DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
...
...
cmake/external/glog.cmake
浏览文件 @
1882f2ce
...
@@ -18,8 +18,8 @@ SET(GLOG_PREFIX_DIR ${THIRD_PARTY_PATH}/glog)
...
@@ -18,8 +18,8 @@ SET(GLOG_PREFIX_DIR ${THIRD_PARTY_PATH}/glog)
SET
(
GLOG_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/glog/src/extern_glog
)
SET
(
GLOG_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/glog/src/extern_glog
)
SET
(
GLOG_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/glog
)
SET
(
GLOG_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/glog
)
SET
(
GLOG_INCLUDE_DIR
"
${
GLOG_INSTALL_DIR
}
/include"
CACHE PATH
"glog include directory."
FORCE
)
SET
(
GLOG_INCLUDE_DIR
"
${
GLOG_INSTALL_DIR
}
/include"
CACHE PATH
"glog include directory."
FORCE
)
SET
(
GLOG_REPOSITORY
https://gitee.com/tianjianh
e/glog.git
)
SET
(
GLOG_REPOSITORY
${
GIT_URL
}
/googl
e/glog.git
)
SET
(
GLOG_TAG v0.
3.5
)
SET
(
GLOG_TAG v0.
4.0
)
IF
(
WIN32
)
IF
(
WIN32
)
SET
(
GLOG_LIBRARIES
"
${
GLOG_INSTALL_DIR
}
/lib/glog.lib"
CACHE FILEPATH
"glog library."
FORCE
)
SET
(
GLOG_LIBRARIES
"
${
GLOG_INSTALL_DIR
}
/lib/glog.lib"
CACHE FILEPATH
"glog library."
FORCE
)
...
@@ -47,7 +47,7 @@ ExternalProject_Add(
...
@@ -47,7 +47,7 @@ ExternalProject_Add(
SOURCE_DIR
${
GLOG_SOURCE_DIR
}
SOURCE_DIR
${
GLOG_SOURCE_DIR
}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
-DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
-DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
"-DCMAKE_CXX_FLAGS=
${
GLOG_CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0"
-DCMAKE_CXX_FLAGS=
${
GLOG_CMAKE_CXX_FLAGS
}
-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
-DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
-DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
...
...
cmake/external/grpc.cmake
浏览文件 @
1882f2ce
...
@@ -28,7 +28,7 @@ IF(APPLE)
...
@@ -28,7 +28,7 @@ IF(APPLE)
SET
(
GRPC_INSTALL_CMD make prefix=
${
GRPC_INSTALL_DIR
}
install
)
SET
(
GRPC_INSTALL_CMD make prefix=
${
GRPC_INSTALL_DIR
}
install
)
ELSE
()
ELSE
()
SET
(
GRPC_CFLAGS
"-Wno-error -std=c11
${
CLFAGS
}
"
)
SET
(
GRPC_CFLAGS
"-Wno-error -std=c11
${
CLFAGS
}
"
)
SET
(
GRPC_CXXFLAGS
"-Wno-error -std=c++11
${
CXXFLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0
"
)
SET
(
GRPC_CXXFLAGS
"-Wno-error -std=c++11
${
CXXFLAGS
}
"
)
SET
(
BUILD_CMD make CFLAGS=
${
GRPC_CFLAGS
}
CXXFLAGS=
${
GRPC_CXXFLAGS
}
HAS_SYSTEM_PROTOBUF=false -s -j
${
NUM_OF_PROCESSOR
}
static grpc_cpp_plugin
)
SET
(
BUILD_CMD make CFLAGS=
${
GRPC_CFLAGS
}
CXXFLAGS=
${
GRPC_CXXFLAGS
}
HAS_SYSTEM_PROTOBUF=false -s -j
${
NUM_OF_PROCESSOR
}
static grpc_cpp_plugin
)
SET
(
GRPC_INSTALL_CMD make prefix=
${
GRPC_INSTALL_DIR
}
install CFLAGS=
${
GRPC_CFLAGS
}
CXXFLAGS=
${
GRPC_CXXFLAGS
}
)
SET
(
GRPC_INSTALL_CMD make prefix=
${
GRPC_INSTALL_DIR
}
install CFLAGS=
${
GRPC_CFLAGS
}
CXXFLAGS=
${
GRPC_CXXFLAGS
}
)
ENDIF
()
ENDIF
()
...
...
cmake/external/openblas.cmake
浏览文件 @
1882f2ce
...
@@ -17,7 +17,7 @@ INCLUDE(ExternalProject)
...
@@ -17,7 +17,7 @@ INCLUDE(ExternalProject)
SET
(
CBLAS_PREFIX_DIR
${
THIRD_PARTY_PATH
}
/openblas
)
SET
(
CBLAS_PREFIX_DIR
${
THIRD_PARTY_PATH
}
/openblas
)
SET
(
CBLAS_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/openblas/src/extern_openblas
)
SET
(
CBLAS_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/openblas/src/extern_openblas
)
SET
(
CBLAS_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/openblas
)
SET
(
CBLAS_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/openblas
)
SET
(
CBLAS_REPOSITORY
https://gitee.com/tianjianhe
/OpenBLAS.git
)
SET
(
CBLAS_REPOSITORY
${
GIT_URL
}
/xianyi
/OpenBLAS.git
)
SET
(
CBLAS_TAG v0.3.7
)
SET
(
CBLAS_TAG v0.3.7
)
if
(
WITH_MIPS
)
if
(
WITH_MIPS
)
SET
(
CBLAS_TAG v0.3.13
)
SET
(
CBLAS_TAG v0.3.13
)
...
...
cmake/external/protobuf.cmake
浏览文件 @
1882f2ce
...
@@ -183,7 +183,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
...
@@ -183,7 +183,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
"
"-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
"
"-DCMAKE_C_FLAGS_DEBUG=
${
CMAKE_C_FLAGS_DEBUG
}
"
"-DCMAKE_C_FLAGS_DEBUG=
${
CMAKE_C_FLAGS_DEBUG
}
"
"-DCMAKE_C_FLAGS_RELEASE=
${
CMAKE_C_FLAGS_RELEASE
}
"
"-DCMAKE_C_FLAGS_RELEASE=
${
CMAKE_C_FLAGS_RELEASE
}
"
"-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0
"
"-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
"
"-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
"
"-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
"
"-DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
"
"-DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
"
"-Dprotobuf_WITH_ZLIB=ON"
"-Dprotobuf_WITH_ZLIB=ON"
...
@@ -198,8 +198,8 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
...
@@ -198,8 +198,8 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"-Dprotobuf_MSVC_STATIC_RUNTIME=
${
MSVC_STATIC_CRT
}
"
)
"-Dprotobuf_MSVC_STATIC_RUNTIME=
${
MSVC_STATIC_CRT
}
"
)
ENDIF
()
ENDIF
()
SET
(
PROTOBUF_REPOSITORY
https://gitee.com/tianjianhe
/protobuf.git
)
SET
(
PROTOBUF_REPOSITORY
${
GIT_URL
}
/protocolbuffers
/protobuf.git
)
SET
(
PROTOBUF_TAG
v3.8.0
)
SET
(
PROTOBUF_TAG
9f75c5aa851cd877fb0d93ccc31b8567a6706546
)
cache_third_party
(
${
TARGET_NAME
}
cache_third_party
(
${
TARGET_NAME
}
REPOSITORY
${
PROTOBUF_REPOSITORY
}
REPOSITORY
${
PROTOBUF_REPOSITORY
}
...
@@ -234,7 +234,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
...
@@ -234,7 +234,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
)
)
ENDFUNCTION
()
ENDFUNCTION
()
#
SET(PROTOBUF_VERSION 3.1.0)
SET
(
PROTOBUF_VERSION 3.1.0
)
IF
(
NOT PROTOBUF_FOUND
)
IF
(
NOT PROTOBUF_FOUND
)
build_protobuf
(
extern_protobuf FALSE
)
build_protobuf
(
extern_protobuf FALSE
)
...
...
cmake/external/pybind11.cmake
浏览文件 @
1882f2ce
...
@@ -16,8 +16,8 @@ include(ExternalProject)
...
@@ -16,8 +16,8 @@ include(ExternalProject)
set
(
PYBIND_PREFIX_DIR
${
THIRD_PARTY_PATH
}
/pybind
)
set
(
PYBIND_PREFIX_DIR
${
THIRD_PARTY_PATH
}
/pybind
)
set
(
PYBIND_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/pybind/src/extern_pybind
)
set
(
PYBIND_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/pybind/src/extern_pybind
)
SET
(
PYBIND_REPOSITORY
https://gitee.com/tianjianhe
/pybind11.git
)
SET
(
PYBIND_REPOSITORY
${
GIT_URL
}
/pybind
/pybind11.git
)
SET
(
PYBIND_TAG v2.
6.0
)
SET
(
PYBIND_TAG v2.
4.3
)
cache_third_party
(
extern_pybind
cache_third_party
(
extern_pybind
REPOSITORY
${
PYBIND_REPOSITORY
}
REPOSITORY
${
PYBIND_REPOSITORY
}
...
...
cmake/external/warpctc.cmake
浏览文件 @
1882f2ce
...
@@ -19,7 +19,7 @@ SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
...
@@ -19,7 +19,7 @@ SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
SET
(
WARPCTC_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/warpctc
)
SET
(
WARPCTC_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/warpctc
)
set
(
WARPCTC_REPOSITORY https://gitee.com/tianjianhe/warp-ctc.git
)
set
(
WARPCTC_REPOSITORY https://gitee.com/tianjianhe/warp-ctc.git
)
set
(
WARPCTC_TAG 95a461eddeabd51099ef059dcfada1117eb1bfb8
)
set
(
WARPCTC_TAG 95a461eddeabd51099ef059dcfada1117eb1bfb8
)
# set(WARPCTC_TAG bc29dcfff07ced1c7a19a4ecee48e5ad583cef8e
)
set
(
WARPCTC_REPOSITORY
${
GIT_URL
}
/baidu-research/warp-ctc.git
)
SET
(
WARPCTC_INCLUDE_DIR
"
${
WARPCTC_INSTALL_DIR
}
/include"
SET
(
WARPCTC_INCLUDE_DIR
"
${
WARPCTC_INSTALL_DIR
}
/include"
CACHE PATH
"Warp-ctc Directory"
FORCE
)
CACHE PATH
"Warp-ctc Directory"
FORCE
)
...
@@ -53,7 +53,7 @@ ExternalProject_Add(
...
@@ -53,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
-DCMAKE_C_FLAGS_DEBUG=
${
CMAKE_C_FLAGS_DEBUG
}
-DCMAKE_C_FLAGS_DEBUG=
${
CMAKE_C_FLAGS_DEBUG
}
-DCMAKE_C_FLAGS_RELEASE=
${
CMAKE_C_FLAGS_RELEASE
}
-DCMAKE_C_FLAGS_RELEASE=
${
CMAKE_C_FLAGS_RELEASE
}
"-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0
"
"-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
"
-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
-DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
-DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
-DCMAKE_INSTALL_PREFIX=
${
WARPCTC_INSTALL_DIR
}
-DCMAKE_INSTALL_PREFIX=
${
WARPCTC_INSTALL_DIR
}
...
...
cmake/external/xbyak.cmake
浏览文件 @
1882f2ce
...
@@ -19,7 +19,7 @@ set(XBYAK_PREFIX_DIR ${THIRD_PARTY_PATH}/xbyak)
...
@@ -19,7 +19,7 @@ set(XBYAK_PREFIX_DIR ${THIRD_PARTY_PATH}/xbyak)
SET
(
XBYAK_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/xbyak/src/extern_xbyak
)
SET
(
XBYAK_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/xbyak/src/extern_xbyak
)
set
(
XBYAK_INSTALL_ROOT
${
THIRD_PARTY_PATH
}
/install/xbyak
)
set
(
XBYAK_INSTALL_ROOT
${
THIRD_PARTY_PATH
}
/install/xbyak
)
set
(
XBYAK_INC_DIR
${
XBYAK_INSTALL_ROOT
}
/include
)
set
(
XBYAK_INC_DIR
${
XBYAK_INSTALL_ROOT
}
/include
)
set
(
XBYAK_REPOSITORY
https://gitee.com/tianjianhe
/xbyak.git
)
set
(
XBYAK_REPOSITORY
${
GIT_URL
}
/herumi
/xbyak.git
)
set
(
XBYAK_TAG v5.661
)
# Jul 26th
set
(
XBYAK_TAG v5.661
)
# Jul 26th
include_directories
(
${
XBYAK_INC_DIR
}
)
include_directories
(
${
XBYAK_INC_DIR
}
)
...
...
cmake/external/xxhash.cmake
浏览文件 @
1882f2ce
...
@@ -18,7 +18,7 @@ set(XXHASH_PREFIX_DIR ${THIRD_PARTY_PATH}/xxhash)
...
@@ -18,7 +18,7 @@ set(XXHASH_PREFIX_DIR ${THIRD_PARTY_PATH}/xxhash)
set
(
XXHASH_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/xxhash/src/extern_xxhash
)
set
(
XXHASH_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/xxhash/src/extern_xxhash
)
set
(
XXHASH_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/xxhash
)
set
(
XXHASH_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/xxhash
)
set
(
XXHASH_INCLUDE_DIR
"
${
XXHASH_INSTALL_DIR
}
/include"
)
set
(
XXHASH_INCLUDE_DIR
"
${
XXHASH_INSTALL_DIR
}
/include"
)
set
(
XXHASH_REPOSITORY
https://gitee.com/tianjianhe
/xxHash.git
)
set
(
XXHASH_REPOSITORY
${
GIT_URL
}
/Cyan4973
/xxHash.git
)
set
(
XXHASH_TAG v0.6.5
)
set
(
XXHASH_TAG v0.6.5
)
cache_third_party
(
extern_xxhash
cache_third_party
(
extern_xxhash
...
...
cmake/external/zlib.cmake
浏览文件 @
1882f2ce
...
@@ -19,7 +19,7 @@ SET(ZLIB_SOURCE_DIR ${THIRD_PARTY_PATH}/zlib/src/extern_zlib)
...
@@ -19,7 +19,7 @@ SET(ZLIB_SOURCE_DIR ${THIRD_PARTY_PATH}/zlib/src/extern_zlib)
SET
(
ZLIB_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/zlib
)
SET
(
ZLIB_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/zlib
)
SET
(
ZLIB_ROOT
${
ZLIB_INSTALL_DIR
}
CACHE FILEPATH
"zlib root directory."
FORCE
)
SET
(
ZLIB_ROOT
${
ZLIB_INSTALL_DIR
}
CACHE FILEPATH
"zlib root directory."
FORCE
)
SET
(
ZLIB_INCLUDE_DIR
"
${
ZLIB_INSTALL_DIR
}
/include"
CACHE PATH
"zlib include directory."
FORCE
)
SET
(
ZLIB_INCLUDE_DIR
"
${
ZLIB_INSTALL_DIR
}
/include"
CACHE PATH
"zlib include directory."
FORCE
)
set
(
ZLIB_REPOSITORY
https://gitee.com/tianjianhe
/zlib.git
)
set
(
ZLIB_REPOSITORY
${
GIT_URL
}
/madler
/zlib.git
)
set
(
ZLIB_TAG v1.2.8
)
set
(
ZLIB_TAG v1.2.8
)
INCLUDE_DIRECTORIES
(
${
ZLIB_INCLUDE_DIR
}
)
# For zlib code to include its own headers.
INCLUDE_DIRECTORIES
(
${
ZLIB_INCLUDE_DIR
}
)
# For zlib code to include its own headers.
...
@@ -41,7 +41,7 @@ ExternalProject_Add(
...
@@ -41,7 +41,7 @@ ExternalProject_Add(
CMAKE_ARGS -DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
CMAKE_ARGS -DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
-DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
-DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
"-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0"
-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-DCMAKE_INSTALL_PREFIX=
${
ZLIB_INSTALL_DIR
}
-DCMAKE_INSTALL_PREFIX=
${
ZLIB_INSTALL_DIR
}
-DBUILD_SHARED_LIBS=OFF
-DBUILD_SHARED_LIBS=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
...
...
paddle/fluid/framework/fleet/CMakeLists.txt
浏览文件 @
1882f2ce
...
@@ -33,5 +33,5 @@ cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_conte
...
@@ -33,5 +33,5 @@ cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_conte
cc_test
(
test_fleet_cc SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell
)
cc_test
(
test_fleet_cc SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell
)
if
(
WITH_ASCEND
)
if
(
WITH_ASCEND
)
cc_library
(
ascend_wrapper SRCS ascend_wrapper.cc DEPS framework_proto lod_tensor ascend ascend_graph
)
cc_library
(
ascend_wrapper SRCS ascend_wrapper.cc DEPS framework_proto lod_tensor ascend
_ge
ascend_graph
)
endif
(
WITH_ASCEND
)
endif
(
WITH_ASCEND
)
paddle/fluid/framework/fleet/ascend_wrapper.h
浏览文件 @
1882f2ce
...
@@ -37,7 +37,6 @@ limitations under the License. */
...
@@ -37,7 +37,6 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
// typedef std::vector<std::string> AscendGraphDesc;
typedef
ge
::
Graph
AscendGraphDesc
;
typedef
ge
::
Graph
AscendGraphDesc
;
class
AscendInstance
{
class
AscendInstance
{
...
@@ -45,17 +44,31 @@ class AscendInstance {
...
@@ -45,17 +44,31 @@ class AscendInstance {
virtual
~
AscendInstance
()
{}
virtual
~
AscendInstance
()
{}
AscendInstance
()
{}
AscendInstance
()
{}
std
::
map
<
std
::
string
,
std
::
string
>
GetDefaultInitSessionOptions
()
{
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
GetDefaultInitOptions
()
{
std
::
map
<
std
::
string
,
std
::
string
>
init_options
;
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
init_options
;
init_options
[
"a"
]
=
"b"
;
init_options
[
"ge.exec.deviceId"
]
=
"0"
;
init_options
[
"ge.trainFlag"
]
=
"1"
;
init_options
[
"ge.graphRunMode"
]
=
"1"
;
return
init_options
;
return
init_options
;
}
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
GetDefaultInitSessionOptions
()
{
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
init_options
;
init_options
[
"a"
]
=
"b"
;
init_options
[
"ge.trainFlag"
]
=
"1"
;
return
init_options
;
}
ge
::
Status
InitGEForUT
(){
return
ge
::
GEInitialize
(
GetDefaultInitOptions
());
}
}
// add other parameters here to init
void
InitGlobalResouces
()
{
void
InitGlobalResouces
()
{
session_
.
reset
(
new
ge
::
Session
(
GetDefaultInitSessionOptions
()));
LOG
(
INFO
)
<<
"Begin InitGlobalResouces"
;
VLOG
(
1
)
<<
"InitGlobalResouces Done"
;
session_
.
reset
(
new
ge
::
Session
(
GetDefaultInitSessionOptions
()));
if
(
session_
==
nullptr
){
LOG
(
FATAL
)
<<
"new session error:"
<<
session_
;
}
LOG
(
INFO
)
<<
"End InitGlobalResouces"
;
}
}
static
std
::
shared_ptr
<
AscendInstance
>
GetInstance
()
{
static
std
::
shared_ptr
<
AscendInstance
>
GetInstance
()
{
...
...
paddle/fluid/pybind/ascend_wrapper_py.cc
浏览文件 @
1882f2ce
...
@@ -33,6 +33,7 @@ limitations under the License. */
...
@@ -33,6 +33,7 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#include "paddle/fluid/pybind/ascend_wrapper_py.h"
#include "paddle/fluid/pybind/ascend_wrapper_py.h"
#include "paddle/fluid/platform/enforce.h"
using
namespace
ge
;
// NOLINT
using
namespace
ge
;
// NOLINT
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
...
@@ -51,9 +52,22 @@ void BindAscendWrapper(py::module *m) {
...
@@ -51,9 +52,22 @@ void BindAscendWrapper(py::module *m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
}
// end AscendWrapper
}
// end AscendWrapper
Status
ge_initialize
(
std
::
map
<
std
::
string
,
std
::
string
>
&
options
)
{
// NOLINT
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
convert_map
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
options
){
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>
rets
;
for
(
auto
&
option
:
options
)
{
ge
::
AscendString
key
=
option
.
first
.
c_str
();
ge
::
AscendString
val
=
option
.
second
.
c_str
();
rets
[
key
]
=
val
;
}
return
rets
;
}
ge
::
Status
ge_initialize
(
std
::
map
<
std
::
string
,
std
::
string
>
&
options
)
{
// NOLINT
py
::
gil_scoped_release
release
;
py
::
gil_scoped_release
release
;
Status
res
=
GEInitialize
(
options
);
auto
init_options
=
convert_map
(
options
);
ge
::
Status
res
=
ge
::
GEInitialize
(
init_options
);
PADDLE_ENFORCE_EQ
(
res
,
ge
::
SUCCESS
,
platform
::
errors
::
Fatal
(
"ge init error:%d"
,
res
));
py
::
gil_scoped_acquire
acquire
;
py
::
gil_scoped_acquire
acquire
;
return
res
;
return
res
;
}
}
...
@@ -214,36 +228,34 @@ void BindAscendGraph(py::module *m) {
...
@@ -214,36 +228,34 @@ void BindAscendGraph(py::module *m) {
// 类封装
// 类封装
py
::
class_
<
Session
>
(
*
m
,
"GESession"
)
py
::
class_
<
Session
>
(
*
m
,
"GESession"
)
.
def
(
py
::
init
<
const
std
::
map
<
std
::
string
,
std
::
string
>
&>
())
.
def
(
py
::
init
([](
const
std
::
map
<
std
::
string
,
std
::
string
>
&
options
)
{
return
std
::
unique_ptr
<
ge
::
Session
>
(
new
ge
::
Session
(
convert_map
(
options
)));
}))
.
def
(
"add_graph"
,
.
def
(
"add_graph"
,
(
Status
(
Session
::*
)(
uint32_t
,
const
Graph
&
))
&
Session
::
AddGraph
)
(
ge
::
Status
(
Session
::*
)(
uint32_t
,
const
Graph
&
))
&
Session
::
AddGraph
)
.
def
(
"add_graph"
,
.
def
(
"add_graph"
,
(
Status
(
Session
::*
)(
uint32_t
,
const
Graph
&
,
[](
Session
&
ss
,
uint32_t
index
,
const
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
))
&
const
std
::
map
<
std
::
string
,
std
::
string
>
&
options
){
Session
::
AddGraph
)
return
ss
.
AddGraph
(
index
,
graph
,
convert_map
(
options
));
})
.
def
(
"remove_graph"
,
&
Session
::
RemoveGraph
)
.
def
(
"remove_graph"
,
&
Session
::
RemoveGraph
)
.
def
(
"run_graph"
,
.
def
(
"run_graph"
,
[](
Session
&
ss
,
uint32_t
graphId
,
[](
Session
&
ss
,
uint32_t
graphId
,
const
std
::
vector
<
Tensor
>
&
inputs
)
->
py
::
tuple
{
const
std
::
vector
<
Tensor
>
&
inputs
)
->
py
::
tuple
{
std
::
vector
<
Tensor
>
outputs
;
std
::
vector
<
Tensor
>
outputs
;
Status
res
=
ss
.
RunGraph
(
graphId
,
inputs
,
outputs
);
ge
::
Status
res
=
ss
.
RunGraph
(
graphId
,
inputs
,
outputs
);
return
py
::
make_tuple
(
outputs
,
res
);
return
py
::
make_tuple
(
outputs
,
res
);
},
},
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"build_graph"
,
&
Session
::
BuildGraph
)
.
def
(
"build_graph"
,
&
Session
::
BuildGraph
)
.
def
(
"run_graph_async"
,
&
Session
::
RunGraphAsync
)
.
def
(
"run_graph_async"
,
&
Session
::
RunGraphAsync
)
.
def
(
"register_call_back_func"
,
.
def
(
"register_call_back_func"
,
(
Status
(
Session
::*
)(
// NOLINT
static_cast
<
ge
::
Status
(
ge
::
Session
::*
)(
const
char
*
,
const
ge
::
session
::
pCallBackFunc
&
)
>
(
&
ge
::
Session
::
RegisterCallBackFunc
))
const
std
::
string
&
,
std
::
function
<
uint32_t
(
uint32_t
graph_id
,
const
std
::
map
<
std
::
string
,
ge
::
Tensor
>
&
params_list
)
>
))
&
Session
::
RegisterCallBackFunc
)
.
def
(
"is_graph_need_rebuild"
,
&
Session
::
IsGraphNeedRebuild
);
.
def
(
"is_graph_need_rebuild"
,
&
Session
::
IsGraphNeedRebuild
);
py
::
class_
<
Graph
>
(
*
m
,
"GEGraph"
)
py
::
class_
<
Graph
>
(
*
m
,
"GEGraph"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
const
std
::
string
&
>
())
.
def
(
py
::
init
<
const
char
*
>
())
.
def
(
"set_inputs"
,
&
Graph
::
SetInputs
)
.
def
(
"set_inputs"
,
&
Graph
::
SetInputs
)
.
def
(
"set_outputs"
,
(
Graph
&
(
Graph
::*
)(
const
std
::
vector
<
Operator
>
&
))
&
.
def
(
"set_outputs"
,
(
Graph
&
(
Graph
::*
)(
const
std
::
vector
<
Operator
>
&
))
&
Graph
::
SetOutputs
)
Graph
::
SetOutputs
)
...
@@ -253,110 +265,121 @@ void BindAscendGraph(py::module *m) {
...
@@ -253,110 +265,121 @@ void BindAscendGraph(py::module *m) {
Graph
::
SetOutputs
)
Graph
::
SetOutputs
)
.
def
(
"set_outputs"
,
.
def
(
"set_outputs"
,
(
Graph
&
(
Graph
&
(
Graph
::*
)(
const
std
::
vector
<
std
::
pair
<
ge
::
Operator
,
std
::
s
tring
>>
(
Graph
::*
)(
const
std
::
vector
<
std
::
pair
<
ge
::
Operator
,
ge
::
AscendS
tring
>>
&
))
&
&
))
&
Graph
::
SetOutputs
)
Graph
::
SetOutputs
)
.
def
(
"set_targets"
,
&
Graph
::
SetTargets
)
.
def
(
"set_targets"
,
&
Graph
::
SetTargets
)
.
def
(
"is_valid"
,
&
Graph
::
IsValid
)
.
def
(
"is_valid"
,
&
Graph
::
IsValid
)
.
def
(
"add_op"
,
&
Graph
::
AddOp
)
.
def
(
"add_op"
,
&
Graph
::
AddOp
)
.
def
(
"find_op_by_name"
,
.
def
(
"find_op_by_name"
,
[](
Graph
&
graph
,
const
std
::
string
&
name
)
->
py
::
tuple
{
[](
Graph
&
graph
,
const
char
*
name
)
->
py
::
tuple
{
ge
::
Operator
op
;
ge
::
Operator
op
;
graphStatus
status
=
graph
.
FindOpByName
(
name
,
op
);
graphStatus
status
=
graph
.
FindOpByName
(
name
,
op
);
return
py
::
make_tuple
(
op
,
status
);
return
py
::
make_tuple
(
op
,
status
);
})
})
.
def
(
"find_op_by_type"
,
.
def
(
"find_op_by_type"
,
[](
Graph
&
graph
,
const
std
::
string
&
type
)
->
py
::
tuple
{
[](
Graph
&
graph
,
const
char
*
type
)
->
py
::
tuple
{
std
::
vector
<
ge
::
Operator
>
ops
;
std
::
vector
<
ge
::
Operator
>
ops
;
graphStatus
status
=
graph
.
FindOpByType
(
type
,
ops
);
graphStatus
status
=
graph
.
FindOpByType
(
type
,
ops
);
return
py
::
make_tuple
(
ops
,
status
);
return
py
::
make_tuple
(
ops
,
status
);
})
})
.
def
(
"get_all_op_name"
,
.
def
(
"get_all_op_name"
,
[](
Graph
&
graph
)
->
py
::
tuple
{
[](
Graph
&
graph
)
->
py
::
tuple
{
std
::
vector
<
std
::
s
tring
>
op_name
;
std
::
vector
<
ge
::
AscendS
tring
>
op_name
;
graphStatus
status
=
graph
.
GetAllOpName
(
op_name
);
graphStatus
status
=
graph
.
GetAllOpName
(
op_name
);
return
py
::
make_tuple
(
op_name
,
status
);
return
py
::
make_tuple
(
op_name
,
status
);
})
})
.
def
(
"save_to_file"
,
&
Graph
::
SaveToFile
)
.
def
(
"save_to_file"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Graph
::*
)(
const
char
*
)
const
>
(
&
ge
::
Graph
::
SaveToFile
)
)
.
def
(
"load_from_file"
,
&
Graph
::
LoadFromFile
)
.
def
(
"load_from_file"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Graph
::*
)(
const
char
*
)
>
(
&
Graph
::
LoadFromFile
)
)
.
def
(
"get_name"
,
&
Graph
::
GetName
)
.
def
(
"get_name"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Graph
::*
)(
ge
::
AscendString
&
)
const
>
(
&
Graph
::
GetName
)
)
.
def
(
"set_need_iteration"
,
&
Graph
::
SetNeedIteration
);
.
def
(
"set_need_iteration"
,
&
Graph
::
SetNeedIteration
);
py
::
class_
<
Operator
>
(
*
m
,
"GEOperator"
)
py
::
class_
<
Operator
>
(
*
m
,
"GEOperator"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
const
std
::
string
&
>
())
.
def
(
py
::
init
<
const
char
*
>
())
.
def
(
py
::
init
<
const
std
::
string
&
,
const
std
::
string
&
>
())
.
def
(
py
::
init
<
const
char
*
,
const
char
*
>
())
.
def
(
"is_empty"
,
&
Operator
::
IsEmpty
)
.
def
(
"is_empty"
,
&
Operator
::
IsEmpty
)
.
def
(
"get_name"
,
&
Operator
::
GetName
)
.
def
(
"get_name"
,
.
def
(
"get_op_type"
,
&
Operator
::
GetOpType
)
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
ge
::
AscendString
&
)
const
>
(
&
Operator
::
GetName
))
.
def
(
"get_op_type"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
ge
::
AscendString
&
)
const
>
(
&
Operator
::
GetOpType
))
.
def
(
"set_input"
,
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
Operator
&
))
&
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
Operator
&
))
&
Operator
::
SetInput
)
Operator
::
SetInput
)
.
def
(
"set_input"
,
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
Operator
&
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
Operator
&
,
const
std
::
string
&
))
&
const
char
*
))
&
Operator
::
SetInput
)
Operator
::
SetInput
)
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
.
def
(
"set_input"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
Operator
&
,
uint32_t
))
&
const
Operator
&
,
uint32_t
))
&
Operator
::
SetInput
)
Operator
::
SetInput
)
.
def
(
"add_control_input"
,
&
Operator
::
AddControlInput
)
.
def
(
"add_control_input"
,
&
Operator
::
AddControlInput
)
.
def
(
"get_input_const_data"
,
.
def
(
"get_input_const_data"
,
[](
Operator
&
op
,
const
std
::
string
&
dst_name
)
->
py
::
tuple
{
[](
Operator
&
op
,
const
char
*
dst_name
)
->
py
::
tuple
{
Tensor
data
;
Tensor
data
;
graphStatus
res
=
op
.
GetInputConstData
(
dst_name
,
data
);
graphStatus
res
=
op
.
GetInputConstData
(
dst_name
,
data
);
return
py
::
make_tuple
(
data
,
res
);
return
py
::
make_tuple
(
data
,
res
);
})
})
.
def
(
"get_input_desc"
,
.
def
(
"get_input_desc"
,
(
TensorDesc
(
Operator
::*
)(
const
std
::
string
&
)
const
)
&
(
TensorDesc
(
Operator
::*
)(
uint32_t
)
const
)
&
Operator
::
GetInputDesc
)
Operator
::
GetInputDesc
)
.
def
(
"get_input_desc"
,
.
def
(
"get_input_desc"
,
(
TensorDesc
(
Operator
::*
)(
uint32_t
)
const
)
&
Operator
::
GetInputDesc
)
[](
Operator
&
op
,
const
std
::
string
&
name
){
.
def
(
"get_dynamic_output_num"
,
&
Operator
::
GetDynamicOutputNum
)
return
op
.
GetInputDescByName
(
name
.
c_str
());
.
def
(
"get_dynamic_input_num"
,
&
Operator
::
GetDynamicInputNum
)
})
.
def
(
"get_dynamic_output_num"
,
static_cast
<
int
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetDynamicOutputNum
))
.
def
(
"get_dynamic_input_num"
,
static_cast
<
int
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetDynamicInputNum
))
.
def
(
"try_get_input_desc"
,
.
def
(
"try_get_input_desc"
,
[](
Operator
&
op
,
const
std
::
string
&
name
)
->
py
::
tuple
{
[](
Operator
&
op
,
const
char
*
name
)
->
py
::
tuple
{
TensorDesc
tensor_desc
;
TensorDesc
tensor_desc
;
graphStatus
status
=
op
.
TryGetInputDesc
(
name
,
tensor_desc
);
graphStatus
status
=
op
.
TryGetInputDesc
(
name
,
tensor_desc
);
return
py
::
make_tuple
(
tensor_desc
,
status
);
return
py
::
make_tuple
(
tensor_desc
,
status
);
})
})
.
def
(
"update_input_desc"
,
&
Operator
::
UpdateInputDesc
)
.
def
(
"update_input_desc"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
const
char
*
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateInputDesc
))
.
def
(
"get_output_desc"
,
.
def
(
"get_output_desc"
,
(
TensorDesc
(
Operator
::*
)(
const
std
::
string
&
)
const
)
&
[](
Operator
&
op
,
const
std
::
string
&
name
)
{
Operator
::
GetOutputDesc
)
return
op
.
GetOutputDescByName
(
name
.
c_str
());
})
.
def
(
"get_output_desc"
,
.
def
(
"get_output_desc"
,
(
TensorDesc
(
Operator
::*
)(
uint32_t
)
const
)
&
Operator
::
GetOutputDesc
)
(
TensorDesc
(
Operator
::*
)(
uint32_t
)
const
)
&
Operator
::
GetOutputDesc
)
.
def
(
"update_output_desc"
,
&
Operator
::
UpdateOutputDesc
)
.
def
(
"update_output_desc"
,
.
def
(
"get_dynamic_input_desc"
,
&
Operator
::
GetDynamicInputDesc
)
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
const
char
*
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateOutputDesc
))
.
def
(
"update_dynamic_input_desc"
,
&
Operator
::
UpdateDynamicInputDesc
)
.
def
(
"get_dynamic_input_desc"
,
.
def
(
"get_dynamic_output_desc"
,
&
Operator
::
GetDynamicOutputDesc
)
static_cast
<
ge
::
TensorDesc
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicInputDesc
))
.
def
(
"update_dynamic_output_desc"
,
&
Operator
::
UpdateDynamicOutputDesc
)
.
def
(
"update_dynamic_input_desc"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateDynamicInputDesc
))
.
def
(
"get_dynamic_output_desc"
,
static_cast
<
ge
::
TensorDesc
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicOutputDesc
))
.
def
(
"update_dynamic_output_desc"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
,
const
TensorDesc
&
)
>
(
&
Operator
::
UpdateDynamicOutputDesc
))
.
def
(
"infer_shape_and_type"
,
&
Operator
::
InferShapeAndType
)
.
def
(
"infer_shape_and_type"
,
&
Operator
::
InferShapeAndType
)
.
def
(
"set_inference_context"
,
&
Operator
::
SetInferenceContext
)
.
def
(
"set_inference_context"
,
&
Operator
::
SetInferenceContext
)
.
def
(
"get_inference_context"
,
&
Operator
::
GetInferenceContext
)
.
def
(
"get_inference_context"
,
&
Operator
::
GetInferenceContext
)
.
def
(
"verify_all_attr"
,
&
Operator
::
VerifyAllAttr
)
.
def
(
"verify_all_attr"
,
&
Operator
::
VerifyAllAttr
)
.
def
(
"get_inputs_size"
,
&
Operator
::
GetInputsSize
)
.
def
(
"get_inputs_size"
,
&
Operator
::
GetInputsSize
)
.
def
(
"get_outputs_size"
,
&
Operator
::
GetOutputsSize
)
.
def
(
"get_outputs_size"
,
&
Operator
::
GetOutputsSize
)
.
def
(
"get_all_attr_names_and_types"
,
&
Operator
::
GetAllAttrNamesAndTypes
)
.
def
(
"get_all_attr_names_and_types"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
std
::
map
<
ge
::
AscendString
,
ge
::
AscendString
>&
)
const
>
(
&
Operator
::
GetAllAttrNamesAndTypes
))
.
def
(
"set_attr_int64"
,
.
def
(
"set_attr_int64"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
int64_t
value
)
->
Operator
&
{
int64_t
value
)
->
Operator
&
{
int64_t
tar
=
(
int64_t
)
value
;
int64_t
tar
=
(
int64_t
)
value
;
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_int32"
,
.
def
(
"set_attr_int32"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
int32_t
value
)
->
Operator
&
{
int32_t
value
)
->
Operator
&
{
int32_t
tar
=
(
int32_t
)
value
;
int32_t
tar
=
(
int32_t
)
value
;
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_uint32"
,
.
def
(
"set_attr_uint32"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
uint32_t
value
)
->
Operator
&
{
uint32_t
value
)
->
Operator
&
{
uint32_t
tar
=
(
uint32_t
)
value
;
uint32_t
tar
=
(
uint32_t
)
value
;
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_vec_int64"
,
.
def
(
"set_attr_vec_int64"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
int64_t
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
int64_t
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
int64_t
>
tar
;
std
::
vector
<
int64_t
>
tar
;
...
@@ -368,7 +391,7 @@ void BindAscendGraph(py::module *m) {
...
@@ -368,7 +391,7 @@ void BindAscendGraph(py::module *m) {
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_vec_int32"
,
.
def
(
"set_attr_vec_int32"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
int32_t
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
int32_t
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
int32_t
>
tar
;
std
::
vector
<
int32_t
>
tar
;
...
@@ -380,7 +403,7 @@ void BindAscendGraph(py::module *m) {
...
@@ -380,7 +403,7 @@ void BindAscendGraph(py::module *m) {
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_vec_uint32"
,
.
def
(
"set_attr_vec_uint32"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
uint32_t
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
uint32_t
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
uint32_t
>
tar
;
std
::
vector
<
uint32_t
>
tar
;
...
@@ -392,21 +415,21 @@ void BindAscendGraph(py::module *m) {
...
@@ -392,21 +415,21 @@ void BindAscendGraph(py::module *m) {
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_list_int64"
,
.
def
(
"set_attr_list_int64"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
std
::
initializer_list
<
int64_t
>
&
attrValue
)
->
Operator
&
{
std
::
initializer_list
<
int64_t
>
&
attrValue
)
->
Operator
&
{
return
op
.
SetAttr
(
name
,
std
::
move
(
attrValue
));
return
op
.
SetAttr
(
name
,
std
::
move
(
attrValue
));
})
})
.
def
(
"set_attr_attrvalue"
,
.
def
(
"set_attr_attrvalue"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
AttrValue
&
attrValue
)
[](
Operator
&
op
,
const
char
*
name
,
AttrValue
&
attrValue
)
->
Operator
&
{
return
op
.
SetAttr
(
name
,
std
::
move
(
attrValue
));
})
->
Operator
&
{
return
op
.
SetAttr
(
name
,
std
::
move
(
attrValue
));
})
.
def
(
.
def
(
"set_attr_float"
,
"set_attr_float"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
float
value
)
->
Operator
&
{
[](
Operator
&
op
,
const
char
*
name
,
float
value
)
->
Operator
&
{
float
tar
=
static_cast
<
float
>
(
value
);
float
tar
=
static_cast
<
float
>
(
value
);
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_vec_float"
,
.
def
(
"set_attr_vec_float"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
float
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
float
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
float
>
tar
;
std
::
vector
<
float
>
tar
;
...
@@ -417,22 +440,22 @@ void BindAscendGraph(py::module *m) {
...
@@ -417,22 +440,22 @@ void BindAscendGraph(py::module *m) {
}
}
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_string"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
.
def
(
"set_attr_string"
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
std
::
string
&
))
&
const
char
*
))
&
Operator
::
SetAttr
)
Operator
::
SetAttr
)
.
def
(
"set_attr_vec_string"
,
.
def
(
"set_attr_vec_string"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
std
::
vector
<
std
::
s
tring
>
&
))
&
const
std
::
vector
<
ge
::
AscendS
tring
>
&
))
&
Operator
::
SetAttr
)
Operator
::
SetAttr
)
.
def
(
"set_attr_bool"
,
.
def
(
"set_attr_bool"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
bool
value
)
->
Operator
&
{
[](
Operator
&
op
,
const
char
*
name
,
bool
value
)
->
Operator
&
{
if
(
value
)
if
(
value
)
return
op
.
SetAttr
(
name
,
true
);
return
op
.
SetAttr
(
name
,
true
);
else
else
return
op
.
SetAttr
(
name
,
false
);
return
op
.
SetAttr
(
name
,
false
);
})
})
.
def
(
"set_attr_vec_bool"
,
.
def
(
"set_attr_vec_bool"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
bool
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
bool
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
bool
>
tar
;
std
::
vector
<
bool
>
tar
;
...
@@ -445,14 +468,14 @@ void BindAscendGraph(py::module *m) {
...
@@ -445,14 +468,14 @@ void BindAscendGraph(py::module *m) {
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_tensor"
,
.
def
(
"set_attr_tensor"
,
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
Tensor
&
))
&
(
Operator
&
(
Operator
::*
)(
const
char
*
,
const
Tensor
&
))
&
Operator
::
SetAttr
)
Operator
::
SetAttr
)
.
def
(
"set_attr_vec_tensor"
,
.
def
(
"set_attr_vec_tensor"
,
(
Operator
&
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
const
std
::
vector
<
Tensor
>
&
))
&
(
Operator
::*
)(
const
char
*
,
const
std
::
vector
<
Tensor
>
&
))
&
Operator
::
SetAttr
)
Operator
::
SetAttr
)
.
def
(
"set_attr_vec_uint8"
,
.
def
(
"set_attr_vec_uint8"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
uint8_t
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
uint8_t
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
uint8_t
>
tar
;
std
::
vector
<
uint8_t
>
tar
;
...
@@ -465,11 +488,11 @@ void BindAscendGraph(py::module *m) {
...
@@ -465,11 +488,11 @@ void BindAscendGraph(py::module *m) {
})
})
.
def
(
"set_attr_vec_vec_int64"
,
.
def
(
"set_attr_vec_vec_int64"
,
(
Operator
&
(
Operator
&
(
Operator
::*
)(
const
std
::
string
&
,
(
Operator
::*
)(
const
char
*
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
))
&
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
))
&
Operator
::
SetAttr
)
Operator
::
SetAttr
)
.
def
(
"set_attr_vec_dtype"
,
.
def
(
"set_attr_vec_dtype"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
std
::
vector
<
DataType
>
&
value
)
->
Operator
&
{
const
std
::
vector
<
DataType
>
&
value
)
->
Operator
&
{
int
len
=
value
.
size
();
int
len
=
value
.
size
();
std
::
vector
<
ge
::
DataType
>
tar
;
std
::
vector
<
ge
::
DataType
>
tar
;
...
@@ -481,14 +504,14 @@ void BindAscendGraph(py::module *m) {
...
@@ -481,14 +504,14 @@ void BindAscendGraph(py::module *m) {
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"set_attr_dtype"
,
.
def
(
"set_attr_dtype"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
const
DataType
&
value
)
->
Operator
&
{
const
DataType
&
value
)
->
Operator
&
{
ge
::
DataType
tar
=
(
ge
::
DataType
)
value
;
ge
::
DataType
tar
=
(
ge
::
DataType
)
value
;
return
op
.
SetAttr
(
name
,
tar
);
return
op
.
SetAttr
(
name
,
tar
);
})
})
.
def
(
"get_attr"
,
.
def
(
"get_attr"
,
[](
Operator
&
op
,
const
std
::
string
&
name
,
[](
Operator
&
op
,
const
char
*
name
,
AttrType
type
)
->
py
::
tuple
{
AttrType
type
)
->
py
::
tuple
{
graphStatus
res
=
-
1
;
graphStatus
res
=
-
1
;
switch
(
type
)
{
switch
(
type
)
{
...
@@ -538,12 +561,12 @@ void BindAscendGraph(py::module *m) {
...
@@ -538,12 +561,12 @@ void BindAscendGraph(py::module *m) {
return
py
::
make_tuple
(
o_av
,
res
);
return
py
::
make_tuple
(
o_av
,
res
);
}
break
;
}
break
;
case
AT_STRING
:
{
case
AT_STRING
:
{
std
::
s
tring
s_av
;
ge
::
AscendS
tring
s_av
;
res
=
op
.
GetAttr
(
name
,
s_av
);
res
=
op
.
GetAttr
(
name
,
s_av
);
return
py
::
make_tuple
(
s_av
,
res
);
return
py
::
make_tuple
(
s_av
,
res
);
}
break
;
}
break
;
case
AT_LIST_STRING
:
{
case
AT_LIST_STRING
:
{
std
::
vector
<
std
::
s
tring
>
v_s_av
;
std
::
vector
<
ge
::
AscendS
tring
>
v_s_av
;
res
=
op
.
GetAttr
(
name
,
v_s_av
);
res
=
op
.
GetAttr
(
name
,
v_s_av
);
return
py
::
make_tuple
(
v_s_av
,
res
);
return
py
::
make_tuple
(
v_s_av
,
res
);
}
break
;
}
break
;
...
@@ -594,11 +617,11 @@ void BindAscendGraph(py::module *m) {
...
@@ -594,11 +617,11 @@ void BindAscendGraph(py::module *m) {
})
})
.
def
(
"break_connect"
,
&
Operator
::
BreakConnect
)
.
def
(
"break_connect"
,
&
Operator
::
BreakConnect
)
.
def
(
"get_subgraph_names_count"
,
&
Operator
::
GetSubgraphNamesCount
)
.
def
(
"get_subgraph_names_count"
,
&
Operator
::
GetSubgraphNamesCount
)
.
def
(
"get_subgraph_names"
,
&
Operator
::
GetSubgraphNames
)
.
def
(
"get_subgraph_names"
,
static_cast
<
ge
::
graphStatus
(
ge
::
Operator
::*
)(
std
::
vector
<
ge
::
AscendString
>
&
)
const
>
(
&
Operator
::
GetSubgraphNames
)
)
.
def
(
"get_subgraph_builder"
,
&
Operator
::
GetSubgraphBuilder
)
.
def
(
"get_subgraph_builder"
,
static_cast
<
ge
::
SubgraphBuilder
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetSubgraphBuilder
)
)
.
def
(
"get_subgraph"
,
&
Operator
::
GetSubgraph
)
.
def
(
"get_subgraph"
,
static_cast
<
ge
::
Graph
(
ge
::
Operator
::*
)(
const
char
*
)
const
>
(
&
Operator
::
GetSubgraph
)
)
.
def
(
"get_dynamic_subgraph_builder"
,
&
Operator
::
GetDynamicSubgraphBuilder
)
.
def
(
"get_dynamic_subgraph_builder"
,
static_cast
<
ge
::
SubgraphBuilder
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicSubgraphBuilder
)
)
.
def
(
"get_dynamic_subgraph"
,
&
Operator
::
GetDynamicSubgraph
);
.
def
(
"get_dynamic_subgraph"
,
static_cast
<
ge
::
Graph
(
ge
::
Operator
::*
)(
const
char
*
,
uint32_t
)
const
>
(
&
Operator
::
GetDynamicSubgraph
)
);
py
::
class_
<
Tensor
>
(
*
m
,
"GETensor"
)
py
::
class_
<
Tensor
>
(
*
m
,
"GETensor"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
...
@@ -614,9 +637,9 @@ void BindAscendGraph(py::module *m) {
...
@@ -614,9 +637,9 @@ void BindAscendGraph(py::module *m) {
.
def
(
"set_data"
,
.
def
(
"set_data"
,
(
graphStatus
(
Tensor
::*
)(
const
uint8_t
*
,
size_t
))
&
Tensor
::
SetData
)
(
graphStatus
(
Tensor
::*
)(
const
uint8_t
*
,
size_t
))
&
Tensor
::
SetData
)
.
def
(
"set_data"
,
.
def
(
"set_data"
,
(
graphStatus
(
Tensor
::*
)(
const
std
::
string
&
))
&
Tensor
::
SetData
)
(
graphStatus
(
Tensor
::*
)(
const
char
*
))
&
Tensor
::
SetData
)
.
def
(
"set_data"
,
.
def
(
"set_data"
,
(
graphStatus
(
Tensor
::*
)(
const
std
::
vector
<
std
::
s
tring
>
&
))
&
(
graphStatus
(
Tensor
::*
)(
const
std
::
vector
<
ge
::
AscendS
tring
>
&
))
&
Tensor
::
SetData
)
Tensor
::
SetData
)
.
def
(
"get_data"
,
.
def
(
"get_data"
,
...
@@ -639,7 +662,7 @@ void BindAscendGraph(py::module *m) {
...
@@ -639,7 +662,7 @@ void BindAscendGraph(py::module *m) {
py
::
arg
(
"format"
)
=
FORMAT_ND
,
py
::
arg
(
"dt"
)
=
DT_FLOAT
)
py
::
arg
(
"format"
)
=
FORMAT_ND
,
py
::
arg
(
"dt"
)
=
DT_FLOAT
)
.
def
(
py
::
init
<
const
TensorDesc
&>
())
.
def
(
py
::
init
<
const
TensorDesc
&>
())
.
def
(
"update"
,
.
def
(
"update"
,
(
void
(
TensorDesc
::*
)(
Shape
,
Format
,
DataType
))
&
TensorDesc
::
Update
,
(
void
(
TensorDesc
::*
)(
const
Shape
&
,
Format
,
DataType
))
&
TensorDesc
::
Update
,
py
::
arg
(
"shape"
),
py
::
arg
(
"format"
)
=
FORMAT_ND
,
py
::
arg
(
"shape"
),
py
::
arg
(
"format"
)
=
FORMAT_ND
,
py
::
arg
(
"dt"
)
=
DT_FLOAT
)
py
::
arg
(
"dt"
)
=
DT_FLOAT
)
.
def
(
"set_shape"
,
&
TensorDesc
::
SetShape
)
.
def
(
"set_shape"
,
&
TensorDesc
::
SetShape
)
...
@@ -660,8 +683,8 @@ void BindAscendGraph(py::module *m) {
...
@@ -660,8 +683,8 @@ void BindAscendGraph(py::module *m) {
.
def
(
"get_origin_format"
,
&
TensorDesc
::
GetOriginFormat
)
.
def
(
"get_origin_format"
,
&
TensorDesc
::
GetOriginFormat
)
.
def
(
"set_data_type"
,
&
TensorDesc
::
SetDataType
)
.
def
(
"set_data_type"
,
&
TensorDesc
::
SetDataType
)
.
def
(
"get_data_type"
,
&
TensorDesc
::
GetDataType
)
.
def
(
"get_data_type"
,
&
TensorDesc
::
GetDataType
)
.
def
(
"set_name"
,
&
TensorDesc
::
SetName
)
.
def
(
"set_name"
,
static_cast
<
void
(
ge
::
TensorDesc
::*
)(
const
char
*
)
>
(
&
TensorDesc
::
SetName
)
)
.
def
(
"get_name"
,
&
TensorDesc
::
GetName
)
.
def
(
"get_name"
,
static_cast
<
ge
::
graphStatus
(
ge
::
TensorDesc
::*
)(
ge
::
AscendString
&
)
>
(
&
TensorDesc
::
GetName
)
)
.
def
(
"set_size"
,
&
TensorDesc
::
SetSize
)
.
def
(
"set_size"
,
&
TensorDesc
::
SetSize
)
.
def
(
"get_size"
,
&
TensorDesc
::
GetSize
)
.
def
(
"get_size"
,
&
TensorDesc
::
GetSize
)
.
def
(
"set_real_dim_cnt"
,
&
TensorDesc
::
SetRealDimCnt
)
.
def
(
"set_real_dim_cnt"
,
&
TensorDesc
::
SetRealDimCnt
)
...
@@ -679,14 +702,16 @@ void BindAscendGraph(py::module *m) {
...
@@ -679,14 +702,16 @@ void BindAscendGraph(py::module *m) {
py
::
class_
<
AttrValue
>
(
*
m
,
"GEAttrValue"
).
def
(
py
::
init
<>
());
py
::
class_
<
AttrValue
>
(
*
m
,
"GEAttrValue"
).
def
(
py
::
init
<>
());
py
::
class_
<
OperatorFactory
>
(
*
m
,
"GEOperatorFactory"
)
py
::
class_
<
OperatorFactory
>
(
*
m
,
"GEOperatorFactory"
)
.
def
(
"create_operator"
,
&
OperatorFactory
::
CreateOperator
)
.
def_static
(
"create_operator"
,
static_cast
<
ge
::
Operator
(
*
)(
const
char
*
,
const
char
*
)
>
(
&
ge
::
OperatorFactory
::
CreateOperator
))
.
def
(
"get_ops_type_list"
,
.
def
(
"get_ops_type_list"
,
[]()
->
py
::
tuple
{
[]()
->
py
::
tuple
{
std
::
vector
<
std
::
s
tring
>
all_ops
;
std
::
vector
<
ge
::
AscendS
tring
>
all_ops
;
graphStatus
status
=
OperatorFactory
::
GetOpsTypeList
(
all_ops
);
graphStatus
status
=
OperatorFactory
::
GetOpsTypeList
(
all_ops
);
return
py
::
make_tuple
(
all_ops
,
status
);
return
py
::
make_tuple
(
all_ops
,
status
);
})
})
.
def
(
"is_exist_op"
,
&
OperatorFactory
::
IsExistOp
);
.
def_static
(
"is_exist_op"
,
static_cast
<
bool
(
*
)(
const
char
*
)
>
(
&
OperatorFactory
::
IsExistOp
));
}
}
}
// end namespace pybind
}
// end namespace pybind
...
...
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
1882f2ce
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <fstream>
#include <fstream>
#include <iostream>
#include <iostream>
#include <string>
#include <string>
#include <unistd.h>
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
@@ -23,6 +24,9 @@
...
@@ -23,6 +24,9 @@
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
#ifdef PADDLE_WITH_ASCEND
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#endif
// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are
// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are
// determined by the OP`s proto automatically, i.e., all the inputs registered
// determined by the OP`s proto automatically, i.e., all the inputs registered
...
@@ -444,6 +448,11 @@ int main(int argc, char* argv[]) {
...
@@ -444,6 +448,11 @@ int main(int argc, char* argv[]) {
return
-
1
;
return
-
1
;
}
}
#ifdef PADDLE_WITH_ASCEND
auto
ascend_ptr
=
paddle
::
framework
::
AscendInstance
::
GetInstance
();
ascend_ptr
->
InitGEForUT
();
#endif
std
::
vector
<
std
::
string
>
headers
{
"
\"
paddle/fluid/imperative/tracer.h
\"
"
};
std
::
vector
<
std
::
string
>
headers
{
"
\"
paddle/fluid/imperative/tracer.h
\"
"
};
std
::
ofstream
out
(
argv
[
1
],
std
::
ios
::
out
);
std
::
ofstream
out
(
argv
[
1
],
std
::
ios
::
out
);
...
@@ -473,5 +482,9 @@ int main(int argc, char* argv[]) {
...
@@ -473,5 +482,9 @@ int main(int argc, char* argv[]) {
<<
"} // namespace paddle
\n
"
;
<<
"} // namespace paddle
\n
"
;
out
.
close
();
out
.
close
();
#ifdef PADDLE_WITH_ASCEND
ge
::
GEFinalize
();
#endif
return
0
;
return
0
;
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录