提交 384fc465 编写于 作者: H Houjiang Chen 提交者: Shenghang Tsai

Update setup py (#2532)

* Update tensorflow cmakefile

* Update setup.py
上级 c2584238
......@@ -108,5 +108,4 @@ if (THIRD_PARTY)
set(THIRD_PARTY OFF CACHE BOOL "" FORCE)
else()
include(oneflow)
configure_file(${PROJECT_SOURCE_DIR}/setup.py.in ${PROJECT_BINARY_DIR}/setup.py)
endif()
......@@ -103,6 +103,7 @@ set(oneflow_third_party_dependencies
cocoapi_copy_libs_to_destination
half_copy_headers_to_destination
json_copy_headers_to_destination
tensorflow_copy_libs_to_destination
)
include_directories(
......
......@@ -28,10 +28,12 @@ set(TENSORFLOW_PROJECT tensorflow)
set(TENSORFLOW_GIT_URL https://github.com/tensorflow/tensorflow.git)
#set(TENSORFLOW_GIT_TAG master)
set(TENSORFLOW_GIT_TAG 80c04b80ad66bf95aa3f41d72a6bba5e84a99622)
set(TENSORFLOW_SOURCES_DIR ${THIRD_PARTY_DIR}/tensorflow)
set(TENSORFLOW_SOURCES_DIR ${CMAKE_CURRENT_BINARY_DIR}/third_party/tensorflow)
set(TENSORFLOW_SRCS_DIR ${TENSORFLOW_SOURCES_DIR}/src/tensorflow)
set(TENSORFLOW_INC_DIR ${TENSORFLOW_SOURCES_DIR}/src/tensorflow)
set(TENSORFLOW_INSTALL_DIR ${THIRD_PARTY_DIR}/tensorflow)
set(PATCHES_DIR ${PROJECT_SOURCE_DIR}/oneflow/xrt/patches)
set(TENSORFLOW_JIT_DIR ${TENSORFLOW_SRCS_DIR}/tensorflow/compiler/jit)
......@@ -51,13 +53,9 @@ list(APPEND TENSORFLOW_XLA_INCLUDE_DIR
${THIRD_SNAPPY_DIR}
)
include_directories(${TENSORFLOW_XLA_INCLUDE_DIR})
list(APPEND TENSORFLOW_XLA_LIBRARIES libtensorflow_framework.so.1)
list(APPEND TENSORFLOW_XLA_LIBRARIES libxla_core.so)
link_directories(
${TENSORFLOW_SRCS_DIR}/bazel-bin/tensorflow
${TENSORFLOW_SRCS_DIR}/bazel-bin/tensorflow/compiler/jit/xla_lib
)
link_directories(${TENSORFLOW_INSTALL_DIR}/lib)
if (THIRD_PARTY)
ExternalProject_Add(${TENSORFLOW_PROJECT}
......@@ -70,9 +68,21 @@ if (THIRD_PARTY)
bazel build ${TENSORFLOW_BUILD_CMD} -j 20 //tensorflow/compiler/jit/xla_lib:libxla_core.so
INSTALL_COMMAND ""
)
endif(THIRD_PARTY)
set(TENSORFLOW_XLA_FRAMEWORK_LIB ${TENSORFLOW_SRCS_DIR}/bazel-bin/tensorflow/libtensorflow_framework.so.1)
set(TENSORFLOW_XLA_CORE_LIB ${TENSORFLOW_SRCS_DIR}/bazel-bin/tensorflow/compiler/jit/xla_lib/libxla_core.so)
add_custom_target(tensorflow_create_library_dir
COMMAND ${CMAKE_COMMAND} -E make_directory ${TENSORFLOW_INSTALL_DIR}/lib
DEPENDS ${TENSORFLOW_PROJECT})
add_custom_target(tensorflow_copy_libs_to_destination
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${TENSORFLOW_XLA_FRAMEWORK_LIB} ${TENSORFLOW_XLA_CORE_LIB} ${TENSORFLOW_INSTALL_DIR}/lib
COMMAND ${CMAKE_COMMAND} -E create_symlink
${TENSORFLOW_INSTALL_DIR}/lib/libtensorflow_framework.so.1
${TENSORFLOW_INSTALL_DIR}/lib/libtensorflow_framework.so
DEPENDS tensorflow_create_library_dir)
endif(THIRD_PARTY)
endif(WITH_XLA)
......@@ -3,12 +3,25 @@ from __future__ import absolute_import
import os
import re
import sys
import argparse
import shutil
from setuptools import find_packages
from setuptools import setup
from setuptools.dist import Distribution
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--with_xla",
type='bool',
default=False,
help="Package xla libraries if true, otherwise not."
)
args, remain_args = parser.parse_known_args()
sys.argv = ['setup.py'] + remain_args
REQUIRED_PACKAGES = [
'numpy',
'protobuf',
......@@ -24,19 +37,19 @@ packages = find_packages("build/python_scripts")
package_dir = {
'':'build/python_scripts',
}
package_data['oneflow'] = ['_oneflow_internal.so']
if '${WITH_XLA}' == 'ON':
packages += ['oneflow.libs']
libs_path = 'python_scripts/oneflow/libs'
package_dir['oneflow.libs'] = libs_path
package_data['oneflow.libs'] = ['libtensorflow_framework.so.1', 'libxla_core.so']
shutil.copy('${TENSORFLOW_XLA_FRAMEWORK_LIB}', libs_path)
shutil.copy('${TENSORFLOW_XLA_CORE_LIB}', libs_path)
command = "patchelf --set-rpath '$ORIGIN/' ${TENSORFLOW_XLA_CORE_LIB}"
if os.system(command) != 0:
raise Exception("patch xla failed, command: %s" % command)
package_data = {'oneflow': ['_oneflow_internal.so']}
if args.with_xla:
packages += ['oneflow.libs']
package_dir['oneflow.libs'] = 'third_party/tensorflow/lib'
package_data['oneflow.libs'] = ['libtensorflow_framework.so.1', 'libxla_core.so']
# Patchelf >= 0.9 is required.
oneflow_internal_so = "build/python_scripts/oneflow/_oneflow_internal.so"
rpath = os.popen("patchelf --print-rpath " + oneflow_internal_so).read()
command = "patchelf --set-rpath '$ORIGIN/:$ORIGIN/libs/:%s' %s" % \
(rpath.strip(), oneflow_internal_so)
if os.system(command) != 0:
raise Exception("Patchelf set rpath failed. command is: %s" % command)
setup(
name='oneflow',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册