未验证 提交 1882f2ce 编写于 作者: G gongweibao 提交者: GitHub

Fix compilcation on CANN20.1 and older (#30494)

Fix compilcation on CANN20.1 and older 
上级 6dd52c5b
...@@ -326,7 +326,6 @@ set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build") ...@@ -326,7 +326,6 @@ set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
if(ON_INFER) if(ON_INFER)
# you can trun off the paddle fluid and inference lib by set ON_INFER=OFF # you can trun off the paddle fluid and inference lib by set ON_INFER=OFF
......
...@@ -12,50 +12,38 @@ ...@@ -12,50 +12,38 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
INCLUDE(ExternalProject)
SET(ASCEND_PROJECT "extern_ascend") #NOTE: Logic is from
IF((NOT DEFINED ASCEND_VER) OR (NOT DEFINED ASCEND_URL)) # https://github.com/mindspore-ai/graphengine/blob/master/CMakeLists.txt
MESSAGE(STATUS "use pre defined download url") if(DEFINED ENV{ASCEND_CUSTOM_PATH})
SET(ASCEND_VER "0.1.1" CACHE STRING "" FORCE) set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH})
SET(ASCEND_NAME "ascend" CACHE STRING "" FORCE) else()
SET(ASCEND_URL "http://paddle-ascend.bj.bcebos.com/ascend.tar.gz" CACHE STRING "" FORCE) set(ASCEND_DIR /usr/local/Ascend)
ENDIF() endif()
MESSAGE(STATUS "ASCEND_NAME: ${ASCEND_NAME}, ASCEND_URL: ${ASCEND_URL}")
SET(ASCEND_SOURCE_DIR "${THIRD_PARTY_PATH}/ascend")
SET(ASCEND_DOWNLOAD_DIR "${ASCEND_SOURCE_DIR}/src/${ASCEND_PROJECT}")
SET(ASCEND_DST_DIR "ascend")
SET(ASCEND_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(ASCEND_INSTALL_DIR ${ASCEND_INSTALL_ROOT}/${ASCEND_DST_DIR})
SET(ASCEND_ROOT ${ASCEND_INSTALL_DIR})
SET(ASCEND_INC_DIR ${ASCEND_ROOT}/include)
SET(ASCEND_LIB_DIR ${ASCEND_ROOT}/lib)
SET(ASCEND_LIB ${ASCEND_LIB_DIR}/libge_runner.so)
SET(ASCEND_GRAPH_LIB ${ASCEND_LIB_DIR}/libgraph.so)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${ASCEND_ROOT}/lib")
INCLUDE_DIRECTORIES(${ASCEND_INC_DIR}) set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64)
FILE(WRITE ${ASCEND_DOWNLOAD_DIR}/CMakeLists.txt set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common)
"PROJECT(ASCEND)\n" set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share)
"cmake_minimum_required(VERSION 3.0)\n" set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64)
"install(DIRECTORY ${ASCEND_NAME}/include ${ASCEND_NAME}/lib \n" set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64)
" DESTINATION ${ASCEND_DST_DIR})\n") set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64)
ExternalProject_Add( set(STATIC_ACL_LIB ${ASCEND_ACL_DIR})
${ASCEND_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${ASCEND_SOURCE_DIR}
DOWNLOAD_DIR ${ASCEND_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${ASCEND_URL} -c -q -O ${ASCEND_NAME}.tar.gz
&& tar zxvf ${ASCEND_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${ASCEND_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ASCEND_INSTALL_ROOT}
)
ADD_LIBRARY(ascend SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ascend PROPERTY IMPORTED_LOCATION ${ASCEND_LIB})
ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL) set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR})
SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${ASCEND_GRAPH_LIB}) set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR})
ADD_DEPENDENCIES(ascend ascend_graph ${ASCEND_PROJECT}) set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)
set(ATLAS_RUNTIME_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64)
set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64)
set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR})
set(atlas_graph ${ATLAS_RUNTIME_DIR}/libgraph.so)
set(atlas_ge_runner ${ATLAS_RUNTIME_DIR}/libge_runner.so)
INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR})
ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner})
ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${atlas_graph})
add_custom_target(extern_ascend DEPENDS ascend_ge ascend_graph)
...@@ -17,7 +17,7 @@ INCLUDE(ExternalProject) ...@@ -17,7 +17,7 @@ INCLUDE(ExternalProject)
SET(CRYPTOPP_PREFIX_DIR ${THIRD_PARTY_PATH}/cryptopp) SET(CRYPTOPP_PREFIX_DIR ${THIRD_PARTY_PATH}/cryptopp)
SET(CRYPTOPP_INSTALL_DIR ${THIRD_PARTY_PATH}/install/cryptopp) SET(CRYPTOPP_INSTALL_DIR ${THIRD_PARTY_PATH}/install/cryptopp)
SET(CRYPTOPP_INCLUDE_DIR "${CRYPTOPP_INSTALL_DIR}/include" CACHE PATH "cryptopp include directory." FORCE) SET(CRYPTOPP_INCLUDE_DIR "${CRYPTOPP_INSTALL_DIR}/include" CACHE PATH "cryptopp include directory." FORCE)
SET(CRYPTOPP_REPOSITORY https://gitee.com/tianjianhe/cryptopp.git) SET(CRYPTOPP_REPOSITORY ${GIT_URL}/weidai11/cryptopp.git)
SET(CRYPTOPP_TAG CRYPTOPP_8_2_0) SET(CRYPTOPP_TAG CRYPTOPP_8_2_0)
IF(WIN32) IF(WIN32)
...@@ -33,7 +33,7 @@ set(CRYPTOPP_CMAKE_ARGS ${COMMON_CMAKE_ARGS} ...@@ -33,7 +33,7 @@ set(CRYPTOPP_CMAKE_ARGS ${COMMON_CMAKE_ARGS}
-DCMAKE_INSTALL_LIBDIR=${CRYPTOPP_INSTALL_DIR}/lib -DCMAKE_INSTALL_LIBDIR=${CRYPTOPP_INSTALL_DIR}/lib
-DCMAKE_INSTALL_PREFIX=${CRYPTOPP_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${CRYPTOPP_INSTALL_DIR}
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -17,7 +17,7 @@ include(ExternalProject) ...@@ -17,7 +17,7 @@ include(ExternalProject)
set(DLPACK_PREFIX_DIR ${THIRD_PARTY_PATH}/dlpack) set(DLPACK_PREFIX_DIR ${THIRD_PARTY_PATH}/dlpack)
set(DLPACK_SOURCE_DIR ${THIRD_PARTY_PATH}/dlpack/src/extern_dlpack) set(DLPACK_SOURCE_DIR ${THIRD_PARTY_PATH}/dlpack/src/extern_dlpack)
set(DLPACK_REPOSITORY https://gitee.com/tianjianhe/dlpack.git) set(DLPACK_REPOSITORY ${GIT_URL}/dmlc/dlpack.git)
set(DLPACK_TAG v0.2) set(DLPACK_TAG v0.2)
cache_third_party(extern_dlpack cache_third_party(extern_dlpack
......
...@@ -18,8 +18,8 @@ SET(GFLAGS_PREFIX_DIR ${THIRD_PARTY_PATH}/gflags) ...@@ -18,8 +18,8 @@ SET(GFLAGS_PREFIX_DIR ${THIRD_PARTY_PATH}/gflags)
SET(GFLAGS_SOURCE_DIR ${THIRD_PARTY_PATH}/gflags/src/extern_gflags) SET(GFLAGS_SOURCE_DIR ${THIRD_PARTY_PATH}/gflags/src/extern_gflags)
SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags) SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags)
SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE) SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE)
set(GFLAGS_REPOSITORY https://gitee.com/tianjianhe/gflags.git) set(GFLAGS_REPOSITORY ${GIT_URL}/gflags/gflags.git)
set(GFLAGS_TAG 77592648e3f3be87d6c7123eb81cbad75f9aef5a) set(GFLAGS_TAG "v2.2.2")
IF(WIN32) IF(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
ELSE(WIN32) ELSE(WIN32)
...@@ -48,7 +48,7 @@ ExternalProject_Add( ...@@ -48,7 +48,7 @@ ExternalProject_Add(
INSTALL_COMMAND ${INSTALL_COMMAND} INSTALL_COMMAND ${INSTALL_COMMAND}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
......
...@@ -18,8 +18,8 @@ SET(GLOG_PREFIX_DIR ${THIRD_PARTY_PATH}/glog) ...@@ -18,8 +18,8 @@ SET(GLOG_PREFIX_DIR ${THIRD_PARTY_PATH}/glog)
SET(GLOG_SOURCE_DIR ${THIRD_PARTY_PATH}/glog/src/extern_glog) SET(GLOG_SOURCE_DIR ${THIRD_PARTY_PATH}/glog/src/extern_glog)
SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog) SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog)
SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE) SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE)
SET(GLOG_REPOSITORY https://gitee.com/tianjianhe/glog.git) SET(GLOG_REPOSITORY ${GIT_URL}/google/glog.git)
SET(GLOG_TAG v0.3.5) SET(GLOG_TAG v0.4.0)
IF(WIN32) IF(WIN32)
SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/glog.lib" CACHE FILEPATH "glog library." FORCE) SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/glog.lib" CACHE FILEPATH "glog library." FORCE)
...@@ -47,7 +47,7 @@ ExternalProject_Add( ...@@ -47,7 +47,7 @@ ExternalProject_Add(
SOURCE_DIR ${GLOG_SOURCE_DIR} SOURCE_DIR ${GLOG_SOURCE_DIR}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
"-DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
......
...@@ -28,7 +28,7 @@ IF(APPLE) ...@@ -28,7 +28,7 @@ IF(APPLE)
SET(GRPC_INSTALL_CMD make prefix=${GRPC_INSTALL_DIR} install) SET(GRPC_INSTALL_CMD make prefix=${GRPC_INSTALL_DIR} install)
ELSE() ELSE()
SET(GRPC_CFLAGS "-Wno-error -std=c11 ${CLFAGS}") SET(GRPC_CFLAGS "-Wno-error -std=c11 ${CLFAGS}")
SET(GRPC_CXXFLAGS "-Wno-error -std=c++11 ${CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") SET(GRPC_CXXFLAGS "-Wno-error -std=c++11 ${CXXFLAGS}")
SET(BUILD_CMD make CFLAGS=${GRPC_CFLAGS} CXXFLAGS=${GRPC_CXXFLAGS} HAS_SYSTEM_PROTOBUF=false -s -j ${NUM_OF_PROCESSOR} static grpc_cpp_plugin) SET(BUILD_CMD make CFLAGS=${GRPC_CFLAGS} CXXFLAGS=${GRPC_CXXFLAGS} HAS_SYSTEM_PROTOBUF=false -s -j ${NUM_OF_PROCESSOR} static grpc_cpp_plugin)
SET(GRPC_INSTALL_CMD make prefix=${GRPC_INSTALL_DIR} install CFLAGS=${GRPC_CFLAGS} CXXFLAGS=${GRPC_CXXFLAGS}) SET(GRPC_INSTALL_CMD make prefix=${GRPC_INSTALL_DIR} install CFLAGS=${GRPC_CFLAGS} CXXFLAGS=${GRPC_CXXFLAGS})
ENDIF() ENDIF()
......
...@@ -17,7 +17,7 @@ INCLUDE(ExternalProject) ...@@ -17,7 +17,7 @@ INCLUDE(ExternalProject)
SET(CBLAS_PREFIX_DIR ${THIRD_PARTY_PATH}/openblas) SET(CBLAS_PREFIX_DIR ${THIRD_PARTY_PATH}/openblas)
SET(CBLAS_SOURCE_DIR ${THIRD_PARTY_PATH}/openblas/src/extern_openblas) SET(CBLAS_SOURCE_DIR ${THIRD_PARTY_PATH}/openblas/src/extern_openblas)
SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas)
SET(CBLAS_REPOSITORY https://gitee.com/tianjianhe/OpenBLAS.git) SET(CBLAS_REPOSITORY ${GIT_URL}/xianyi/OpenBLAS.git)
SET(CBLAS_TAG v0.3.7) SET(CBLAS_TAG v0.3.7)
if(WITH_MIPS) if(WITH_MIPS)
SET(CBLAS_TAG v0.3.13) SET(CBLAS_TAG v0.3.13)
......
...@@ -183,7 +183,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -183,7 +183,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
"-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}" "-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}" "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0" "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}" "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-Dprotobuf_WITH_ZLIB=ON" "-Dprotobuf_WITH_ZLIB=ON"
...@@ -198,8 +198,8 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -198,8 +198,8 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"-Dprotobuf_MSVC_STATIC_RUNTIME=${MSVC_STATIC_CRT}") "-Dprotobuf_MSVC_STATIC_RUNTIME=${MSVC_STATIC_CRT}")
ENDIF() ENDIF()
SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git) SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git)
SET(PROTOBUF_TAG v3.8.0) SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546)
cache_third_party(${TARGET_NAME} cache_third_party(${TARGET_NAME}
REPOSITORY ${PROTOBUF_REPOSITORY} REPOSITORY ${PROTOBUF_REPOSITORY}
...@@ -234,7 +234,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -234,7 +234,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
) )
ENDFUNCTION() ENDFUNCTION()
# SET(PROTOBUF_VERSION 3.1.0) SET(PROTOBUF_VERSION 3.1.0)
IF(NOT PROTOBUF_FOUND) IF(NOT PROTOBUF_FOUND)
build_protobuf(extern_protobuf FALSE) build_protobuf(extern_protobuf FALSE)
......
...@@ -16,8 +16,8 @@ include(ExternalProject) ...@@ -16,8 +16,8 @@ include(ExternalProject)
set(PYBIND_PREFIX_DIR ${THIRD_PARTY_PATH}/pybind) set(PYBIND_PREFIX_DIR ${THIRD_PARTY_PATH}/pybind)
set(PYBIND_SOURCE_DIR ${THIRD_PARTY_PATH}/pybind/src/extern_pybind) set(PYBIND_SOURCE_DIR ${THIRD_PARTY_PATH}/pybind/src/extern_pybind)
SET(PYBIND_REPOSITORY https://gitee.com/tianjianhe/pybind11.git) SET(PYBIND_REPOSITORY ${GIT_URL}/pybind/pybind11.git)
SET(PYBIND_TAG v2.6.0) SET(PYBIND_TAG v2.4.3)
cache_third_party(extern_pybind cache_third_party(extern_pybind
REPOSITORY ${PYBIND_REPOSITORY} REPOSITORY ${PYBIND_REPOSITORY}
......
...@@ -19,7 +19,7 @@ SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc) ...@@ -19,7 +19,7 @@ SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc) SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc)
set(WARPCTC_REPOSITORY https://gitee.com/tianjianhe/warp-ctc.git) set(WARPCTC_REPOSITORY https://gitee.com/tianjianhe/warp-ctc.git)
set(WARPCTC_TAG 95a461eddeabd51099ef059dcfada1117eb1bfb8) set(WARPCTC_TAG 95a461eddeabd51099ef059dcfada1117eb1bfb8)
# set(WARPCTC_TAG bc29dcfff07ced1c7a19a4ecee48e5ad583cef8e) set(WARPCTC_REPOSITORY ${GIT_URL}/baidu-research/warp-ctc.git)
SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE) CACHE PATH "Warp-ctc Directory" FORCE)
...@@ -53,7 +53,7 @@ ExternalProject_Add( ...@@ -53,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0" "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
......
...@@ -19,7 +19,7 @@ set(XBYAK_PREFIX_DIR ${THIRD_PARTY_PATH}/xbyak) ...@@ -19,7 +19,7 @@ set(XBYAK_PREFIX_DIR ${THIRD_PARTY_PATH}/xbyak)
SET(XBYAK_SOURCE_DIR ${THIRD_PARTY_PATH}/xbyak/src/extern_xbyak) SET(XBYAK_SOURCE_DIR ${THIRD_PARTY_PATH}/xbyak/src/extern_xbyak)
set(XBYAK_INSTALL_ROOT ${THIRD_PARTY_PATH}/install/xbyak) set(XBYAK_INSTALL_ROOT ${THIRD_PARTY_PATH}/install/xbyak)
set(XBYAK_INC_DIR ${XBYAK_INSTALL_ROOT}/include) set(XBYAK_INC_DIR ${XBYAK_INSTALL_ROOT}/include)
set(XBYAK_REPOSITORY https://gitee.com/tianjianhe/xbyak.git) set(XBYAK_REPOSITORY ${GIT_URL}/herumi/xbyak.git)
set(XBYAK_TAG v5.661) # Jul 26th set(XBYAK_TAG v5.661) # Jul 26th
include_directories(${XBYAK_INC_DIR}) include_directories(${XBYAK_INC_DIR})
......
...@@ -18,7 +18,7 @@ set(XXHASH_PREFIX_DIR ${THIRD_PARTY_PATH}/xxhash) ...@@ -18,7 +18,7 @@ set(XXHASH_PREFIX_DIR ${THIRD_PARTY_PATH}/xxhash)
set(XXHASH_SOURCE_DIR ${THIRD_PARTY_PATH}/xxhash/src/extern_xxhash) set(XXHASH_SOURCE_DIR ${THIRD_PARTY_PATH}/xxhash/src/extern_xxhash)
set(XXHASH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/xxhash) set(XXHASH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/xxhash)
set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include") set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include")
set(XXHASH_REPOSITORY https://gitee.com/tianjianhe/xxHash.git) set(XXHASH_REPOSITORY ${GIT_URL}/Cyan4973/xxHash.git)
set(XXHASH_TAG v0.6.5) set(XXHASH_TAG v0.6.5)
cache_third_party(extern_xxhash cache_third_party(extern_xxhash
......
...@@ -19,7 +19,7 @@ SET(ZLIB_SOURCE_DIR ${THIRD_PARTY_PATH}/zlib/src/extern_zlib) ...@@ -19,7 +19,7 @@ SET(ZLIB_SOURCE_DIR ${THIRD_PARTY_PATH}/zlib/src/extern_zlib)
SET(ZLIB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/zlib) SET(ZLIB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/zlib)
SET(ZLIB_ROOT ${ZLIB_INSTALL_DIR} CACHE FILEPATH "zlib root directory." FORCE) SET(ZLIB_ROOT ${ZLIB_INSTALL_DIR} CACHE FILEPATH "zlib root directory." FORCE)
SET(ZLIB_INCLUDE_DIR "${ZLIB_INSTALL_DIR}/include" CACHE PATH "zlib include directory." FORCE) SET(ZLIB_INCLUDE_DIR "${ZLIB_INSTALL_DIR}/include" CACHE PATH "zlib include directory." FORCE)
set(ZLIB_REPOSITORY https://gitee.com/tianjianhe/zlib.git) set(ZLIB_REPOSITORY ${GIT_URL}/madler/zlib.git)
set(ZLIB_TAG v1.2.8) set(ZLIB_TAG v1.2.8)
INCLUDE_DIRECTORIES(${ZLIB_INCLUDE_DIR}) # For zlib code to include its own headers. INCLUDE_DIRECTORIES(${ZLIB_INCLUDE_DIR}) # For zlib code to include its own headers.
...@@ -41,7 +41,7 @@ ExternalProject_Add( ...@@ -41,7 +41,7 @@ ExternalProject_Add(
CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_INSTALL_PREFIX=${ZLIB_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${ZLIB_INSTALL_DIR}
-DBUILD_SHARED_LIBS=OFF -DBUILD_SHARED_LIBS=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
......
...@@ -33,5 +33,5 @@ cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_conte ...@@ -33,5 +33,5 @@ cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_conte
cc_test(test_fleet_cc SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell) cc_test(test_fleet_cc SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell)
if(WITH_ASCEND) if(WITH_ASCEND)
cc_library(ascend_wrapper SRCS ascend_wrapper.cc DEPS framework_proto lod_tensor ascend ascend_graph) cc_library(ascend_wrapper SRCS ascend_wrapper.cc DEPS framework_proto lod_tensor ascend_ge ascend_graph)
endif(WITH_ASCEND) endif(WITH_ASCEND)
...@@ -37,7 +37,6 @@ limitations under the License. */ ...@@ -37,7 +37,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// typedef std::vector<std::string> AscendGraphDesc;
typedef ge::Graph AscendGraphDesc; typedef ge::Graph AscendGraphDesc;
class AscendInstance { class AscendInstance {
...@@ -45,17 +44,31 @@ class AscendInstance { ...@@ -45,17 +44,31 @@ class AscendInstance {
virtual ~AscendInstance() {} virtual ~AscendInstance() {}
AscendInstance() {} AscendInstance() {}
std::map<std::string, std::string> GetDefaultInitSessionOptions() { std::map<ge::AscendString, ge::AscendString> GetDefaultInitOptions() {
std::map<std::string, std::string> init_options; std::map<ge::AscendString, ge::AscendString> init_options;
init_options["a"] = "b"; init_options["ge.exec.deviceId"] = "0";
init_options["ge.trainFlag"] = "1"; init_options["ge.graphRunMode"] = "1";
return init_options; return init_options;
}
std::map<ge::AscendString, ge::AscendString> GetDefaultInitSessionOptions() {
std::map<ge::AscendString, ge::AscendString> init_options;
init_options["a"] = "b";
init_options["ge.trainFlag"] = "1";
return init_options;
}
ge::Status InitGEForUT(){
return ge::GEInitialize(GetDefaultInitOptions());
} }
// add other parameters here to init
void InitGlobalResouces() { void InitGlobalResouces() {
session_.reset(new ge::Session(GetDefaultInitSessionOptions())); LOG(INFO) << "Begin InitGlobalResouces";
VLOG(1) << "InitGlobalResouces Done"; session_.reset(new ge::Session(GetDefaultInitSessionOptions()));
if (session_ == nullptr){
LOG(FATAL) << "new session error:" << session_;
}
LOG(INFO) << "End InitGlobalResouces";
} }
static std::shared_ptr<AscendInstance> GetInstance() { static std::shared_ptr<AscendInstance> GetInstance() {
......
...@@ -33,6 +33,7 @@ limitations under the License. */ ...@@ -33,6 +33,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/fleet/ascend_wrapper.h" #include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#include "paddle/fluid/pybind/ascend_wrapper_py.h" #include "paddle/fluid/pybind/ascend_wrapper_py.h"
#include "paddle/fluid/platform/enforce.h"
using namespace ge; // NOLINT using namespace ge; // NOLINT
namespace py = pybind11; namespace py = pybind11;
...@@ -51,9 +52,22 @@ void BindAscendWrapper(py::module *m) { ...@@ -51,9 +52,22 @@ void BindAscendWrapper(py::module *m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} // end AscendWrapper } // end AscendWrapper
Status ge_initialize(std::map<std::string, std::string> &options) { // NOLINT std::map<ge::AscendString, ge::AscendString> convert_map(const std::map<std::string, std::string>& options){
std::map<ge::AscendString, ge::AscendString> rets;
for (auto &option : options) {
ge::AscendString key = option.first.c_str();
ge::AscendString val = option.second.c_str();
rets[key] = val;
}
return rets;
}
ge::Status ge_initialize(std::map<std::string, std::string> &options) { // NOLINT
py::gil_scoped_release release; py::gil_scoped_release release;
Status res = GEInitialize(options); auto init_options=convert_map(options);
ge::Status res = ge::GEInitialize(init_options);
PADDLE_ENFORCE_EQ(res,
ge::SUCCESS, platform::errors::Fatal("ge init error:%d", res));
py::gil_scoped_acquire acquire; py::gil_scoped_acquire acquire;
return res; return res;
} }
...@@ -214,36 +228,34 @@ void BindAscendGraph(py::module *m) { ...@@ -214,36 +228,34 @@ void BindAscendGraph(py::module *m) {
// 类封装 // 类封装
py::class_<Session>(*m, "GESession") py::class_<Session>(*m, "GESession")
.def(py::init<const std::map<std::string, std::string> &>()) .def(py::init([](const std::map<std::string, std::string> & options) {
return std::unique_ptr<ge::Session>(new ge::Session(convert_map(options)));
}))
.def("add_graph", .def("add_graph",
(Status (Session::*)(uint32_t, const Graph &)) & Session::AddGraph) (ge::Status (Session::*)(uint32_t, const Graph &)) & Session::AddGraph)
.def("add_graph", .def("add_graph",
(Status (Session::*)(uint32_t, const Graph &, [](Session& ss, uint32_t index, const Graph & graph,
const std::map<std::string, std::string> &)) & const std::map<std::string, std::string> &options){
Session::AddGraph) return ss.AddGraph(index, graph, convert_map(options));
})
.def("remove_graph", &Session::RemoveGraph) .def("remove_graph", &Session::RemoveGraph)
.def("run_graph", .def("run_graph",
[](Session &ss, uint32_t graphId, [](Session &ss, uint32_t graphId,
const std::vector<Tensor> &inputs) -> py::tuple { const std::vector<Tensor> &inputs) -> py::tuple {
std::vector<Tensor> outputs; std::vector<Tensor> outputs;
Status res = ss.RunGraph(graphId, inputs, outputs); ge::Status res = ss.RunGraph(graphId, inputs, outputs);
return py::make_tuple(outputs, res); return py::make_tuple(outputs, res);
}, },
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("build_graph", &Session::BuildGraph) .def("build_graph", &Session::BuildGraph)
.def("run_graph_async", &Session::RunGraphAsync) .def("run_graph_async", &Session::RunGraphAsync)
.def("register_call_back_func", .def("register_call_back_func",
(Status (Session::*)( // NOLINT static_cast<ge::Status (ge::Session::*)(const char*, const ge::session::pCallBackFunc&)>(&ge::Session::RegisterCallBackFunc))
const std::string &,
std::function<uint32_t(
uint32_t graph_id,
const std::map<std::string, ge::Tensor> &params_list)>)) &
Session::RegisterCallBackFunc)
.def("is_graph_need_rebuild", &Session::IsGraphNeedRebuild); .def("is_graph_need_rebuild", &Session::IsGraphNeedRebuild);
py::class_<Graph>(*m, "GEGraph") py::class_<Graph>(*m, "GEGraph")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const std::string &>()) .def(py::init<const char *>())
.def("set_inputs", &Graph::SetInputs) .def("set_inputs", &Graph::SetInputs)
.def("set_outputs", (Graph & (Graph::*)(const std::vector<Operator> &)) & .def("set_outputs", (Graph & (Graph::*)(const std::vector<Operator> &)) &
Graph::SetOutputs) Graph::SetOutputs)
...@@ -253,110 +265,121 @@ void BindAscendGraph(py::module *m) { ...@@ -253,110 +265,121 @@ void BindAscendGraph(py::module *m) {
Graph::SetOutputs) Graph::SetOutputs)
.def("set_outputs", .def("set_outputs",
(Graph & (Graph &
(Graph::*)(const std::vector<std::pair<ge::Operator, std::string>> (Graph::*)(const std::vector<std::pair<ge::Operator, ge::AscendString>>
&)) & &)) &
Graph::SetOutputs) Graph::SetOutputs)
.def("set_targets", &Graph::SetTargets) .def("set_targets", &Graph::SetTargets)
.def("is_valid", &Graph::IsValid) .def("is_valid", &Graph::IsValid)
.def("add_op", &Graph::AddOp) .def("add_op", &Graph::AddOp)
.def("find_op_by_name", .def("find_op_by_name",
[](Graph &graph, const std::string &name) -> py::tuple { [](Graph &graph, const char* name) -> py::tuple {
ge::Operator op; ge::Operator op;
graphStatus status = graph.FindOpByName(name, op); graphStatus status = graph.FindOpByName(name, op);
return py::make_tuple(op, status); return py::make_tuple(op, status);
}) })
.def("find_op_by_type", .def("find_op_by_type",
[](Graph &graph, const std::string &type) -> py::tuple { [](Graph &graph, const char * type) -> py::tuple {
std::vector<ge::Operator> ops; std::vector<ge::Operator> ops;
graphStatus status = graph.FindOpByType(type, ops); graphStatus status = graph.FindOpByType(type, ops);
return py::make_tuple(ops, status); return py::make_tuple(ops, status);
}) })
.def("get_all_op_name", .def("get_all_op_name",
[](Graph &graph) -> py::tuple { [](Graph &graph) -> py::tuple {
std::vector<std::string> op_name; std::vector<ge::AscendString> op_name;
graphStatus status = graph.GetAllOpName(op_name); graphStatus status = graph.GetAllOpName(op_name);
return py::make_tuple(op_name, status); return py::make_tuple(op_name, status);
}) })
.def("save_to_file", &Graph::SaveToFile) .def("save_to_file", static_cast<ge::graphStatus (ge::Graph::*)(const char *) const>(&ge::Graph::SaveToFile))
.def("load_from_file", &Graph::LoadFromFile) .def("load_from_file", static_cast<ge::graphStatus (ge::Graph::*)(const char*)>(&Graph::LoadFromFile))
.def("get_name", &Graph::GetName) .def("get_name", static_cast<ge::graphStatus (ge::Graph::*)(ge::AscendString&) const>(&Graph::GetName))
.def("set_need_iteration", &Graph::SetNeedIteration); .def("set_need_iteration", &Graph::SetNeedIteration);
py::class_<Operator>(*m, "GEOperator") py::class_<Operator>(*m, "GEOperator")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const std::string &>()) .def(py::init<const char *>())
.def(py::init<const std::string &, const std::string &>()) .def(py::init<const char*, const char *>())
.def("is_empty", &Operator::IsEmpty) .def("is_empty", &Operator::IsEmpty)
.def("get_name", &Operator::GetName) .def("get_name",
.def("get_op_type", &Operator::GetOpType) static_cast<ge::graphStatus (ge::Operator::*)(ge::AscendString&) const>(&Operator::GetName))
.def("get_op_type",
static_cast<ge::graphStatus (ge::Operator::*)(ge::AscendString&) const>(&Operator::GetOpType))
.def("set_input", .def("set_input",
(Operator & (Operator::*)(const std::string &, const Operator &)) & (Operator & (Operator::*)(const char*, const Operator &)) &
Operator::SetInput) Operator::SetInput)
.def("set_input", .def("set_input",
(Operator & (Operator::*)(const std::string &, const Operator &, (Operator & (Operator::*)(const char *, const Operator &,
const std::string &)) & const char *)) &
Operator::SetInput) Operator::SetInput)
.def("set_input", (Operator & (Operator::*)(const std::string &, .def("set_input", (Operator & (Operator::*)(const char *,
const Operator &, uint32_t)) & const Operator &, uint32_t)) &
Operator::SetInput) Operator::SetInput)
.def("add_control_input", &Operator::AddControlInput) .def("add_control_input", &Operator::AddControlInput)
.def("get_input_const_data", .def("get_input_const_data",
[](Operator &op, const std::string &dst_name) -> py::tuple { [](Operator &op, const char* dst_name) -> py::tuple {
Tensor data; Tensor data;
graphStatus res = op.GetInputConstData(dst_name, data); graphStatus res = op.GetInputConstData(dst_name, data);
return py::make_tuple(data, res); return py::make_tuple(data, res);
}) })
.def("get_input_desc", .def("get_input_desc",
(TensorDesc (Operator::*)(const std::string &) const) & (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc)
Operator::GetInputDesc)
.def("get_input_desc", .def("get_input_desc",
(TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc) [](Operator& op, const std::string& name){
.def("get_dynamic_output_num", &Operator::GetDynamicOutputNum) return op.GetInputDescByName(name.c_str());
.def("get_dynamic_input_num", &Operator::GetDynamicInputNum) })
.def("get_dynamic_output_num", static_cast<int (ge::Operator::*)(const char*) const>(&Operator::GetDynamicOutputNum))
.def("get_dynamic_input_num", static_cast<int (ge::Operator::*)(const char*) const>(&Operator::GetDynamicInputNum))
.def("try_get_input_desc", .def("try_get_input_desc",
[](Operator &op, const std::string &name) -> py::tuple { [](Operator &op, const char* name) -> py::tuple {
TensorDesc tensor_desc; TensorDesc tensor_desc;
graphStatus status = op.TryGetInputDesc(name, tensor_desc); graphStatus status = op.TryGetInputDesc(name, tensor_desc);
return py::make_tuple(tensor_desc, status); return py::make_tuple(tensor_desc, status);
}) })
.def("update_input_desc", &Operator::UpdateInputDesc) .def("update_input_desc",
static_cast<ge::graphStatus (ge::Operator::*)(const char*, const TensorDesc&)>(&Operator::UpdateInputDesc))
.def("get_output_desc", .def("get_output_desc",
(TensorDesc (Operator::*)(const std::string &) const) & [](Operator& op, const std::string& name) {
Operator::GetOutputDesc) return op.GetOutputDescByName(name.c_str());
})
.def("get_output_desc", .def("get_output_desc",
(TensorDesc (Operator::*)(uint32_t) const) & Operator::GetOutputDesc) (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetOutputDesc)
.def("update_output_desc", &Operator::UpdateOutputDesc) .def("update_output_desc",
.def("get_dynamic_input_desc", &Operator::GetDynamicInputDesc) static_cast<ge::graphStatus (ge::Operator::*)(const char*, const TensorDesc&)>(&Operator::UpdateOutputDesc))
.def("update_dynamic_input_desc", &Operator::UpdateDynamicInputDesc) .def("get_dynamic_input_desc",
.def("get_dynamic_output_desc", &Operator::GetDynamicOutputDesc) static_cast<ge::TensorDesc (ge::Operator::*)(const char*, uint32_t) const>(&Operator::GetDynamicInputDesc))
.def("update_dynamic_output_desc", &Operator::UpdateDynamicOutputDesc) .def("update_dynamic_input_desc",
static_cast<ge::graphStatus (ge::Operator::*)(const char*, uint32_t, const TensorDesc&)>(&Operator::UpdateDynamicInputDesc))
.def("get_dynamic_output_desc",
static_cast<ge::TensorDesc (ge::Operator::*)(const char*, uint32_t) const>(&Operator::GetDynamicOutputDesc))
.def("update_dynamic_output_desc",
static_cast<ge::graphStatus (ge::Operator::*)(const char*, uint32_t, const TensorDesc&)>(&Operator::UpdateDynamicOutputDesc))
.def("infer_shape_and_type", &Operator::InferShapeAndType) .def("infer_shape_and_type", &Operator::InferShapeAndType)
.def("set_inference_context", &Operator::SetInferenceContext) .def("set_inference_context", &Operator::SetInferenceContext)
.def("get_inference_context", &Operator::GetInferenceContext) .def("get_inference_context", &Operator::GetInferenceContext)
.def("verify_all_attr", &Operator::VerifyAllAttr) .def("verify_all_attr", &Operator::VerifyAllAttr)
.def("get_inputs_size", &Operator::GetInputsSize) .def("get_inputs_size", &Operator::GetInputsSize)
.def("get_outputs_size", &Operator::GetOutputsSize) .def("get_outputs_size", &Operator::GetOutputsSize)
.def("get_all_attr_names_and_types", &Operator::GetAllAttrNamesAndTypes) .def("get_all_attr_names_and_types",
static_cast<ge::graphStatus (ge::Operator::*)(std::map<ge::AscendString, ge::AscendString>&) const>(&Operator::GetAllAttrNamesAndTypes))
.def("set_attr_int64", .def("set_attr_int64",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
int64_t value) -> Operator & { int64_t value) -> Operator & {
int64_t tar = (int64_t)value; int64_t tar = (int64_t)value;
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_int32", .def("set_attr_int32",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
int32_t value) -> Operator & { int32_t value) -> Operator & {
int32_t tar = (int32_t)value; int32_t tar = (int32_t)value;
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_uint32", .def("set_attr_uint32",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
uint32_t value) -> Operator & { uint32_t value) -> Operator & {
uint32_t tar = (uint32_t)value; uint32_t tar = (uint32_t)value;
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_vec_int64", .def("set_attr_vec_int64",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
const std::vector<int64_t> &value) -> Operator & { const std::vector<int64_t> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<int64_t> tar; std::vector<int64_t> tar;
...@@ -368,7 +391,7 @@ void BindAscendGraph(py::module *m) { ...@@ -368,7 +391,7 @@ void BindAscendGraph(py::module *m) {
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_vec_int32", .def("set_attr_vec_int32",
[](Operator &op, const std::string &name, [](Operator &op, const char * name,
const std::vector<int32_t> &value) -> Operator & { const std::vector<int32_t> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<int32_t> tar; std::vector<int32_t> tar;
...@@ -380,7 +403,7 @@ void BindAscendGraph(py::module *m) { ...@@ -380,7 +403,7 @@ void BindAscendGraph(py::module *m) {
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_vec_uint32", .def("set_attr_vec_uint32",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
const std::vector<uint32_t> &value) -> Operator & { const std::vector<uint32_t> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<uint32_t> tar; std::vector<uint32_t> tar;
...@@ -392,21 +415,21 @@ void BindAscendGraph(py::module *m) { ...@@ -392,21 +415,21 @@ void BindAscendGraph(py::module *m) {
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_list_int64", .def("set_attr_list_int64",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
std::initializer_list<int64_t> &attrValue) -> Operator & { std::initializer_list<int64_t> &attrValue) -> Operator & {
return op.SetAttr(name, std::move(attrValue)); return op.SetAttr(name, std::move(attrValue));
}) })
.def("set_attr_attrvalue", .def("set_attr_attrvalue",
[](Operator &op, const std::string &name, AttrValue &attrValue) [](Operator &op, const char* name, AttrValue &attrValue)
-> Operator & { return op.SetAttr(name, std::move(attrValue)); }) -> Operator & { return op.SetAttr(name, std::move(attrValue)); })
.def( .def(
"set_attr_float", "set_attr_float",
[](Operator &op, const std::string &name, float value) -> Operator & { [](Operator &op, const char* name, float value) -> Operator & {
float tar = static_cast<float>(value); float tar = static_cast<float>(value);
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_vec_float", .def("set_attr_vec_float",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
const std::vector<float> &value) -> Operator & { const std::vector<float> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<float> tar; std::vector<float> tar;
...@@ -417,22 +440,22 @@ void BindAscendGraph(py::module *m) { ...@@ -417,22 +440,22 @@ void BindAscendGraph(py::module *m) {
} }
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_string", (Operator & (Operator::*)(const std::string &, .def("set_attr_string", (Operator & (Operator::*)(const char*,
const std::string &)) & const char*)) &
Operator::SetAttr) Operator::SetAttr)
.def("set_attr_vec_string", .def("set_attr_vec_string",
(Operator & (Operator::*)(const std::string &, (Operator & (Operator::*)(const char*,
const std::vector<std::string> &)) & const std::vector<ge::AscendString> &)) &
Operator::SetAttr) Operator::SetAttr)
.def("set_attr_bool", .def("set_attr_bool",
[](Operator &op, const std::string &name, bool value) -> Operator & { [](Operator &op, const char* name, bool value) -> Operator & {
if (value) if (value)
return op.SetAttr(name, true); return op.SetAttr(name, true);
else else
return op.SetAttr(name, false); return op.SetAttr(name, false);
}) })
.def("set_attr_vec_bool", .def("set_attr_vec_bool",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
const std::vector<bool> &value) -> Operator & { const std::vector<bool> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<bool> tar; std::vector<bool> tar;
...@@ -445,14 +468,14 @@ void BindAscendGraph(py::module *m) { ...@@ -445,14 +468,14 @@ void BindAscendGraph(py::module *m) {
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_tensor", .def("set_attr_tensor",
(Operator & (Operator::*)(const std::string &, const Tensor &)) & (Operator & (Operator::*)(const char* , const Tensor &)) &
Operator::SetAttr) Operator::SetAttr)
.def("set_attr_vec_tensor", .def("set_attr_vec_tensor",
(Operator & (Operator &
(Operator::*)(const std::string &, const std::vector<Tensor> &)) & (Operator::*)(const char *, const std::vector<Tensor> &)) &
Operator::SetAttr) Operator::SetAttr)
.def("set_attr_vec_uint8", .def("set_attr_vec_uint8",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
const std::vector<uint8_t> &value) -> Operator & { const std::vector<uint8_t> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<uint8_t> tar; std::vector<uint8_t> tar;
...@@ -465,11 +488,11 @@ void BindAscendGraph(py::module *m) { ...@@ -465,11 +488,11 @@ void BindAscendGraph(py::module *m) {
}) })
.def("set_attr_vec_vec_int64", .def("set_attr_vec_vec_int64",
(Operator & (Operator &
(Operator::*)(const std::string &, (Operator::*)(const char*,
const std::vector<std::vector<int64_t>> &)) & const std::vector<std::vector<int64_t>> &)) &
Operator::SetAttr) Operator::SetAttr)
.def("set_attr_vec_dtype", .def("set_attr_vec_dtype",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
const std::vector<DataType> &value) -> Operator & { const std::vector<DataType> &value) -> Operator & {
int len = value.size(); int len = value.size();
std::vector<ge::DataType> tar; std::vector<ge::DataType> tar;
...@@ -481,14 +504,14 @@ void BindAscendGraph(py::module *m) { ...@@ -481,14 +504,14 @@ void BindAscendGraph(py::module *m) {
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("set_attr_dtype", .def("set_attr_dtype",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
const DataType &value) -> Operator & { const DataType &value) -> Operator & {
ge::DataType tar = (ge::DataType)value; ge::DataType tar = (ge::DataType)value;
return op.SetAttr(name, tar); return op.SetAttr(name, tar);
}) })
.def("get_attr", .def("get_attr",
[](Operator &op, const std::string &name, [](Operator &op, const char* name,
AttrType type) -> py::tuple { AttrType type) -> py::tuple {
graphStatus res = -1; graphStatus res = -1;
switch (type) { switch (type) {
...@@ -538,12 +561,12 @@ void BindAscendGraph(py::module *m) { ...@@ -538,12 +561,12 @@ void BindAscendGraph(py::module *m) {
return py::make_tuple(o_av, res); return py::make_tuple(o_av, res);
} break; } break;
case AT_STRING: { case AT_STRING: {
std::string s_av; ge::AscendString s_av;
res = op.GetAttr(name, s_av); res = op.GetAttr(name, s_av);
return py::make_tuple(s_av, res); return py::make_tuple(s_av, res);
} break; } break;
case AT_LIST_STRING: { case AT_LIST_STRING: {
std::vector<std::string> v_s_av; std::vector<ge::AscendString> v_s_av;
res = op.GetAttr(name, v_s_av); res = op.GetAttr(name, v_s_av);
return py::make_tuple(v_s_av, res); return py::make_tuple(v_s_av, res);
} break; } break;
...@@ -594,11 +617,11 @@ void BindAscendGraph(py::module *m) { ...@@ -594,11 +617,11 @@ void BindAscendGraph(py::module *m) {
}) })
.def("break_connect", &Operator::BreakConnect) .def("break_connect", &Operator::BreakConnect)
.def("get_subgraph_names_count", &Operator::GetSubgraphNamesCount) .def("get_subgraph_names_count", &Operator::GetSubgraphNamesCount)
.def("get_subgraph_names", &Operator::GetSubgraphNames) .def("get_subgraph_names", static_cast<ge::graphStatus (ge::Operator::*)(std::vector<ge::AscendString> &) const>(&Operator::GetSubgraphNames))
.def("get_subgraph_builder", &Operator::GetSubgraphBuilder) .def("get_subgraph_builder", static_cast<ge::SubgraphBuilder (ge::Operator::*)(const char*) const>(&Operator::GetSubgraphBuilder))
.def("get_subgraph", &Operator::GetSubgraph) .def("get_subgraph", static_cast<ge::Graph (ge::Operator::*)(const char*) const>(&Operator::GetSubgraph))
.def("get_dynamic_subgraph_builder", &Operator::GetDynamicSubgraphBuilder) .def("get_dynamic_subgraph_builder", static_cast<ge::SubgraphBuilder (ge::Operator::*)(const char*, uint32_t) const>(&Operator::GetDynamicSubgraphBuilder))
.def("get_dynamic_subgraph", &Operator::GetDynamicSubgraph); .def("get_dynamic_subgraph", static_cast<ge::Graph (ge::Operator::*)(const char*, uint32_t) const>(&Operator::GetDynamicSubgraph));
py::class_<Tensor>(*m, "GETensor") py::class_<Tensor>(*m, "GETensor")
.def(py::init<>()) .def(py::init<>())
...@@ -614,9 +637,9 @@ void BindAscendGraph(py::module *m) { ...@@ -614,9 +637,9 @@ void BindAscendGraph(py::module *m) {
.def("set_data", .def("set_data",
(graphStatus (Tensor::*)(const uint8_t *, size_t)) & Tensor::SetData) (graphStatus (Tensor::*)(const uint8_t *, size_t)) & Tensor::SetData)
.def("set_data", .def("set_data",
(graphStatus (Tensor::*)(const std::string &)) & Tensor::SetData) (graphStatus (Tensor::*)(const char*)) & Tensor::SetData)
.def("set_data", .def("set_data",
(graphStatus (Tensor::*)(const std::vector<std::string> &)) & (graphStatus (Tensor::*)(const std::vector<ge::AscendString> &)) &
Tensor::SetData) Tensor::SetData)
.def("get_data", .def("get_data",
...@@ -639,7 +662,7 @@ void BindAscendGraph(py::module *m) { ...@@ -639,7 +662,7 @@ void BindAscendGraph(py::module *m) {
py::arg("format") = FORMAT_ND, py::arg("dt") = DT_FLOAT) py::arg("format") = FORMAT_ND, py::arg("dt") = DT_FLOAT)
.def(py::init<const TensorDesc &>()) .def(py::init<const TensorDesc &>())
.def("update", .def("update",
(void (TensorDesc::*)(Shape, Format, DataType)) & TensorDesc::Update, (void (TensorDesc::*)(const Shape&, Format, DataType)) & TensorDesc::Update,
py::arg("shape"), py::arg("format") = FORMAT_ND, py::arg("shape"), py::arg("format") = FORMAT_ND,
py::arg("dt") = DT_FLOAT) py::arg("dt") = DT_FLOAT)
.def("set_shape", &TensorDesc::SetShape) .def("set_shape", &TensorDesc::SetShape)
...@@ -660,8 +683,8 @@ void BindAscendGraph(py::module *m) { ...@@ -660,8 +683,8 @@ void BindAscendGraph(py::module *m) {
.def("get_origin_format", &TensorDesc::GetOriginFormat) .def("get_origin_format", &TensorDesc::GetOriginFormat)
.def("set_data_type", &TensorDesc::SetDataType) .def("set_data_type", &TensorDesc::SetDataType)
.def("get_data_type", &TensorDesc::GetDataType) .def("get_data_type", &TensorDesc::GetDataType)
.def("set_name", &TensorDesc::SetName) .def("set_name", static_cast<void (ge::TensorDesc::*)(const char*)>(&TensorDesc::SetName))
.def("get_name", &TensorDesc::GetName) .def("get_name", static_cast<ge::graphStatus (ge::TensorDesc::*)(ge::AscendString&)>(&TensorDesc::GetName))
.def("set_size", &TensorDesc::SetSize) .def("set_size", &TensorDesc::SetSize)
.def("get_size", &TensorDesc::GetSize) .def("get_size", &TensorDesc::GetSize)
.def("set_real_dim_cnt", &TensorDesc::SetRealDimCnt) .def("set_real_dim_cnt", &TensorDesc::SetRealDimCnt)
...@@ -679,14 +702,16 @@ void BindAscendGraph(py::module *m) { ...@@ -679,14 +702,16 @@ void BindAscendGraph(py::module *m) {
py::class_<AttrValue>(*m, "GEAttrValue").def(py::init<>()); py::class_<AttrValue>(*m, "GEAttrValue").def(py::init<>());
py::class_<OperatorFactory>(*m, "GEOperatorFactory") py::class_<OperatorFactory>(*m, "GEOperatorFactory")
.def("create_operator", &OperatorFactory::CreateOperator) .def_static("create_operator",
static_cast<ge::Operator (*)(const char*, const char*)>(&ge::OperatorFactory::CreateOperator))
.def("get_ops_type_list", .def("get_ops_type_list",
[]() -> py::tuple { []() -> py::tuple {
std::vector<std::string> all_ops; std::vector<ge::AscendString> all_ops;
graphStatus status = OperatorFactory::GetOpsTypeList(all_ops); graphStatus status = OperatorFactory::GetOpsTypeList(all_ops);
return py::make_tuple(all_ops, status); return py::make_tuple(all_ops, status);
}) })
.def("is_exist_op", &OperatorFactory::IsExistOp); .def_static("is_exist_op",
static_cast<bool (*)(const char*)>(&OperatorFactory::IsExistOp));
} }
} // end namespace pybind } // end namespace pybind
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <unistd.h>
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -23,6 +24,9 @@ ...@@ -23,6 +24,9 @@
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#ifdef PADDLE_WITH_ASCEND
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#endif
// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are // NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are
// determined by the OP`s proto automatically, i.e., all the inputs registered // determined by the OP`s proto automatically, i.e., all the inputs registered
...@@ -444,6 +448,11 @@ int main(int argc, char* argv[]) { ...@@ -444,6 +448,11 @@ int main(int argc, char* argv[]) {
return -1; return -1;
} }
#ifdef PADDLE_WITH_ASCEND
auto ascend_ptr = paddle::framework::AscendInstance::GetInstance();
ascend_ptr->InitGEForUT();
#endif
std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\""}; std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\""};
std::ofstream out(argv[1], std::ios::out); std::ofstream out(argv[1], std::ios::out);
...@@ -473,5 +482,9 @@ int main(int argc, char* argv[]) { ...@@ -473,5 +482,9 @@ int main(int argc, char* argv[]) {
<< "} // namespace paddle\n"; << "} // namespace paddle\n";
out.close(); out.close();
#ifdef PADDLE_WITH_ASCEND
ge::GEFinalize();
#endif
return 0; return 0;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册