提交 6d4c4405 编写于 作者: Y yangyaming

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-3923-c

...@@ -67,6 +67,9 @@ endif() ...@@ -67,6 +67,9 @@ endif()
if(ANDROID) if(ANDROID)
if(${CMAKE_SYSTEM_VERSION} VERSION_LESS "16") if(${CMAKE_SYSTEM_VERSION} VERSION_LESS "16")
message(FATAL_ERROR "Unsupport standalone toolchains with Android API level lower than 16") message(FATAL_ERROR "Unsupport standalone toolchains with Android API level lower than 16")
elseif(${CMAKE_SYSTEM_VERSION} VERSION_LESS "21")
# TODO: support glog for Android api 16 ~ 19 in the future
message(WARNING "Using the unofficial git repository <https://github.com/Xreki/glog.git> instead")
endif() endif()
set(WITH_GPU OFF CACHE STRING set(WITH_GPU OFF CACHE STRING
......
...@@ -6,13 +6,14 @@ RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ub ...@@ -6,13 +6,14 @@ RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ub
# ENV variables # ENV variables
ARG ANDROID_ABI ARG ANDROID_ABI
ARG ANDROID_API
ENV ANDROID_ABI=${ANDROID_ABI:-"armeabi-v7a"} ENV ANDROID_ABI=${ANDROID_ABI:-"armeabi-v7a"}
ENV ANDROID_API=${ANDROID_API:-21}
ENV HOME=/root \ ENV HOME=/root \
ANDROID_NDK_HOME=/opt/android-ndk-linux \ ANDROID_NDK_HOME=/opt/android-ndk-linux \
ANDROID_ARM_STANDALONE_TOOLCHAIN=/opt/arm-toolchain \ ANDROID_TOOLCHAINS_DIR=/opt/toolchains
ANDROID_ARM64_STANDALONE_TOOLCHAIN=/opt/arm64-toolchain
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y \ apt-get install -y \
...@@ -42,14 +43,12 @@ RUN pip install --upgrade pip && \ ...@@ -42,14 +43,12 @@ RUN pip install --upgrade pip && \
pip install pre-commit pip install pre-commit
# Android NDK # Android NDK
RUN mkdir /opt/android-ndk-tmp && \ RUN mkdir -p ${ANDROID_TOOLCHAINS_DIR} && \
mkdir -p /opt/android-ndk-tmp && \
cd /opt/android-ndk-tmp && \ cd /opt/android-ndk-tmp && \
wget -q https://dl.google.com/android/repository/android-ndk-r14b-linux-x86_64.zip && \ wget -q https://dl.google.com/android/repository/android-ndk-r14b-linux-x86_64.zip && \
unzip -q android-ndk-r14b-linux-x86_64.zip && \ unzip -q android-ndk-r14b-linux-x86_64.zip && \
mv android-ndk-r14b ${ANDROID_NDK_HOME} && \ mv android-ndk-r14b ${ANDROID_NDK_HOME} && \
${ANDROID_NDK_HOME}/build/tools/make-standalone-toolchain.sh --arch=arm --platform=android-23 --install-dir=${ANDROID_ARM_STANDALONE_TOOLCHAIN} && \ rm -rf /opt/android-ndk-tmp
${ANDROID_NDK_HOME}/build/tools/make-standalone-toolchain.sh --arch=arm64 --platform=android-23 --install-dir=${ANDROID_ARM64_STANDALONE_TOOLCHAIN} && \
rm -rf /opt/android-ndk-tmp && \
rm -rf ${ANDROID_NDK_HOME}
CMD ["bash", "/paddle/paddle/scripts/docker/build_android.sh"] CMD ["bash", "/paddle/paddle/scripts/docker/build_android.sh"]
...@@ -18,9 +18,9 @@ SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags) ...@@ -18,9 +18,9 @@ SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags)
SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags) SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags)
SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE) SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE)
IF(WIN32) IF(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/gflags.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/gflags.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
ELSE(WIN32) ELSE(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.a" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.a" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
ENDIF(WIN32) ENDIF(WIN32)
INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR})
...@@ -56,3 +56,12 @@ SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES}) ...@@ -56,3 +56,12 @@ SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES})
ADD_DEPENDENCIES(gflags extern_gflags) ADD_DEPENDENCIES(gflags extern_gflags)
LIST(APPEND external_project_dependencies gflags) LIST(APPEND external_project_dependencies gflags)
IF(WITH_C_API)
INSTALL(DIRECTORY ${GFLAGS_INCLUDE_DIR} DESTINATION third_party/gflags)
IF(ANDROID)
INSTALL(FILES ${GFLAGS_LIBRARIES} DESTINATION third_party/gflags/lib/${ANDROID_ABI})
ELSE()
INSTALL(FILES ${GFLAGS_LIBRARIES} DESTINATION third_party/gflags/lib)
ENDIF()
ENDIF()
...@@ -19,9 +19,9 @@ SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog) ...@@ -19,9 +19,9 @@ SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog)
SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE) SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE)
IF(WIN32) IF(WIN32)
SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.lib" CACHE FILEPATH "glog library." FORCE) SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.lib" CACHE FILEPATH "glog library." FORCE)
ELSE(WIN32) ELSE(WIN32)
SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.a" CACHE FILEPATH "glog library." FORCE) SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.a" CACHE FILEPATH "glog library." FORCE)
ENDIF(WIN32) ENDIF(WIN32)
INCLUDE_DIRECTORIES(${GLOG_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${GLOG_INCLUDE_DIR})
...@@ -56,3 +56,12 @@ ADD_DEPENDENCIES(glog extern_glog gflags) ...@@ -56,3 +56,12 @@ ADD_DEPENDENCIES(glog extern_glog gflags)
LINK_LIBRARIES(glog gflags) LINK_LIBRARIES(glog gflags)
LIST(APPEND external_project_dependencies glog) LIST(APPEND external_project_dependencies glog)
IF(WITH_C_API)
INSTALL(DIRECTORY ${GLOG_INCLUDE_DIR} DESTINATION third_party/glog)
IF(ANDROID)
INSTALL(FILES ${GLOG_LIBRARIES} DESTINATION third_party/glog/lib/${ANDROID_ABI})
ELSE()
INSTALL(FILES ${GLOG_LIBRARIES} DESTINATION third_party/glog/lib)
ENDIF()
ENDIF()
...@@ -73,6 +73,26 @@ IF(NOT ${CBLAS_FOUND}) ...@@ -73,6 +73,26 @@ IF(NOT ${CBLAS_FOUND})
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
) )
IF(WITH_C_API)
INSTALL(DIRECTORY ${CBLAS_INC_DIR} DESTINATION third_party/openblas)
# Because libopenblas.a is a symbolic link of another library, thus need to
# install the whole directory.
IF(ANDROID)
SET(TMP_INSTALL_DIR third_party/openblas/lib/${ANDROID_ABI})
ELSE()
SET(TMP_INSTALL_DIR third_party/openblas/lib)
ENDIF()
INSTALL(CODE "execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CBLAS_INSTALL_DIR}/lib
destination ${CMAKE_INSTALL_PREFIX}/${TMP_INSTALL_DIR}
)"
)
INSTALL(CODE "MESSAGE(STATUS \"Installing: \"
\"${CBLAS_INSTALL_DIR}/lib -> ${CMAKE_INSTALL_PREFIX}/${TMP_INSTALL_DIR}\"
)"
)
ENDIF()
ENDIF(NOT ${CBLAS_FOUND}) ENDIF(NOT ${CBLAS_FOUND})
MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}") MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}")
......
...@@ -223,6 +223,15 @@ IF(NOT PROTOBUF_FOUND) ...@@ -223,6 +223,15 @@ IF(NOT PROTOBUF_FOUND)
SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY} SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY}
CACHE FILEPATH "protoc library." FORCE) CACHE FILEPATH "protoc library." FORCE)
IF(WITH_C_API)
INSTALL(DIRECTORY ${PROTOBUF_INCLUDE_DIR} DESTINATION third_party/protobuf)
IF(ANDROID)
INSTALL(FILES ${PROTOBUF_LIBRARY} DESTINATION third_party/protobuf/lib/${ANDROID_ABI})
ELSE()
INSTALL(FILES ${PROTOBUF_LIBRARY} DESTINATION third_party/protobuf/lib)
ENDIF()
ENDIF()
IF(CMAKE_CROSSCOMPILING) IF(CMAKE_CROSSCOMPILING)
PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf) PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf)
ELSE() ELSE()
......
...@@ -49,3 +49,12 @@ ExternalProject_Add( ...@@ -49,3 +49,12 @@ ExternalProject_Add(
) )
LIST(APPEND external_project_dependencies zlib) LIST(APPEND external_project_dependencies zlib)
IF(WITH_C_API)
INSTALL(DIRECTORY ${ZLIB_INCLUDE_DIR} DESTINATION third_party/zlib)
IF(ANDROID)
INSTALL(FILES ${ZLIB_LIBRARIES} DESTINATION third_party/zlib/lib/${ANDROID_ABI})
ELSE()
INSTALL(FILES ${ZLIB_LIBRARIES} DESTINATION third_party/zlib/lib)
ENDIF()
ENDIF()
...@@ -64,9 +64,29 @@ link_paddle_exe(paddle_capi_shared) ...@@ -64,9 +64,29 @@ link_paddle_exe(paddle_capi_shared)
install(FILES ${CAPI_HEADERS} DESTINATION include/paddle) install(FILES ${CAPI_HEADERS} DESTINATION include/paddle)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/config.h DESTINATION include/paddle) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/config.h DESTINATION include/paddle)
if(ANDROID) if(ANDROID)
execute_process(
COMMAND ${GIT_EXECUTABLE} log --pretty=oneline -1
OUTPUT_VARIABLE GIT_COMMITS_LIST
RESULT_VARIABLE GIT_COMMITS_LIST_RESULT
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
if(${GIT_COMMITS_LIST_RESULT})
set(GIT_COMMITS_LIST "No commits.")
endif()
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${capi_whole_library} install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${capi_whole_library}
DESTINATION lib/${ANDROID_ABI}) DESTINATION lib/${ANDROID_ABI})
install(TARGETS paddle_capi_shared DESTINATION lib/${ANDROID_ABI}) install(TARGETS paddle_capi_shared DESTINATION lib/${ANDROID_ABI})
install(CODE "FILE(WRITE ${CMAKE_INSTALL_PREFIX}/lib/${ANDROID_ABI}/BUILD.txt
\"Compiler:\n\"
\"\\t${CMAKE_C_COMPILER}\\n\"
\"\\t${CMAKE_CXX_COMPILER}\\n\"
\"Compiler Flags:\\n\"
\"\\t${CMAKE_F_FLAGS}\\n\"
\"\\t${CMAKE_CXX_FLAGS}\\n\"
\"Android API: ${CMAKE_SYSTEM_VERSION}\\n\"
\"Lastest commit:\\n\"
\"\\t${GIT_COMMITS_LIST}\\n\"
)"
)
else(ANDROID) else(ANDROID)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${capi_whole_library} DESTINATION lib) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${capi_whole_library} DESTINATION lib)
install(TARGETS paddle_capi_shared DESTINATION lib) install(TARGETS paddle_capi_shared DESTINATION lib)
......
...@@ -594,7 +594,7 @@ struct StridePadding { ...@@ -594,7 +594,7 @@ struct StridePadding {
float32x4_t s1 = vdupq_n_f32(0.f); float32x4_t s1 = vdupq_n_f32(0.f);
for (int s = 0; s < step; s++) { for (int s = 0; s < step; s++) {
float32x4_t s0 = vld1q_f32(input); float32x4_t s0 = vld1q_f32(input);
float32x4x2_t v = {s0, s1}; float32x4x2_t v = {{s0, s1}};
vst2q_f32(inputPadding, v); vst2q_f32(inputPadding, v);
input += 4; input += 4;
inputPadding += 8; inputPadding += 8;
......
...@@ -2,8 +2,30 @@ ...@@ -2,8 +2,30 @@
set -xe set -xe
if [ $ANDROID_ABI == "arm64-v8a" ]; then
ANDROID_ARCH=arm64
else # armeabi, armeabi-v7a
ANDROID_ARCH=arm
fi
ANDROID_STANDALONE_TOOLCHAIN=$ANDROID_TOOLCHAINS_DIR/$ANDROID_ARCH-android-$ANDROID_API
cat <<EOF
============================================
Generating the standalone toolchain ...
${ANDROID_NDK_HOME}/build/tools/make-standalone-toolchain.sh
--arch=$ANDROID_ARCH
--platform=android-$ANDROID_API
--install-dir=${ANDROID_STANDALONE_TOOLCHAIN}
============================================
EOF
${ANDROID_NDK_HOME}/build/tools/make-standalone-toolchain.sh \
--arch=$ANDROID_ARCH \
--platform=android-$ANDROID_API \
--install-dir=$ANDROID_STANDALONE_TOOLCHAIN
BUILD_ROOT=/paddle/build_android BUILD_ROOT=/paddle/build_android
DEST_ROOT=/paddle/install DEST_ROOT=/paddle/install_android
rm -rf $BUILD_ROOT 2>/dev/null || true rm -rf $BUILD_ROOT 2>/dev/null || true
mkdir -p $BUILD_ROOT mkdir -p $BUILD_ROOT
...@@ -11,7 +33,7 @@ cd $BUILD_ROOT ...@@ -11,7 +33,7 @@ cd $BUILD_ROOT
if [ $ANDROID_ABI == "armeabi-v7a" ]; then if [ $ANDROID_ABI == "armeabi-v7a" ]; then
cmake -DCMAKE_SYSTEM_NAME=Android \ cmake -DCMAKE_SYSTEM_NAME=Android \
-DANDROID_STANDALONE_TOOLCHAIN=$ANDROID_ARM_STANDALONE_TOOLCHAIN \ -DANDROID_STANDALONE_TOOLCHAIN=$ANDROID_STANDALONE_TOOLCHAIN \
-DANDROID_ABI=$ANDROID_ABI \ -DANDROID_ABI=$ANDROID_ABI \
-DANDROID_ARM_NEON=ON \ -DANDROID_ARM_NEON=ON \
-DANDROID_ARM_MODE=ON \ -DANDROID_ARM_MODE=ON \
...@@ -26,7 +48,7 @@ if [ $ANDROID_ABI == "armeabi-v7a" ]; then ...@@ -26,7 +48,7 @@ if [ $ANDROID_ABI == "armeabi-v7a" ]; then
.. ..
elif [ $ANDROID_ABI == "arm64-v8a" ]; then elif [ $ANDROID_ABI == "arm64-v8a" ]; then
cmake -DCMAKE_SYSTEM_NAME=Android \ cmake -DCMAKE_SYSTEM_NAME=Android \
-DANDROID_STANDALONE_TOOLCHAIN=$ANDROID_ARM64_STANDALONE_TOOLCHAIN \ -DANDROID_STANDALONE_TOOLCHAIN=$ANDROID_STANDALONE_TOOLCHAIN \
-DANDROID_ABI=$ANDROID_ABI \ -DANDROID_ABI=$ANDROID_ABI \
-DANDROID_ARM_MODE=ON \ -DANDROID_ARM_MODE=ON \
-DHOST_C_COMPILER=/usr/bin/gcc \ -DHOST_C_COMPILER=/usr/bin/gcc \
...@@ -40,12 +62,12 @@ elif [ $ANDROID_ABI == "arm64-v8a" ]; then ...@@ -40,12 +62,12 @@ elif [ $ANDROID_ABI == "arm64-v8a" ]; then
.. ..
elif [ $ANDROID_ABI == "armeabi" ]; then elif [ $ANDROID_ABI == "armeabi" ]; then
cmake -DCMAKE_SYSTEM_NAME=Android \ cmake -DCMAKE_SYSTEM_NAME=Android \
-DANDROID_STANDALONE_TOOLCHAIN=$ANDROID_ARM_STANDALONE_TOOLCHAIN \ -DANDROID_STANDALONE_TOOLCHAIN=$ANDROID_STANDALONE_TOOLCHAIN \
-DANDROID_ABI=$ANDROID_ABI \ -DANDROID_ABI=$ANDROID_ABI \
-DANDROID_ARM_MODE=ON \ -DANDROID_ARM_MODE=ON \
-DHOST_C_COMPILER=/usr/bin/gcc \ -DHOST_C_COMPILER=/usr/bin/gcc \
-DHOST_CXX_COMPILER=/usr/bin/g++ \ -DHOST_CXX_COMPILER=/usr/bin/g++ \
-DCMAKE_INSTALL_PREFIX=/paddle/install \ -DCMAKE_INSTALL_PREFIX=$DEST_ROOT \
-DCMAKE_BUILD_TYPE=Release \ -DCMAKE_BUILD_TYPE=Release \
-DWITH_C_API=ON \ -DWITH_C_API=ON \
-DWITH_SWIG_PY=OFF \ -DWITH_SWIG_PY=OFF \
...@@ -55,5 +77,10 @@ else ...@@ -55,5 +77,10 @@ else
echo "Invalid ANDROID_ABI: $ANDROID_ABI" echo "Invalid ANDROID_ABI: $ANDROID_ABI"
fi fi
cat <<EOF
============================================
Building in $BUILD_ROOT ...
============================================
EOF
make -j `nproc` make -j `nproc`
make install -j `nproc` make install -j `nproc`
...@@ -169,6 +169,7 @@ class LayerType(object): ...@@ -169,6 +169,7 @@ class LayerType(object):
EXCONV_LAYER = 'exconv' EXCONV_LAYER = 'exconv'
EXCONVTRANS_LAYER = 'exconvt' EXCONVTRANS_LAYER = 'exconvt'
CUDNNCONV_LAYER = 'cudnn_conv' CUDNNCONV_LAYER = 'cudnn_conv'
CUDNNCONVTRANS_LAYER = 'cudnn_convt'
POOL_LAYER = 'pool' POOL_LAYER = 'pool'
POOL3D_LAYER = 'pool3d' POOL3D_LAYER = 'pool3d'
BATCH_NORM_LAYER = 'batch_norm' BATCH_NORM_LAYER = 'batch_norm'
......
...@@ -85,7 +85,7 @@ def get_numeric_gradient(scope, ...@@ -85,7 +85,7 @@ def get_numeric_gradient(scope,
op, op,
inputs, inputs,
input_to_check, input_to_check,
output_name, output_names,
delta=0.005, delta=0.005,
in_place=False): in_place=False):
...@@ -100,8 +100,11 @@ def get_numeric_gradient(scope, ...@@ -100,8 +100,11 @@ def get_numeric_gradient(scope,
ctx = core.DeviceContext.create(core.CPUPlace()) ctx = core.DeviceContext.create(core.CPUPlace())
def get_output(): def get_output():
op.run(scope, ctx) sum = 0.0
return np.array(scope.find_var(output_name).get_tensor()).sum() for output_name in output_names:
op.run(scope, ctx)
sum += np.array(scope.find_var(output_name).get_tensor()).sum()
return sum
tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims()) tensor_size = product(tensor_to_check.get_dims())
...@@ -225,7 +228,7 @@ class OpTest(unittest.TestCase): ...@@ -225,7 +228,7 @@ class OpTest(unittest.TestCase):
def check_grad(self, def check_grad(self,
inputs_to_check, inputs_to_check,
output_name, output_names,
no_grad_set=None, no_grad_set=None,
in_place=False, in_place=False,
max_relative_error=0.005): max_relative_error=0.005):
...@@ -237,13 +240,16 @@ class OpTest(unittest.TestCase): ...@@ -237,13 +240,16 @@ class OpTest(unittest.TestCase):
if no_grad_set is None: if no_grad_set is None:
no_grad_set = set() no_grad_set = set()
if not type(output_names) is list:
output_names = [output_names]
numeric_grads = [ numeric_grads = [
get_numeric_gradient( get_numeric_gradient(
self.scope, self.scope,
self.op, self.op,
self.inputs, self.inputs,
input_to_check, input_to_check,
output_name, output_names,
in_place=in_place) for input_to_check in inputs_to_check in_place=in_place) for input_to_check in inputs_to_check
] ]
grad_names = [ grad_names = [
......
...@@ -12,7 +12,8 @@ class GetNumericGradientTest(unittest.TestCase): ...@@ -12,7 +12,8 @@ class GetNumericGradientTest(unittest.TestCase):
z = x + y z = x + y
scope = core.Scope() scope = core.Scope()
add_op = create_op(scope, "add", {'X': x, 'Y': y}, {'Out': z}, dict()) add_op = create_op(scope, "add", {'X': x, 'Y': y}, {'Out': z}, dict())
arr = get_numeric_gradient(scope, add_op, {'X': x, 'Y': y}, 'X', 'Out') arr = get_numeric_gradient(scope, add_op, {'X': x,
'Y': y}, 'X', ['Out'])
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4) self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4)
def test_softmax_op(self): def test_softmax_op(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册