diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2ca988c406ae2987e26ca37dbc17cc0a2af43743..bb8c88787d37faf9ce4d7d856a307c11f1085d98 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -24,7 +24,7 @@
description: Format files with ClangFormat.
entry: clang-format -i
language: system
- files: \.(c|cc|cxx|cpp|h|hpp|hxx)$
+ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index c7d743e193e7d32dbc0b56f3bcb05b6c61f85f1d..b174831109372cb014741d63032fa6a470e74042 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -36,8 +36,8 @@ include(simd)
################################ Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
-option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF)
-option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF)
+option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND})
+option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND})
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
diff --git a/Dockerfile b/Dockerfile
index 5dd9b0be4f7e0a304108abfdfb089fea4faa4d38..156ad3552b2c4ff90b405c35c66d44117c2624a4 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -27,13 +27,16 @@ RUN apt-get update && \
git python-pip python-dev openssh-server bison \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \
- python-numpy python-matplotlib gcc-4.8 g++-4.8 \
+ python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format-3.8 swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools && \
apt-get clean -y
+# paddle is using numpy.flip, which is introduced since 1.12.0
+RUN pip --no-cache-dir install 'numpy>=1.12.0'
+
# Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go.tgz && \
diff --git a/cmake/configure.cmake b/cmake/configure.cmake
index 69220e03fe8e337205f31cb1f45e3e19ae4f5d1e..2ac098954647d37e26ac2499e0675dae39910edc 100644
--- a/cmake/configure.cmake
+++ b/cmake/configure.cmake
@@ -74,8 +74,6 @@ if(WITH_MKLDNN)
set(OPENMP_FLAGS "-fopenmp")
set(CMAKE_C_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS})
set(CMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS})
- set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed")
- set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OPENMP_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPENMP_FLAGS}")
else()
diff --git a/cmake/cpplint.cmake b/cmake/cpplint.cmake
index 656e1a0803c6e389d70f37f592c3aa2e95a2bcd4..5184f0815faac005b3dff1015395235f4e19d65b 100644
--- a/cmake/cpplint.cmake
+++ b/cmake/cpplint.cmake
@@ -42,26 +42,21 @@ macro(add_style_check_target TARGET_NAME)
if(WITH_STYLE_CHECK)
set(SOURCES_LIST ${ARGN})
list(REMOVE_DUPLICATES SOURCES_LIST)
- list(SORT SOURCES_LIST)
-
foreach(filename ${SOURCES_LIST})
- set(LINT ON)
foreach(pattern ${IGNORE_PATTERN})
if(filename MATCHES ${pattern})
- message(STATUS "DROP LINT ${filename}")
- set(LINT OFF)
+ list(REMOVE_ITEM SOURCES_LIST ${filename})
endif()
endforeach()
- if(LINT MATCHES ON)
- # cpplint code style
- get_filename_component(base_filename ${filename} NAME)
- set(CUR_GEN ${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.cpplint)
- add_custom_command(TARGET ${TARGET_NAME} PRE_BUILD
- COMMAND "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py"
- "--filter=${STYLE_FILTER}"
- "--write-success=${CUR_GEN}" ${filename}
- WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
- endif()
endforeach()
+
+ if(SOURCES_LIST)
+ add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
+ COMMAND "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py"
+ "--filter=${STYLE_FILTER}"
+ ${SOURCES_LIST}
+ COMMENT "cpplint: Checking source code style"
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
+ endif()
endif()
endmacro()
diff --git a/cmake/external/any.cmake b/cmake/external/any.cmake
index 45e3764e8482a4cfc8ee72fe4d79f04a3c9b74fa..5d2f7219b2007493916a39e839d647a9d0046c9f 100644
--- a/cmake/external/any.cmake
+++ b/cmake/external/any.cmake
@@ -7,7 +7,7 @@ INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/extern_lib_any)
ExternalProject_Add(
extern_lib_any
${EXTERNAL_PROJECT_LOG_ARGS}
- GIT_REPOSITORY "https://github.com/thelink2012/any.git"
+ GIT_REPOSITORY "https://github.com/PaddlePaddle/any.git"
GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020"
PREFIX ${ANY_SOURCE_DIR}
UPDATE_COMMAND ""
diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake
index a0d0a892c4b3cc3743ac725f3cd90444f18abf34..16e5bef4cdb8d6513de51838e3c3c8398dbad60d 100644
--- a/cmake/external/gflags.cmake
+++ b/cmake/external/gflags.cmake
@@ -28,7 +28,14 @@ INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR})
ExternalProject_Add(
extern_gflags
${EXTERNAL_PROJECT_LOG_ARGS}
- GIT_REPOSITORY "https://github.com/gflags/gflags.git"
+ # TODO(yiwang): The annoying warnings mentioned in
+ # https://github.com/PaddlePaddle/Paddle/issues/3277 are caused by
+ # gflags. I fired a PR https://github.com/gflags/gflags/pull/230
+ # to fix it. Before it gets accepted by the gflags team, we use
+ # my personal fork, which contains above fix, temporarily. Let's
+ # change this back to the official Github repo once my PR is
+ # merged.
+ GIT_REPOSITORY "https://github.com/wangkuiyi/gflags.git"
PREFIX ${GFLAGS_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake
index 60a1041936437775e0994157b8ffcb7c52b7ab87..db09232c0e69016bf18c1d981e4620e9e804ff7c 100644
--- a/cmake/external/openblas.cmake
+++ b/cmake/external/openblas.cmake
@@ -69,8 +69,13 @@ ENDIF(NOT ${CBLAS_FOUND})
MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}")
INCLUDE_DIRECTORIES(${CBLAS_INC_DIR})
-ADD_LIBRARY(cblas STATIC IMPORTED)
-SET_PROPERTY(TARGET cblas PROPERTY IMPORTED_LOCATION ${CBLAS_LIBRARIES})
+# FIXME(gangliao): generate cblas target to track all high performance
+# linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas)
+SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c)
+FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
+ADD_LIBRARY(cblas STATIC ${dummyfile})
+TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
+
IF(NOT ${CBLAS_FOUND})
ADD_DEPENDENCIES(cblas extern_openblas)
LIST(APPEND external_project_dependencies cblas)
diff --git a/cmake/external/python.cmake b/cmake/external/python.cmake
index 67a359d4b5f4cca8fc8e74eab4d4acb4cc12baed..490c87d67ed79a238dd506127cd4d9855fab6626 100644
--- a/cmake/external/python.cmake
+++ b/cmake/external/python.cmake
@@ -24,7 +24,6 @@ IF(WITH_PYTHON)
ENDIF(WITH_PYTHON)
SET(py_env "")
-SET(USE_VIRTUALENV_FOR_TEST 1)
IF(PYTHONINTERP_FOUND)
find_python_module(pip REQUIRED)
find_python_module(numpy REQUIRED)
diff --git a/cmake/flags.cmake b/cmake/flags.cmake
index d00a9bb3a30cfb16623e073414088059481c3e1a..e26d8d9df386e65137aa83cc60a43bfeabf7a4a6 100644
--- a/cmake/flags.cmake
+++ b/cmake/flags.cmake
@@ -115,7 +115,7 @@ set(COMMON_FLAGS
-Wno-error=literal-suffix
-Wno-error=sign-compare
-Wno-error=unused-local-typedefs
- -Wno-error=parentheses-equality # Warnings in Pybind11
+ -Wno-error=parentheses-equality # Warnings in pybind11
)
set(GPU_COMMON_FLAGS
@@ -195,6 +195,7 @@ endif()
# Modern gpu architectures: Pascal
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
+ list(APPEND CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
endif()
# Custom gpu architecture
diff --git a/cmake/generic.cmake b/cmake/generic.cmake
index 41b9b5928958ae31799c396a8d77fd7cff557905..957c20bcf603f2f264b4658f63ac0eec438f12b1 100644
--- a/cmake/generic.cmake
+++ b/cmake/generic.cmake
@@ -403,3 +403,16 @@ function(py_proto_compile TARGET_NAME)
protobuf_generate_python(py_srcs ${py_proto_compile_SRCS})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs})
endfunction()
+
+function(py_test TARGET_NAME)
+ if(WITH_TESTING)
+ set(options STATIC static SHARED shared)
+ set(oneValueArgs "")
+ set(multiValueArgs SRCS DEPS)
+ cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+ add_test(NAME ${TARGET_NAME}
+ COMMAND env PYTHONPATH=${PADDLE_PYTHON_PACKAGE_DIR}
+ python2 ${py_test_SRCS}
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
+ endif()
+endfunction()
diff --git a/cmake/util.cmake b/cmake/util.cmake
index 87ad9d91d8701c56255c1e7f224764998df634a7..4a27623b7ffc0b389680baee52db440c78442f46 100644
--- a/cmake/util.cmake
+++ b/cmake/util.cmake
@@ -118,7 +118,6 @@ endfunction()
macro(add_unittest_without_exec TARGET_NAME)
add_executable(${TARGET_NAME} ${ARGN})
link_paddle_test(${TARGET_NAME})
- add_style_check_target(${TARGET_NAME} ${ARGN})
endmacro()
# add_unittest
@@ -150,9 +149,12 @@ endfunction()
# Create a python unittest using run_python_tests.sh,
# which takes care of making correct running environment
function(add_python_test TEST_NAME)
- add_test(NAME ${TEST_NAME}
- COMMAND env PADDLE_PACKAGE_DIR=${PADDLE_PYTHON_PACKAGE_DIR}
- bash ${PROJ_ROOT}/paddle/scripts/run_python_tests.sh
- ${USE_VIRTUALENV_FOR_TEST} ${PYTHON_EXECUTABLE} ${ARGN}
- WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
+ foreach(arg ${ARGN})
+ get_filename_component(py_fn ${arg} NAME_WE)
+ set(TRG_NAME ${TEST_NAME}_${py_fn})
+ add_test(NAME ${TRG_NAME}
+ COMMAND env PYTHONPATH=${PADDLE_PYTHON_PACKAGE_DIR}
+ python2 ${arg}
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
+ endforeach()
endfunction()
diff --git a/doc/design/mkldnn/README.MD b/doc/design/mkldnn/README.MD
new file mode 100644
index 0000000000000000000000000000000000000000..e956994431fbb43438c56dcd96ad8313cf516090
--- /dev/null
+++ b/doc/design/mkldnn/README.MD
@@ -0,0 +1,110 @@
+# Intel® MKL-DNN on PaddlePaddle: Design Doc
+
+我们计划将Intel深度神经网络数学库(**MKL-DNN**\[[1](#references)\])集成到PaddlePaddle,充分展现英特尔平台的优势,有效提升PaddlePaddle在英特尔架构上的性能。
+
+我们短期内的基本目标是:
+
+- 完成常用layer的MKL-DNN实现。
+- 完成常见深度神经网络VGG,GoogLeNet 和 ResNet的MKL-DNN实现。
+
+
+## Contents
+
+- [Overview](#overview)
+- [Actions](#actions)
+ - [CMake](#cmake)
+ - [Layers](#layers)
+ - [Activations](#activations)
+ - [Unit Tests](#unit-tests)
+ - [Protobuf Messages](#protobuf-messages)
+ - [Python API](#python-api)
+ - [Demos](#demos)
+ - [Benchmarking](#benchmarking)
+ - [Others](#others)
+- [Design Concerns](#design-concerns)
+
+## Overview
+
+我们会把MKL-DNN作为第三方库集成进PaddlePaddle,整体框架图
+
+
+Figure 1. PaddlePaddle on IA.
+
+
+## Actions
+我们把集成方案大致分为了如下几个方面。
+
+### CMake
+我们会在`CMakeLists.txt`中会添加`WITH_MKLDNN`的选项,当设置这个值为`ON`的时候会启用编译MKL-DNN功能。同时会自动开启OpenMP用于提高MKL-DNN的性能。
+
+同时,我们会引入`WITH_MKLML`选项,用于选择是否使用MKL-DNN自带的MKLML安装包。这个安装包可以独立于MKL-DNN使用,但是建议在开启MKL-DNN的同时也打开MKLML的开关,这样才能发挥最好的性能。
+
+所以,我们会在`cmake/external`目录新建`mkldnn.cmake`和`mklml.cmake`文件,它们会在编译PaddlePaddle的时候下载对应的软件包,并放到PaddlePaddle的third party目录中。
+
+**备注**:当`WITH_MKLML=ON`的时候,会优先使用这个包作为PaddlePaddle的CBLAS和LAPACK库,所以会稍微改动`cmake/cblas.cmake`中的逻辑。
+
+### Layers
+所有MKL-DNN相关的C++ layers,都会按照PaddlePaddle的目录结构存放在
+`paddle/gserver/layers`中,并且文件名都会一以*Mkldnn*开头。
+
+所有MKL-DNN的layers都会继承于一个叫做`MkldnnLayer`的父类,该父类继承于PaddlePaddle的基类`Layer`。
+
+### Activations
+由于在PaddlePaddle中,激活函数是独立于layer概念的,所以会在`paddle/gserver/activations`目录下添加一个`MkldnnActivation.h`文件定义一些用于MKL-DNN的接口,实现方法还是会在`ActivationFunction.cpp`文件。
+
+### Unit Tests
+会在`paddle/gserver/test`目录下添加`test_Mkldnn.cpp`和`MkldnnTester.*`用于MKL-DNN的测试。
+
+Activation的测试,计划在PaddlePaddle原有的测试文件上直接添加新的测试type。
+
+### Protobuf Messages
+根据具体layer的需求可能会在`proto/ModelConfig.proto`里面添加必要的选项。
+
+### Python API
+目前只考虑**v1 API**。
+
+计划在`python/paddle/trainer/config_parser.py`里面添加`use_mkldnn`这个选择,方便用户选择使用MKL-DNN的layers。
+
+具体实现方式比如:
+
+```python
+use_mkldnn = bool(int(g_command_config_args.get("use_mkldnn", 0)))
+if use_mkldnn
+ self.layer_type = mkldnn_*
+```
+
+所有MKL-DNN的layer type会以*mkldnn_*开头,以示区分。
+
+并且可能在`python/paddle/trainer_config_helper`目录下的`activations.py `和`layers.py`里面添加必要的MKL-DNN的接口。
+
+### Demos
+
+会在`v1_api_demo`目录下添加一个`mkldnn`的文件夹,里面放入一些用于MKL-DNN测试的demo脚本。
+
+### Benchmarking
+会考虑添加部分逻辑在`benchmark/paddle/image/run.sh`,添加使用MKL-DNN的测试。
+
+### Others
+1. 如果在使用MKL-DNN的情况下,会把CPU的Buffer对齐为64。
+2. 深入PaddlePaddle,寻找有没有其他可以优化的可能,进一步优化。比如可能会用OpenMP改进SGD的更新性能。
+
+## Design Concerns
+
+为了更好的符合PaddlePaddle的代码风格\[[2](#references)\],同时又尽可能少的牺牲MKL-DNN的性能\[[3](#references)\]。
+
+我们总结出一些特别需要注意的点:
+
+1. 使用**deviceId_**。为了尽可能少的在父类Layer中添加变量或者函数,我们决定使用已有的`deviceId_`变量来区分layer的属性,定义`-2`为`MkldnnLayer`特有的设备ID。
+2. 重写父类Layer的**init**函数,修改`deviceId_`为`-2`,代表这个layer是用于跑在MKL-DNN的环境下。
+3. 创建`MkldnnMatrix`,用于管理MKL-DNN会用到的相关memory函数、接口以及会用的到格式信息。
+4. 创建`MkldnnBase`,定义一些除了layer和memory相关的类和函数。包括MKL-DNN会用到`MkldnnStream`和`CpuEngine`,和未来可能还会用到`FPGAEngine`等。
+5. 在**Argument**里添加两个`MkldnnMatrixPtr`,取名为`mkldnnValue`和`mkldnnGrad`,用于存放`MkldnnLayer`会用到的memory buffer。 并且添加函数cvt(会修改为一个更加合适的函数名),用于处理"CPU device"和"MKL-DNN device"之间memory的相互转化。
+6. 在父类`Layer`中的`getOutput`函数中添加一段逻辑,用于判断`deviceId`,并针对device在MKL-DNN和CPU之间不统一的情况,做一个前期转换。 也就是调用`Argument`的cvt函数把output统一到需要的device上。
+7. 在原来的`FLAGS`中添加一个`use_mkldnn`的flag,用于选择是否使用MKL-DNN的相关功能。
+
+## References
+
+1. [Intel Math Kernel Library for Deep Neural Networks (Intel MKL-DNN)](https://github.com/01org/mkl-dnn "Intel MKL-DNN")
+2. [原来的方案](https://github.com/PaddlePaddle/Paddle/pull/3096)会引入**nextLayer**的信息。但是在PaddlePaddle中,无论是重构前的layer还是重构后的op,都不会想要知道next layer/op的信息。
+3. MKL-DNN的高性能格式与PaddlePaddle原有的`NCHW`不同(PaddlePaddle中的CUDNN部分使用的也是`NCHW`,所以不存在这个问题),所以需要引入一个转换方法,并且只需要在必要的时候转换这种格式,才能更好的发挥MKL-DNN的性能。
+
diff --git a/doc/design/mkldnn/image/overview.png b/doc/design/mkldnn/image/overview.png
new file mode 100644
index 0000000000000000000000000000000000000000..84b455c28230703599a2529f014cfbb222138fef
Binary files /dev/null and b/doc/design/mkldnn/image/overview.png differ
diff --git a/paddle/.set_python_path.sh b/paddle/.set_python_path.sh
index fa7baccc86e0b56e57d52a40c95cfe1b98fececc..8fd58925ee4820269572176ff9496f42914652da 100755
--- a/paddle/.set_python_path.sh
+++ b/paddle/.set_python_path.sh
@@ -21,22 +21,15 @@
#
# It same as PYTHONPATH=${YOUR_PYTHON_PATH}:$PYTHONPATH {exec...}
#
-
-if ! python -c "import paddle" >/dev/null 2>/dev/null; then
- PYPATH=""
- set -x
- while getopts "d:" opt; do
- case $opt in
- d)
- PYPATH=$OPTARG
- ;;
- esac
- done
- shift $(($OPTIND - 1))
- export PYTHONPATH=$PYPATH:$PYTHONPATH
- $@
-else
- echo "paddle package is already in your PYTHONPATH. But unittest need a clean environment."
- echo "Please uninstall paddle package before start unittest. Try to 'pip uninstall paddle'"
- exit 1
-fi
+PYPATH=""
+set -x
+while getopts "d:" opt; do
+ case $opt in
+ d)
+ PYPATH=$OPTARG
+ ;;
+ esac
+done
+shift $(($OPTIND - 1))
+export PYTHONPATH=$PYPATH:$PYTHONPATH
+$@
diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt
index f8a88cf317aee6c5dd25e4cc25d588c6c50fcbce..cf61a243e9df2fd4a580e41f07cb0a22dcc72083 100644
--- a/paddle/CMakeLists.txt
+++ b/paddle/CMakeLists.txt
@@ -22,7 +22,5 @@ if(WITH_C_API)
endif()
if(WITH_SWIG_PY)
- configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in
- ${CMAKE_CURRENT_SOURCE_DIR}/setup.py)
add_subdirectory(api)
endif()
diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt
index 84da89a1422b6095b995744cebb6a3af98a071c6..7a1e8b8b26ac6330c3799b7dfeb4447e171fe0f1 100644
--- a/paddle/api/CMakeLists.txt
+++ b/paddle/api/CMakeLists.txt
@@ -82,9 +82,7 @@ SWIG_LINK_LIBRARIES(swig_paddle
add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/py_paddle/_swig_paddle.so
COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/swig_paddle.py ${PROJ_ROOT}/paddle/py_paddle
COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/_swig_paddle.so ${PROJ_ROOT}/paddle/py_paddle
- COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel
- COMMAND ${CMAKE_COMMAND} -E touch dist/.timestamp
- COMMAND rm -rf py_paddle.egg-info build
+ COMMAND ${CMAKE_COMMAND} -E touch .timestamp
WORKING_DIRECTORY ${PROJ_ROOT}/paddle
DEPENDS _swig_paddle
)
@@ -92,10 +90,6 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/py_paddle/_swig_paddle.so
# TODO(yuyang18) : make wheel name calculated by cmake
add_custom_target(python_api_wheel ALL DEPENDS ${PROJ_ROOT}/paddle/py_paddle/_swig_paddle.so)
-install(DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/dist/
- DESTINATION opt/paddle/share/wheels
-)
-
if(WITH_TESTING)
IF(NOT PY_PIP_FOUND)
SET(PIP_SOURCES_DIR ${PYTHON_SOURCES_DIR}/pip)
@@ -108,7 +102,7 @@ if(WITH_TESTING)
BUILD_COMMAND ""
INSTALL_COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py install
BUILD_IN_SOURCE 1
- DEPENDS python setuptools python_api_wheel
+ #DEPENDS python setuptools python_api_wheel
)
ENDIF()
add_subdirectory(test)
diff --git a/paddle/api/test/CMakeLists.txt b/paddle/api/test/CMakeLists.txt
index f3b1c2c4d438b5d3e776ef27ce8f8b78f710f2ab..761aeb5b174105edece8880a9f5012c13a63fd11 100644
--- a/paddle/api/test/CMakeLists.txt
+++ b/paddle/api/test/CMakeLists.txt
@@ -1,2 +1,6 @@
-add_python_test(test_swig_api
- testArguments.py testGradientMachine.py testMatrix.py testVector.py testTrain.py testTrainer.py)
+py_test(testTrain SRCS testTrain.py)
+py_test(testMatrix SRCS testMatrix.py)
+py_test(testVector SRCS testVector.py)
+py_test(testTrainer SRCS testTrainer.py)
+py_test(testArguments SRCS testArguments.py)
+py_test(testGradientMachine SRCS testGradientMachine.py)
diff --git a/paddle/cuda/src/hl_batch_transpose.cu b/paddle/cuda/src/hl_batch_transpose.cu
index f047403da17e66960f029f2fee7312210009c952..f4c253df7b4be937f041f18587efd4c9d693fbe4 100644
--- a/paddle/cuda/src/hl_batch_transpose.cu
+++ b/paddle/cuda/src/hl_batch_transpose.cu
@@ -12,17 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
-#include "hl_batch_transpose.h"
#include "hl_base.h"
+#include "hl_batch_transpose.h"
const int TILE_DIM = 64;
const int BLOCK_ROWS = 16;
// No bank-conflict transpose for a batch of data.
-__global__ void batchTransposeNoBankConflicts(real* odata,
- const real* idata,
- int numSamples, int width,
- int height) {
+__global__ void batchTransposeNoBankConflicts(
+ real* odata, const real* idata, int numSamples, int width, int height) {
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
const int x = blockIdx.x * TILE_DIM + threadIdx.x;
@@ -50,12 +48,12 @@ __global__ void batchTransposeNoBankConflicts(real* odata,
newX] = tile[threadIdx.x][j];
}
-void batchTranspose(const real* input, real* output, int width, int height,
- int batchSize) {
+void batchTranspose(
+ const real* input, real* output, int width, int height, int batchSize) {
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
dim3 dimGrid(DIVUP(width, TILE_DIM), DIVUP(height, TILE_DIM), batchSize);
- batchTransposeNoBankConflicts<<>>
- (output, input, batchSize, width, height);
+ batchTransposeNoBankConflicts<<>>(
+ output, input, batchSize, width, height);
CHECK_SYNC("batchTranspose failed!");
}
diff --git a/paddle/cuda/src/hl_cuda_aggregate.cu b/paddle/cuda/src/hl_cuda_aggregate.cu
index 97034a917708487d1c5dc59e6ebbf45bad1c3227..16a54ad343fa140aa1f3bec311c4b712d0086082 100644
--- a/paddle/cuda/src/hl_cuda_aggregate.cu
+++ b/paddle/cuda/src/hl_cuda_aggregate.cu
@@ -12,27 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
-
+#include "hl_aggregate.h"
#include "hl_base.h"
#include "hl_cuda.h"
#include "hl_cuda.ph"
-#include "hl_aggregate.h"
-#include "hl_thread.ph"
#include "hl_matrix_base.cuh"
+#include "hl_thread.ph"
#include "paddle/utils/Logging.h"
/**
* @brief matrix row operator.
*/
-template
-__global__ void KeMatrixRowOp(Agg agg,
- real *E,
- real *Sum,
- int dimN) {
+template
+__global__ void KeMatrixRowOp(Agg agg, real *E, real *Sum, int dimN) {
__shared__ real sum_s[blockSize];
- int cnt = (dimN + blockSize -1) / blockSize;
- int rowId = blockIdx.x + blockIdx.y*gridDim.x;
- int index = rowId*dimN;
+ int cnt = (dimN + blockSize - 1) / blockSize;
+ int rowId = blockIdx.x + blockIdx.y * gridDim.x;
+ int index = rowId * dimN;
int tid = threadIdx.x;
int lmt = tid;
@@ -44,7 +40,7 @@ __global__ void KeMatrixRowOp(Agg agg,
sum_s[tid] = tmp;
__syncthreads();
- for (int stride = blockSize/2; stride > 0; stride = stride/2) {
+ for (int stride = blockSize / 2; stride > 0; stride = stride / 2) {
if (tid < stride) {
sum_s[tid] = agg(sum_s[tid], sum_s[tid + stride]);
}
@@ -58,29 +54,21 @@ __global__ void KeMatrixRowOp(Agg agg,
}
template
-void hl_matrix_row_op(Agg agg,
- real *A_d,
- real *C_d,
- int dimM,
- int dimN) {
+void hl_matrix_row_op(Agg agg, real *A_d, real *C_d, int dimM, int dimN) {
int blocksX = dimM;
int blocksY = 1;
dim3 threads(128, 1);
dim3 grid(blocksX, blocksY);
- KeMatrixRowOp<<< grid, threads, 0, STREAM_DEFAULT >>>
- (agg, A_d, C_d, dimN);
+ KeMatrixRowOp<<>>(
+ agg, A_d, C_d, dimN);
}
void hl_matrix_row_sum(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
- hl_matrix_row_op(aggregate::sum(),
- A_d,
- C_d,
- dimM,
- dimN);
+ hl_matrix_row_op(aggregate::sum(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_row_sum failed");
}
@@ -88,11 +76,7 @@ void hl_matrix_row_max(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
- hl_matrix_row_op(aggregate::max(),
- A_d,
- C_d,
- dimM,
- dimN);
+ hl_matrix_row_op(aggregate::max(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_row_max failed");
}
@@ -100,23 +84,16 @@ void hl_matrix_row_min(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
- hl_matrix_row_op(aggregate::min(),
- A_d,
- C_d,
- dimM,
- dimN);
+ hl_matrix_row_op(aggregate::min(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_row_min failed");
}
/**
* @brief matrix column operator.
*/
-template
-__global__ void KeMatrixColumnOp(Agg agg,
- real *E,
- real *Sum,
- int dimM,
- int dimN) {
+template
+__global__ void KeMatrixColumnOp(
+ Agg agg, real *E, real *Sum, int dimM, int dimN) {
int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
real tmp = agg.init();
if (rowIdx < dimN) {
@@ -127,15 +104,12 @@ __global__ void KeMatrixColumnOp(Agg agg,
}
}
-template
-__global__ void KeMatrixColumnOp_S(Agg agg,
- real *E,
- real *Sum,
- int dimM,
- int dimN) {
- __shared__ real _sum[blockDimX*blockDimY];
- int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
- int index = threadIdx.y;
+template
+__global__ void KeMatrixColumnOp_S(
+ Agg agg, real *E, real *Sum, int dimM, int dimN) {
+ __shared__ real _sum[blockDimX * blockDimY];
+ int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
+ int index = threadIdx.y;
real tmp = agg.init();
if (rowIdx < dimN) {
@@ -144,14 +118,14 @@ __global__ void KeMatrixColumnOp_S(Agg agg,
index += blockDimY;
}
}
- _sum[threadIdx.x + threadIdx.y*blockDimX] = tmp;
+ _sum[threadIdx.x + threadIdx.y * blockDimX] = tmp;
__syncthreads();
if (rowIdx < dimN) {
- if (threadIdx.y ==0) {
+ if (threadIdx.y == 0) {
real tmp = agg.init();
- for (int i=0; i < blockDimY; i++) {
- tmp = agg(tmp, _sum[threadIdx.x + i*blockDimX]);
+ for (int i = 0; i < blockDimY; i++) {
+ tmp = agg(tmp, _sum[threadIdx.x + i * blockDimX]);
}
Sum[rowIdx] = tmp;
}
@@ -159,25 +133,21 @@ __global__ void KeMatrixColumnOp_S(Agg agg,
}
template
-void hl_matrix_column_op(Agg agg,
- real *A_d,
- real *C_d,
- int dimM,
- int dimN) {
+void hl_matrix_column_op(Agg agg, real *A_d, real *C_d, int dimM, int dimN) {
if (dimN >= 8192) {
- int blocksX = (dimN + 128 -1) / 128;
+ int blocksX = (dimN + 128 - 1) / 128;
int blocksY = 1;
dim3 threads(128, 1);
dim3 grid(blocksX, blocksY);
- KeMatrixColumnOp<<< grid, threads, 0, STREAM_DEFAULT >>>
- (agg, A_d, C_d, dimM, dimN);
+ KeMatrixColumnOp<<>>(
+ agg, A_d, C_d, dimM, dimN);
} else {
- int blocksX = (dimN + 32 -1) / 32;
+ int blocksX = (dimN + 32 - 1) / 32;
int blocksY = 1;
dim3 threads(32, 32);
dim3 grid(blocksX, blocksY);
- KeMatrixColumnOp_S<<< grid, threads, 0, STREAM_DEFAULT>>>
- (agg, A_d, C_d, dimM, dimN);
+ KeMatrixColumnOp_S<<>>(
+ agg, A_d, C_d, dimM, dimN);
}
return;
@@ -187,11 +157,7 @@ void hl_matrix_column_sum(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
- hl_matrix_column_op(aggregate::sum(),
- A_d,
- C_d,
- dimM,
- dimN);
+ hl_matrix_column_op(aggregate::sum(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_column_sum failed");
}
@@ -200,11 +166,7 @@ void hl_matrix_column_max(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
- hl_matrix_column_op(aggregate::max(),
- A_d,
- C_d,
- dimM,
- dimN);
+ hl_matrix_column_op(aggregate::max(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_column_max failed");
}
@@ -213,11 +175,7 @@ void hl_matrix_column_min(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
- hl_matrix_column_op(aggregate::min(),
- A_d,
- C_d,
- dimM,
- dimN);
+ hl_matrix_column_op(aggregate::min(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_column_min failed");
}
@@ -226,16 +184,16 @@ template
__global__ void KeVectorSum(real *E, real *Sum, int dimM) {
__shared__ double sum_s[blockSize];
int tid = threadIdx.x;
- int index = blockIdx.y*blockDim.x+threadIdx.x;
+ int index = blockIdx.y * blockDim.x + threadIdx.x;
sum_s[tid] = 0.0f;
while (index < dimM) {
sum_s[tid] += E[index];
- index += blockDim.x*gridDim.y;
+ index += blockDim.x * gridDim.y;
}
__syncthreads();
- for (int stride = blockSize/2; stride > 0; stride = stride/2) {
+ for (int stride = blockSize / 2; stride > 0; stride = stride / 2) {
if (tid < stride) {
sum_s[tid] += sum_s[tid + stride];
}
@@ -259,38 +217,39 @@ void hl_vector_sum(real *A_d, real *C_h, int dimM) {
dim3 threads(blockSize, 1);
dim3 grid(blocksX, blocksY);
- struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
+ struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
hl_event_t hl_event = &hl_event_st;
- while (!hl_cuda_event_is_ready(hl_event)) {}
+ while (!hl_cuda_event_is_ready(hl_event)) {
+ }
- KeVectorSum<128><<< grid, threads, 0, STREAM_DEFAULT >>>
- (A_d, t_resource.gpu_mem, dimM);
- KeVectorSum<128><<< 1, threads, 0, STREAM_DEFAULT >>>
- (t_resource.gpu_mem, t_resource.cpu_mem, 128);
+ KeVectorSum<128><<>>(
+ A_d, t_resource.gpu_mem, dimM);
+ KeVectorSum<128><<<1, threads, 0, STREAM_DEFAULT>>>(
+ t_resource.gpu_mem, t_resource.cpu_mem, 128);
hl_memcpy_async(C_h, t_resource.cpu_mem, sizeof(real), HPPL_STREAM_DEFAULT);
hl_stream_record_event(HPPL_STREAM_DEFAULT, hl_event);
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
cudaError_t err = (cudaError_t)hl_get_device_last_error();
- CHECK_EQ(cudaSuccess, err)
- << "CUDA error: " << hl_get_device_error_string((size_t)err);
+ CHECK_EQ(cudaSuccess, err) << "CUDA error: "
+ << hl_get_device_error_string((size_t)err);
}
template
__global__ void KeVectorAbsSum(real *E, real *Sum, int dimM) {
__shared__ double sum_s[blockSize];
int tid = threadIdx.x;
- int index = blockIdx.y*blockDim.x+threadIdx.x;
+ int index = blockIdx.y * blockDim.x + threadIdx.x;
sum_s[tid] = 0.0f;
while (index < dimM) {
sum_s[tid] += abs(E[index]);
- index += blockDim.x*gridDim.y;
+ index += blockDim.x * gridDim.y;
}
__syncthreads();
- for (int stride = blockSize/2; stride > 0; stride = stride/2) {
+ for (int stride = blockSize / 2; stride > 0; stride = stride / 2) {
if (tid < stride) {
sum_s[tid] += sum_s[tid + stride];
}
@@ -314,20 +273,21 @@ void hl_vector_abs_sum(real *A_d, real *C_h, int dimM) {
dim3 threads(blockSize, 1);
dim3 grid(blocksX, blocksY);
- struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
+ struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
hl_event_t hl_event = &hl_event_st;
- while (!hl_cuda_event_is_ready(hl_event)) {}
+ while (!hl_cuda_event_is_ready(hl_event)) {
+ }
- KeVectorAbsSum<128><<< grid, threads, 0, STREAM_DEFAULT >>>
- (A_d, t_resource.gpu_mem, dimM);
- KeVectorAbsSum<128><<< 1, threads, 0, STREAM_DEFAULT >>>
- (t_resource.gpu_mem, t_resource.cpu_mem, 128);
+ KeVectorAbsSum<128><<>>(
+ A_d, t_resource.gpu_mem, dimM);
+ KeVectorAbsSum<128><<<1, threads, 0, STREAM_DEFAULT>>>(
+ t_resource.gpu_mem, t_resource.cpu_mem, 128);
hl_memcpy_async(C_h, t_resource.cpu_mem, sizeof(real), HPPL_STREAM_DEFAULT);
hl_stream_record_event(HPPL_STREAM_DEFAULT, hl_event);
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
cudaError_t err = (cudaError_t)hl_get_device_last_error();
- CHECK_EQ(cudaSuccess, err)
- << "CUDA error: " << hl_get_device_error_string((size_t)err);
+ CHECK_EQ(cudaSuccess, err) << "CUDA error: "
+ << hl_get_device_error_string((size_t)err);
}
diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu
index b6e3e63a4f52261e49467bd82fdabd063e81460e..aac19b1ea566ad69f1f7374e393676c8debd9883 100644
--- a/paddle/cuda/src/hl_cuda_cnn.cu
+++ b/paddle/cuda/src/hl_cuda_cnn.cu
@@ -12,21 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
-
#include
#include "hl_base.h"
#include "hl_cnn.h"
#include "hl_device_functions.cuh"
-__global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
- const int channels, const int height,
+__global__ void KeMaxPoolForward(const int nthreads,
+ const real* inputData,
+ const int channels,
+ const int height,
const int width,
- const int pooledH, const int pooledW,
- const int ksizeW, const int ksizeH,
- const int strideH, const int strideW,
- const int offsetH, const int offsetW,
- real* tgtData, const int tgtStride) {
- int index = blockIdx.x * blockDim.x + threadIdx.x;
+ const int pooledH,
+ const int pooledW,
+ const int ksizeW,
+ const int ksizeH,
+ const int strideH,
+ const int strideW,
+ const int offsetH,
+ const int offsetW,
+ real* tgtData,
+ const int tgtStride) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int pw = index % pooledW;
int ph = (index / pooledW) % pooledH;
@@ -46,44 +52,70 @@ __global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
maxval = inputData[h * width + w];
}
}
- int tgtIndex = index % (pooledW * pooledH * channels) +
- frameNum * tgtStride;
+ int tgtIndex =
+ index % (pooledW * pooledH * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = maxval;
}
}
-void hl_maxpool_forward(const int frameCnt, const real* inputData,
+void hl_maxpool_forward(const int frameCnt,
+ const real* inputData,
const int channels,
- const int height, const int width,
- const int pooledH, const int pooledW,
- const int sizeX, const int sizeY,
- const int strideH, const int strideW,
- const int paddingH, const int paddingW,
- real* tgtData, const int tgtStride) {
-
+ const int height,
+ const int width,
+ const int pooledH,
+ const int pooledW,
+ const int sizeX,
+ const int sizeY,
+ const int strideH,
+ const int strideW,
+ const int paddingH,
+ const int paddingW,
+ real* tgtData,
+ const int tgtStride) {
int num_kernels = pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
- KeMaxPoolForward<<< grid, threads, 0, STREAM_DEFAULT >>>
- (num_kernels, inputData, channels, height, width,
- pooledH, pooledW, sizeX, sizeY, strideH, strideW,
- paddingH, paddingW, tgtData, tgtStride);
+ KeMaxPoolForward<<>>(num_kernels,
+ inputData,
+ channels,
+ height,
+ width,
+ pooledH,
+ pooledW,
+ sizeX,
+ sizeY,
+ strideH,
+ strideW,
+ paddingH,
+ paddingW,
+ tgtData,
+ tgtStride);
CHECK_SYNC("hl_maxpool_forward failed");
}
-__global__ void KeMaxPoolBackward(const int nthreads, const real* inputData,
- const real* outData, const real* outGrad,
- const int channels, const int height,
+__global__ void KeMaxPoolBackward(const int nthreads,
+ const real* inputData,
+ const real* outData,
+ const real* outGrad,
+ const int channels,
+ const int height,
const int width,
- const int pooledH, const int pooledW,
- const int sizeX, const int sizeY,
- const int strideH, const int strideW,
- const int padH, const int padW,
- real scaleA, real scaleB,
- real* targetGrad, const int outStride) {
- int index = blockIdx.x * blockDim.x + threadIdx.x;
+ const int pooledH,
+ const int pooledW,
+ const int sizeX,
+ const int sizeY,
+ const int strideH,
+ const int strideW,
+ const int padH,
+ const int padW,
+ real scaleA,
+ real scaleB,
+ real* targetGrad,
+ const int outStride) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
// find out the local index
// find out the local offset
@@ -107,43 +139,69 @@ __global__ void KeMaxPoolBackward(const int nthreads, const real* inputData,
}
}
}
- targetGrad[index] =
- scaleB * targetGrad[index] + scaleA * gradient;
+ targetGrad[index] = scaleB * targetGrad[index] + scaleA * gradient;
}
}
-void hl_maxpool_backward(const int frameCnt, const real* inputData,
- const real* outData, const real* outGrad,
- const int channels, const int height,
- const int width,
- const int pooledH, const int pooledW,
- const int sizeX, const int sizeY,
- const int strideH, const int strideW,
- const int paddingH, const int paddingW,
- real scaleA, real scaleB,
- real* targetGrad, const int outStride) {
-
+void hl_maxpool_backward(const int frameCnt,
+ const real* inputData,
+ const real* outData,
+ const real* outGrad,
+ const int channels,
+ const int height,
+ const int width,
+ const int pooledH,
+ const int pooledW,
+ const int sizeX,
+ const int sizeY,
+ const int strideH,
+ const int strideW,
+ const int paddingH,
+ const int paddingW,
+ real scaleA,
+ real scaleB,
+ real* targetGrad,
+ const int outStride) {
int num_kernels = height * width * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;
- KeMaxPoolBackward<<< blocks, 1024, 0, STREAM_DEFAULT >>>
- (num_kernels, inputData, outData, outGrad, channels,
- height, width, pooledH, pooledW, sizeX, sizeY,
- strideH, strideW,
- paddingH, paddingW,
- scaleA, scaleB,
- targetGrad, outStride);
+ KeMaxPoolBackward<<>>(num_kernels,
+ inputData,
+ outData,
+ outGrad,
+ channels,
+ height,
+ width,
+ pooledH,
+ pooledW,
+ sizeX,
+ sizeY,
+ strideH,
+ strideW,
+ paddingH,
+ paddingW,
+ scaleA,
+ scaleB,
+ targetGrad,
+ outStride);
CHECK_SYNC("hl_maxpool_backward");
}
-__global__ void KeAvgPoolForward(const int nthreads, const real* inputData,
+__global__ void KeAvgPoolForward(const int nthreads,
+ const real* inputData,
const int channels,
- const int height, const int width,
- const int pooledH, const int pooledW,
- const int sizeX, const int sizeY,
- const int strideH, const int strideW,
- const int padH, const int padW,
- real* tgtData, const int tgtStride) {
+ const int height,
+ const int width,
+ const int pooledH,
+ const int pooledW,
+ const int sizeX,
+ const int sizeY,
+ const int strideH,
+ const int strideW,
+ const int padH,
+ const int padW,
+ real* tgtData,
+ const int tgtStride) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int pw = index % pooledW;
@@ -168,39 +226,64 @@ __global__ void KeAvgPoolForward(const int nthreads, const real* inputData,
aveval += inputData[h * width + w];
}
}
- int tgtIndex = index % (pooledW * pooledH * channels) +
- frameNum * tgtStride;
+ int tgtIndex =
+ index % (pooledW * pooledH * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = aveval / pool_size;
}
}
-void hl_avgpool_forward(const int frameCnt, const real* inputData,
+void hl_avgpool_forward(const int frameCnt,
+ const real* inputData,
const int channels,
- const int height, const int width,
- const int pooledH, const int pooledW,
- const int sizeX, const int sizeY,
- const int strideH, const int strideW,
- const int paddingH, const int paddingW,
- real* tgtData, const int tgtStride) {
+ const int height,
+ const int width,
+ const int pooledH,
+ const int pooledW,
+ const int sizeX,
+ const int sizeY,
+ const int strideH,
+ const int strideW,
+ const int paddingH,
+ const int paddingW,
+ real* tgtData,
+ const int tgtStride) {
int num_kernels = pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;
- KeAvgPoolForward<<< blocks, 1024, 0, STREAM_DEFAULT >>>
- (num_kernels, inputData, channels,
- height, width, pooledH, pooledW,
- sizeX, sizeY, strideH, strideW,
- paddingH, paddingW, tgtData, tgtStride);
+ KeAvgPoolForward<<>>(num_kernels,
+ inputData,
+ channels,
+ height,
+ width,
+ pooledH,
+ pooledW,
+ sizeX,
+ sizeY,
+ strideH,
+ strideW,
+ paddingH,
+ paddingW,
+ tgtData,
+ tgtStride);
CHECK_SYNC("hl_avgpool_forward failed");
}
-__global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad,
- const int channels, const int height,
+__global__ void KeAvgPoolBackward(const int nthreads,
+ const real* outGrad,
+ const int channels,
+ const int height,
const int width,
- const int pooledH, const int pooledW,
- const int sizeX, const int sizeY,
- const int strideH, const int strideW,
- const int padH, const int padW,
- real scaleA, real scaleB,
- real* tgtGrad, const int outStride) {
+ const int pooledH,
+ const int pooledW,
+ const int sizeX,
+ const int sizeY,
+ const int strideH,
+ const int strideW,
+ const int padH,
+ const int padW,
+ real scaleA,
+ real scaleB,
+ real* tgtGrad,
+ const int outStride) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int offsetW = index % width + padW;
@@ -215,7 +298,6 @@ __global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad,
real gradient = 0;
outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);
-
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
// figure out the pooling size
@@ -224,32 +306,50 @@ __global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad,
int hend = min(hstart + sizeY, height + padH);
int wend = min(wstart + sizeX, width + padW);
int poolsize = (hend - hstart) * (wend - wstart);
- gradient += outGrad[ph * pooledW + pw]/poolsize;
+ gradient += outGrad[ph * pooledW + pw] / poolsize;
}
}
tgtGrad[index] = scaleB * tgtGrad[index] + scaleA * gradient;
}
}
-void hl_avgpool_backward(const int frameCnt, const real* outGrad,
+void hl_avgpool_backward(const int frameCnt,
+ const real* outGrad,
const int channels,
- const int height, const int width,
- const int pooledH, const int pooledW,
- const int sizeX, const int sizeY,
- const int strideH, const int strideW,
- const int paddingH, const int paddingW,
- real scaleA, real scaleB,
- real* backGrad, const int outStride) {
+ const int height,
+ const int width,
+ const int pooledH,
+ const int pooledW,
+ const int sizeX,
+ const int sizeY,
+ const int strideH,
+ const int strideW,
+ const int paddingH,
+ const int paddingW,
+ real scaleA,
+ real scaleB,
+ real* backGrad,
+ const int outStride) {
int num_kernels = height * width * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;
- KeAvgPoolBackward <<< blocks, 1024, 0, STREAM_DEFAULT >>>
- (num_kernels, outGrad, channels, height, width,
- pooledH, pooledW, sizeX, sizeY,
- strideH, strideW,
- paddingH, paddingW,
- scaleA, scaleB,
- backGrad, outStride);
+ KeAvgPoolBackward<<>>(num_kernels,
+ outGrad,
+ channels,
+ height,
+ width,
+ pooledH,
+ pooledW,
+ sizeX,
+ sizeY,
+ strideH,
+ strideW,
+ paddingH,
+ paddingW,
+ scaleA,
+ scaleB,
+ backGrad,
+ outStride);
CHECK_SYNC("hl_avgpool_backward failed");
}
@@ -266,7 +366,7 @@ __global__ void KeBilinearInterpFw(const real* in,
const size_t numChannels,
const real ratioH,
const real ratioW) {
- int nthreads = outputH * outputW;
+ int nthreads = outputH * outputW;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int outIdH = tid / outputW;
@@ -287,13 +387,14 @@ __global__ void KeBilinearInterpFw(const real* in,
real w1lambda = ratioW * outImgIdx - inImgIdx;
real w2lambda = 1.f - w1lambda;
- const real* inPos =
- &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
+ const real* inPos = &in[outIdH * inputW + channelId * inImgSize +
+ inImgIdy * inImgW + inImgIdx];
// bilinear interpolation
out[outIdH * outputW + outIdW] =
- h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) +
- h1lambda * (w2lambda * inPos[hId * inImgW] + w1lambda * inPos[hId * inImgW + wId]);
+ h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) +
+ h1lambda * (w2lambda * inPos[hId * inImgW] +
+ w1lambda * inPos[hId * inImgW + wId]);
}
}
@@ -313,9 +414,19 @@ void hl_bilinear_forward(const real* inData,
int threadNum = outputH * outputW;
int blocks = (threadNum + 1024 - 1) / 1024;
- KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
- inData, inImgH, inImgW, inputH, inputW, outData, outImgH,
- outImgW, outputH, outputW, numChannels, ratioH, ratioW);
+ KeBilinearInterpFw<<>>(inData,
+ inImgH,
+ inImgW,
+ inputH,
+ inputW,
+ outData,
+ outImgH,
+ outImgW,
+ outputH,
+ outputW,
+ numChannels,
+ ratioH,
+ ratioW);
CHECK_SYNC("hl_bilinear_forward failed");
}
@@ -353,13 +464,15 @@ __global__ void KeBilinearInterpBw(real* in,
real w1lambda = ratioW * outImgIdx - inImgIdx;
real w2lambda = 1.f - w1lambda;
- real* inPos =
- &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
+ real* inPos = &in[outIdH * inputW + channelId * inImgSize +
+ inImgIdy * inImgW + inImgIdx];
const real* outPos = &out[outIdH * outputW + outIdW];
paddle::paddleAtomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
paddle::paddleAtomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
- paddle::paddleAtomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]);
- paddle::paddleAtomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]);
+ paddle::paddleAtomicAdd(&inPos[hId * inImgW],
+ h1lambda * w2lambda * outPos[0]);
+ paddle::paddleAtomicAdd(&inPos[hId * inImgW + wId],
+ h1lambda * w1lambda * outPos[0]);
}
}
@@ -379,22 +492,37 @@ void hl_bilinear_backward(real* inGrad,
int threadNum = outputH * outputW;
int blocks = (threadNum + 1024 - 1) / 1024;
- KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
- inGrad, inImgH, inImgW, inputH, inputW, outGrad, outImgH,
- outImgW, outputH, outputW, numChannels, ratioH, ratioW);
+ KeBilinearInterpBw<<>>(inGrad,
+ inImgH,
+ inImgW,
+ inputH,
+ inputW,
+ outGrad,
+ outImgH,
+ outImgW,
+ outputH,
+ outputW,
+ numChannels,
+ ratioH,
+ ratioW);
CHECK_SYNC("hl_bilinear_backward failed");
}
-__global__ void maxoutFpCompute(size_t nthreads, const real * inData,
- real * outData, int* idData,
- size_t size, size_t featLen, size_t groups) {
+__global__ void maxoutFpCompute(size_t nthreads,
+ const real* inData,
+ real* outData,
+ int* idData,
+ size_t size,
+ size_t featLen,
+ size_t groups) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
- if(index < nthreads) {
+ if (index < nthreads) {
size_t batch_idx = index / size;
size_t i = index % size;
size_t channel_idx = i / featLen;
size_t feat_idx = i % featLen;
- size_t data_idx = (batch_idx * size + channel_idx * featLen) * groups + feat_idx;
+ size_t data_idx =
+ (batch_idx * size + channel_idx * featLen) * groups + feat_idx;
real max = inData[data_idx];
int maxId = 0;
for (size_t g = 1; g < groups; ++g) {
@@ -409,37 +537,50 @@ __global__ void maxoutFpCompute(size_t nthreads, const real * inData,
}
}
-void hl_maxout_forward(const real* inData, real* outData,
- int* idData, size_t batchSize, size_t size,
- size_t featLen, size_t groups) {
+void hl_maxout_forward(const real* inData,
+ real* outData,
+ int* idData,
+ size_t batchSize,
+ size_t size,
+ size_t featLen,
+ size_t groups) {
int num_kernels = size * batchSize;
int blocks = (num_kernels + 1024 - 1) / 1024;
- maxoutFpCompute<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
- num_kernels, inData, outData, idData, size, featLen, groups);
+ maxoutFpCompute<<>>(
+ num_kernels, inData, outData, idData, size, featLen, groups);
CHECK_SYNC("hl_maxout_forward failed");
}
-__global__ void maxoutBpCompute(size_t nthreads, real* inGrad,
- const real* outGrad, const int* idData,
- size_t size, size_t featLen, size_t groups) {
+__global__ void maxoutBpCompute(size_t nthreads,
+ real* inGrad,
+ const real* outGrad,
+ const int* idData,
+ size_t size,
+ size_t featLen,
+ size_t groups) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
- if(index < nthreads) {
+ if (index < nthreads) {
size_t batch_idx = index / size;
size_t i = index % size;
size_t channel_idx = i / featLen;
size_t feat_idx = i % featLen;
size_t newIndex = batch_idx * size;
- size_t gradIdx = (channel_idx * groups + (idData + newIndex)[i]) * featLen + feat_idx;
+ size_t gradIdx =
+ (channel_idx * groups + (idData + newIndex)[i]) * featLen + feat_idx;
(inGrad + newIndex * groups)[gradIdx] += (outGrad + newIndex)[i];
}
}
-void hl_maxout_backward(real* inGrad, const real* outGrad,
- const int* idData, size_t batchSize, size_t size,
- size_t featLen, size_t groups) {
+void hl_maxout_backward(real* inGrad,
+ const real* outGrad,
+ const int* idData,
+ size_t batchSize,
+ size_t size,
+ size_t featLen,
+ size_t groups) {
int num_kernels = size * batchSize;
int blocks = (num_kernels + 1024 - 1) / 1024;
- maxoutBpCompute<<< blocks, 1024, 0, STREAM_DEFAULT >>>(
- num_kernels, inGrad, outGrad, idData, size, featLen, groups);
+ maxoutBpCompute<<>>(
+ num_kernels, inGrad, outGrad, idData, size, featLen, groups);
CHECK_SYNC("hl_maxout_backward failed");
}
diff --git a/paddle/cuda/src/hl_cuda_lstm.cu b/paddle/cuda/src/hl_cuda_lstm.cu
index b869d903ba3cfb188f823518ba8ee7d17f9b2440..a5ce81a904ebbd655a16ef68660b81d442478575 100644
--- a/paddle/cuda/src/hl_cuda_lstm.cu
+++ b/paddle/cuda/src/hl_cuda_lstm.cu
@@ -12,14 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
-
+#include "hl_activation_functions.h"
#include "hl_base.h"
#include "hl_cuda_cublas.h"
#include "hl_device_functions.cuh"
-#include "hl_activation_functions.h"
#include "paddle/utils/Logging.h"
-typedef hppl::Active::forward t_forward;
+typedef hppl::Active::forward t_forward;
typedef hppl::Active::backward t_backward;
bool hl_lstm_sequence_parallel(int frameSize) {
@@ -42,9 +41,9 @@ public:
value_ += (start + length - 1) * frameSize + idx;
}
}
- __device__ inline real *getPtr() const {return value_;}
- __device__ inline real getValue() {return *value_;}
- __device__ inline void setValue(real value) {*value_ = value;}
+ __device__ inline real *getPtr() const { return value_; }
+ __device__ inline real getValue() { return *value_; }
+ __device__ inline void setValue(real value) { *value_ = value; }
template
__device__ inline void nextFrame() {
if (reversed == 0) {
@@ -55,28 +54,25 @@ public:
}
};
-__device__ __forceinline__
-void ptx_sync(const int id, const int barriers) {
+__device__ __forceinline__ void ptx_sync(const int id, const int barriers) {
asm volatile("bar.sync %0, %1;" : : "r"(id), "r"(barriers) : "memory");
}
-__device__ __forceinline__
-void ptx_arrive(const int id, const int barriers) {
+__device__ __forceinline__ void ptx_arrive(const int id, const int barriers) {
asm volatile("bar.arrive %0, %1;" : : "r"(id), "r"(barriers) : "memory");
}
-template
-__device__ __forceinline__ real
-forward_sequence(real value,
- real *shValue,
- real *state,
- real *preOutput,
- real *output,
- real check,
- int index,
- t_forward activeNode,
- t_forward activeGate,
- t_forward activeState) {
+template
+__device__ __forceinline__ real forward_sequence(real value,
+ real *shValue,
+ real *state,
+ real *preOutput,
+ real *output,
+ real check,
+ int index,
+ t_forward activeNode,
+ t_forward activeGate,
+ t_forward activeState) {
real out;
real prevOut;
real state_r;
@@ -112,17 +108,20 @@ forward_sequence(real value,
if (idy == 0) {
ptx_sync(2, frameSize * 2);
prevOut = state[idx];
- prevOut = activeState(prevOut);
+ prevOut = activeState(prevOut);
preOutput[idx] = prevOut;
ptx_arrive(3, frameSize * 2);
}
return value;
}
-#define OUTPUT_BARRIER_ID 10
-#define OUTPUT_BARRIER_ID2 11
-template
+#define OUTPUT_BARRIER_ID 10
+#define OUTPUT_BARRIER_ID2 11
+template
__global__ void KeLstmForward(real *gateValue,
real *state,
real *output,
@@ -184,10 +183,16 @@ __global__ void KeLstmForward(real *gateValue,
}
}
value = forward_sequence(
- value, shValue, shState, shPrevOutput, shOutput, check, index,
- hppl::gpu::forward[active_node],
- hppl::gpu::forward[active_gate],
- hppl::gpu::forward[active_state]);
+ value,
+ shValue,
+ shState,
+ shPrevOutput,
+ shOutput,
+ check,
+ index,
+ hppl::gpu::forward[active_node],
+ hppl::gpu::forward[active_gate],
+ hppl::gpu::forward[active_state]);
const int idx = index % frameSize;
const int idy = index / frameSize;
if (valueSize == 128) {
@@ -218,7 +223,7 @@ __global__ void KeLstmForward(real *gateValue,
real B_r[frameSize];
const int computeIdx = index - valueSize;
if (i == 0) {
- #pragma unroll
+#pragma unroll
for (int n = 0; n < frameSize; n++) {
B_r[n] = weight[n * valueSize + computeIdx];
}
@@ -230,7 +235,7 @@ __global__ void KeLstmForward(real *gateValue,
}
real sum = 0.0f;
for (int n = 0; n < frameSize; n++) {
- sum += A_r[n]*B_r[n];
+ sum += A_r[n] * B_r[n];
}
shValue[computeIdx] = sum;
ptx_arrive(OUTPUT_BARRIER_ID2, blockSize);
@@ -239,14 +244,14 @@ __global__ void KeLstmForward(real *gateValue,
if (valueSize == 256) {
real B_r[frameSize];
if (i == 0) {
- #pragma unroll
+#pragma unroll
for (int n = 0; n < frameSize; n++) {
B_r[n] = weight[n * valueSize + index];
}
}
real sum = 0.0f;
for (int n = 0; n < frameSize; n++) {
- sum += shOutput[n]*B_r[n];
+ sum += shOutput[n] * B_r[n];
}
value += sum;
}
@@ -273,50 +278,81 @@ void hl_lstm_parallel_forward(real *gateValue,
dim3 grid(numSequences, 1);
if (!reversed) {
if (frameSize == 32) {
- KeLstmForward<128, 32, 0, 128, 256>
- <<>>
- (gateValue, stateValue, outputValue, preOutputValue,
- checkIg, checkFg, checkOg, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmForward<128, 32, 0, 128, 256><<>>(
+ gateValue,
+ stateValue,
+ outputValue,
+ preOutputValue,
+ checkIg,
+ checkFg,
+ checkOg,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
} else if (frameSize == 64) {
- KeLstmForward<256, 64, 0, 256, 256>
- <<>>
- (gateValue, stateValue, outputValue, preOutputValue,
- checkIg, checkFg, checkOg, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmForward<256, 64, 0, 256, 256><<>>(
+ gateValue,
+ stateValue,
+ outputValue,
+ preOutputValue,
+ checkIg,
+ checkFg,
+ checkOg,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
}
} else {
if (frameSize == 32) {
- KeLstmForward<128, 32, 1, 128, 256>
- <<>>
- (gateValue, stateValue, outputValue, preOutputValue,
- checkIg, checkFg, checkOg, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmForward<128, 32, 1, 128, 256><<>>(
+ gateValue,
+ stateValue,
+ outputValue,
+ preOutputValue,
+ checkIg,
+ checkFg,
+ checkOg,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
} else if (frameSize == 64) {
- KeLstmForward<256, 64, 1, 256, 256>
- <<>>
- (gateValue, stateValue, outputValue, preOutputValue,
- checkIg, checkFg, checkOg, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmForward<256, 64, 1, 256, 256><<>>(
+ gateValue,
+ stateValue,
+ outputValue,
+ preOutputValue,
+ checkIg,
+ checkFg,
+ checkOg,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
}
}
CHECK_SYNC("hl_lstm_parallel_forward failed");
}
-__device__ __forceinline__
-void transpose_32x32(real a[], const int idx) {
+__device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
int addr = idx % 32;
- #pragma unroll
+#pragma unroll
for (int k = 1; k < 32; k++) {
// rSrc[k] = __shfl(rSrc[k], (threadIdx.x + k) % 32, 32);
addr = __shfl(addr, (idx + 1) % 32, 32);
a[k] = __shfl(a[k], addr, 32);
}
- #pragma unroll
+#pragma unroll
for (int tid = 0; tid < 31; tid++) {
real tmp = (idx > tid) ? a[0] : a[1];
- #pragma unroll
+#pragma unroll
for (int k = 31; k > 0; k--) {
a[(k + 1) % 32] = (idx > tid) ? a[k] : a[(k + 1) % 32];
}
@@ -324,29 +360,28 @@ void transpose_32x32(real a[], const int idx) {
}
addr = (32 - idx) % 32;
- #pragma unroll
+#pragma unroll
for (int k = 0; k < 32; k++) {
a[k] = __shfl(a[k], addr, 32);
addr = __shfl(addr, (idx + 31) % 32, 32);
}
}
-template
-__device__ void
-backward_sequence(real rGateValue,
- real rOutputGrad,
- real rPreOutputValue,
- real &rGateGrad,
- real &rStateGrad,
- real *shStateGrad,
- real *shStateValue,
- real *shGateValue,
- real rCheck,
- real &rGateValuePrev,
- int index,
- t_backward activeNode,
- t_backward activeGate,
- t_backward activeState) {
+template
+__device__ void backward_sequence(real rGateValue,
+ real rOutputGrad,
+ real rPreOutputValue,
+ real &rGateGrad,
+ real &rStateGrad,
+ real *shStateGrad,
+ real *shStateValue,
+ real *shGateValue,
+ real rCheck,
+ real &rGateValuePrev,
+ int index,
+ t_backward activeNode,
+ t_backward activeGate,
+ t_backward activeState) {
const int frameIdx = index % frameSize;
const int frameIdy = index / frameSize;
if (frameIdy == 3) {
@@ -363,8 +398,8 @@ backward_sequence(real rGateValue,
rStateGrad = rGateGrad * rCheck;
shStateGrad[index] = rStateGrad;
ptx_sync(3, valueSize);
- rStateGrad += shStateGrad[frameIdx + frameSize *2];
- rStateGrad += shStateGrad[frameIdx + frameSize *3];
+ rStateGrad += shStateGrad[frameIdx + frameSize * 2];
+ rStateGrad += shStateGrad[frameIdx + frameSize * 3];
rGateGrad = rStateGrad * shGateValue[frameIdx];
rGateGrad = activeGate(rGateGrad, rGateValue);
} else if (frameIdy == 2) {
@@ -373,7 +408,7 @@ backward_sequence(real rGateValue,
shStateGrad[index] = rStateGrad;
ptx_sync(3, valueSize);
rStateGrad += shStateGrad[frameIdx + frameSize];
- rStateGrad += shStateGrad[frameIdx + frameSize *3];
+ rStateGrad += shStateGrad[frameIdx + frameSize * 3];
rGateValuePrev = rGateValue;
rGateGrad = rStateGrad * shStateValue[frameIdx];
rGateGrad = activeGate(rGateGrad, rGateValue);
@@ -381,43 +416,43 @@ backward_sequence(real rGateValue,
shGateValue[frameIdx] = rGateValue;
ptx_sync(3, valueSize);
rStateGrad = shStateGrad[frameIdx + frameSize];
- rStateGrad += shStateGrad[frameIdx + frameSize *2];
- rStateGrad += shStateGrad[frameIdx + frameSize *3];
+ rStateGrad += shStateGrad[frameIdx + frameSize * 2];
+ rStateGrad += shStateGrad[frameIdx + frameSize * 3];
rGateGrad = rStateGrad * shGateValue[frameIdx + frameSize];
rGateGrad = activeNode(rGateGrad, rGateValue);
}
}
-template
+template
__device__ void load_weight(real rWeight[], real *weight, const int index) {
if (valueSize == 128) {
weight += index;
- #pragma unroll
+#pragma unroll
for (int n = 0; n < frameSize; n++) {
- rWeight[n] = weight[n*valueSize];
+ rWeight[n] = weight[n * valueSize];
}
transpose_32x32(rWeight, index % 32);
}
if (valueSize == 256) {
int id = (index / 32) % 2;
weight += index - id * 32 + id * 32 * valueSize;
- #pragma unroll
+#pragma unroll
for (int n = 0; n < 32; n++) {
- rWeight[n] = weight[n*valueSize];
- rWeight[n + 32] = weight[n*valueSize + 32];
+ rWeight[n] = weight[n * valueSize];
+ rWeight[n + 32] = weight[n * valueSize + 32];
}
transpose_32x32(rWeight, index % 32);
transpose_32x32(&rWeight[32], index % 32);
}
}
-template
+template
__global__ void KeLstmBackward(real *gateValue,
real *gateGrad,
real *stateValue,
- real *stateGrad, /* do not need save */
+ real *stateGrad, /* do not need save */
real *preOutputValue,
- real *preOutputGrad, /* do not need save */
+ real *preOutputGrad, /* do not need save */
real *checkIg,
real *checkIgGrad,
real *checkFg,
@@ -484,20 +519,27 @@ __global__ void KeLstmBackward(real *gateValue,
for (int i = 0; i < length; ++i) {
if (frameIdy == 3) {
- if (i != length -1) {
+ if (i != length - 1) {
frameStateValue.nextFrame();
shStateValue[frameIdx] = frameStateValue.getValue();
} else {
shStateValue[frameIdx] = 0.0;
}
}
- backward_sequence(
- rGateValue, rOutputGrad, rPreOutputValue, rGateGrad,
- rStateGrad, shStateGrad, shStateValue, shGateValue,
- rCheck, rGateValuePrev, index,
- hppl::gpu::backward[active_node],
- hppl::gpu::backward[active_gate],
- hppl::gpu::backward[active_state]);
+ backward_sequence(rGateValue,
+ rOutputGrad,
+ rPreOutputValue,
+ rGateGrad,
+ rStateGrad,
+ shStateGrad,
+ shStateValue,
+ shGateValue,
+ rCheck,
+ rGateValuePrev,
+ index,
+ hppl::gpu::backward[active_node],
+ hppl::gpu::backward[active_gate],
+ hppl::gpu::backward[active_state]);
if (frameIdy == 3) {
rCheckGrad += rGateGrad * rStateValue;
rStateValue = shStateValue[frameIdx];
@@ -523,9 +565,9 @@ __global__ void KeLstmBackward(real *gateValue,
shGateGrad[frameIdy][frameIdx] = rGateGrad;
if (valueSize == 128) {
real sum = 0.0f;
- #pragma unroll
+#pragma unroll
for (int n = 0; n < frameSize; n++) {
- sum += shGateGrad[frameIdy][n]*B_r[n];
+ sum += shGateGrad[frameIdy][n] * B_r[n];
}
if (frameIdy == 3) {
rOutputGrad += sum;
@@ -541,7 +583,7 @@ __global__ void KeLstmBackward(real *gateValue,
}
real sum = 0.0f;
for (int n = 0; n < frameSize; n++) {
- sum += A_r[n]*B_r[n];
+ sum += A_r[n] * B_r[n];
}
if (frameIdy == 3) {
rOutputGrad += sum;
@@ -552,8 +594,8 @@ __global__ void KeLstmBackward(real *gateValue,
if (frameIdy == 3) {
ptx_sync(6, valueSize);
- #pragma unroll
- for (int i = 0; i < 3; i ++) {
+#pragma unroll
+ for (int i = 0; i < 3; i++) {
rOutputGrad += shOutputGrad[i][frameIdx];
}
} else {
@@ -564,11 +606,14 @@ __global__ void KeLstmBackward(real *gateValue,
/* TODO: Temporary save & merger in another kernel */
if (frameIdy == 1) {
- if (checkIgGrad) paddle::paddleAtomicAdd(checkIgGrad+frameIdx, rCheckGrad);
+ if (checkIgGrad)
+ paddle::paddleAtomicAdd(checkIgGrad + frameIdx, rCheckGrad);
} else if (frameIdy == 2) {
- if (checkFgGrad) paddle::paddleAtomicAdd(checkFgGrad+frameIdx, rCheckGrad);
+ if (checkFgGrad)
+ paddle::paddleAtomicAdd(checkFgGrad + frameIdx, rCheckGrad);
} else if (frameIdy == 3) {
- if (checkOgGrad) paddle::paddleAtomicAdd(checkOgGrad+frameIdx, rCheckGrad);
+ if (checkOgGrad)
+ paddle::paddleAtomicAdd(checkOgGrad + frameIdx, rCheckGrad);
}
}
@@ -593,68 +638,183 @@ void hl_lstm_parallel_backward_data(real *gateValue,
hl_activation_mode_t active_node,
hl_activation_mode_t active_gate,
hl_activation_mode_t active_state) {
- CHECK(frameSize == 32 || frameSize == 64 ||
- frameSize == 128 || frameSize == 256);
+ CHECK(frameSize == 32 || frameSize == 64 || frameSize == 128 ||
+ frameSize == 256);
dim3 grid(numSequences, 1);
if (!reversed) {
if (frameSize == 32) {
- KeLstmBackward<128, 32, 0><<>>
- (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
- preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
- checkOgGrad, outputGrad, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmBackward<128, 32, 0><<>>(
+ gateValue,
+ gateGrad,
+ stateValue,
+ stateGrad,
+ preOutputValue,
+ preOutputGrad,
+ checkIg,
+ checkIgGrad,
+ checkFg,
+ checkFgGrad,
+ checkOg,
+ checkOgGrad,
+ outputGrad,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
} else if (frameSize == 64) {
- KeLstmBackward<256, 64, 0><<>>
- (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
- preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
- checkOgGrad, outputGrad, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmBackward<256, 64, 0><<>>(
+ gateValue,
+ gateGrad,
+ stateValue,
+ stateGrad,
+ preOutputValue,
+ preOutputGrad,
+ checkIg,
+ checkIgGrad,
+ checkFg,
+ checkFgGrad,
+ checkOg,
+ checkOgGrad,
+ outputGrad,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
} else if (frameSize == 128) {
- KeLstmBackward<512, 128, 0><<>>
- (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
- preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
- checkOgGrad, outputGrad, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmBackward<512, 128, 0><<>>(
+ gateValue,
+ gateGrad,
+ stateValue,
+ stateGrad,
+ preOutputValue,
+ preOutputGrad,
+ checkIg,
+ checkIgGrad,
+ checkFg,
+ checkFgGrad,
+ checkOg,
+ checkOgGrad,
+ outputGrad,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
} else if (frameSize == 256) {
- KeLstmBackward<1024, 256, 0><<>>
- (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
- preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
- checkOgGrad, outputGrad, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmBackward<1024, 256, 0><<>>(
+ gateValue,
+ gateGrad,
+ stateValue,
+ stateGrad,
+ preOutputValue,
+ preOutputGrad,
+ checkIg,
+ checkIgGrad,
+ checkFg,
+ checkFgGrad,
+ checkOg,
+ checkOgGrad,
+ outputGrad,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
}
} else {
if (frameSize == 32) {
- KeLstmBackward<128, 32, 1><<>>
- (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
- preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
- checkOgGrad, outputGrad, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmBackward<128, 32, 1><<>>(
+ gateValue,
+ gateGrad,
+ stateValue,
+ stateGrad,
+ preOutputValue,
+ preOutputGrad,
+ checkIg,
+ checkIgGrad,
+ checkFg,
+ checkFgGrad,
+ checkOg,
+ checkOgGrad,
+ outputGrad,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
} else if (frameSize == 64) {
- KeLstmBackward<256, 64, 1><<>>
- (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
- preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
- checkOgGrad, outputGrad, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmBackward<256, 64, 1><<>>(
+ gateValue,
+ gateGrad,
+ stateValue,
+ stateGrad,
+ preOutputValue,
+ preOutputGrad,
+ checkIg,
+ checkIgGrad,
+ checkFg,
+ checkFgGrad,
+ checkOg,
+ checkOgGrad,
+ outputGrad,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
} else if (frameSize == 128) {
- KeLstmBackward<512, 128, 1><<>>
- (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
- preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
- checkOgGrad, outputGrad, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmBackward<512, 128, 1><<>>(
+ gateValue,
+ gateGrad,
+ stateValue,
+ stateGrad,
+ preOutputValue,
+ preOutputGrad,
+ checkIg,
+ checkIgGrad,
+ checkFg,
+ checkFgGrad,
+ checkOg,
+ checkOgGrad,
+ outputGrad,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
} else if (frameSize == 256) {
- KeLstmBackward<1024, 256, 1><<>>
- (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
- preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
- checkOgGrad, outputGrad, weight, sequence,
- active_node, active_gate, active_state);
+ KeLstmBackward<1024, 256, 1><<>>(
+ gateValue,
+ gateGrad,
+ stateValue,
+ stateGrad,
+ preOutputValue,
+ preOutputGrad,
+ checkIg,
+ checkIgGrad,
+ checkFg,
+ checkFgGrad,
+ checkOg,
+ checkOgGrad,
+ outputGrad,
+ weight,
+ sequence,
+ active_node,
+ active_gate,
+ active_state);
}
}
CHECK_SYNC("hl_lstm_parallel_backward_data");
}
-template
+template
__global__ void KeSetGradZero(real *gateGrad,
- const int *starts, int valueSize, int numSequences, bool reversed) {
+ const int *starts,
+ int valueSize,
+ int numSequences,
+ bool reversed) {
// const int tid = threadIdx.x;
const int frameIdx = blockIdx.x * B_X + threadIdx.x;
@@ -682,19 +842,31 @@ void hl_lstm_parallel_backward_weight(real *weightGrad,
int valueSize = 4 * frameSize;
dim3 threads(32, 32);
dim3 grid((valueSize + 32 - 1) / 32, (numSequences + 32 - 1) / 32);
- KeSetGradZero<32, 32><<>>
- (gateGrad, sequence, valueSize, numSequences, reversed);
+ KeSetGradZero<32, 32><<>>(
+ gateGrad, sequence, valueSize, numSequences, reversed);
if (!reversed) {
hl_matrix_mul(outputValue,
- HPPL_OP_T, gateGrad + valueSize, HPPL_OP_N, weightGrad,
- frameSize, valueSize, batchSize - 1,
- 1.0, 1.0);
+ HPPL_OP_T,
+ gateGrad + valueSize,
+ HPPL_OP_N,
+ weightGrad,
+ frameSize,
+ valueSize,
+ batchSize - 1,
+ 1.0,
+ 1.0);
} else {
hl_matrix_mul(outputValue + frameSize,
- HPPL_OP_T, gateGrad, HPPL_OP_N, weightGrad,
- frameSize, valueSize, batchSize - 1,
- 1.0, 1.0);
+ HPPL_OP_T,
+ gateGrad,
+ HPPL_OP_N,
+ weightGrad,
+ frameSize,
+ valueSize,
+ batchSize - 1,
+ 1.0,
+ 1.0);
}
CHECK_SYNC("hl_lstm_parallel_backward_weight");
}
diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu
index 9bcc7fb7de44b2211db450fb164655f7947dcad9..39272456c394adc0509e60cf5972df832f7b3424 100644
--- a/paddle/cuda/src/hl_cuda_matrix.cu
+++ b/paddle/cuda/src/hl_cuda_matrix.cu
@@ -12,22 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
-
#include "hl_base.h"
+#include "hl_device_functions.cuh"
+#include "hl_gpu_matrix_kernel.cuh"
#include "hl_matrix.h"
-#include "hl_matrix_ops.cuh"
#include "hl_matrix_apply.cuh"
+#include "hl_matrix_ops.cuh"
#include "hl_sequence.h"
#include "hl_sparse.ph"
#include "paddle/utils/Logging.h"
-#include "hl_device_functions.cuh"
-#include "hl_gpu_matrix_kernel.cuh"
DEFINE_MATRIX_UNARY_OP(Zero, a = 0);
-DEFINE_MATRIX_TERNARY_PARAMETER_OP(_add, TWO_PARAMETER, c = p1*a + p2*b);
-void hl_matrix_add(real *A_d,
- real *B_d,
- real *C_d,
+DEFINE_MATRIX_TERNARY_PARAMETER_OP(_add, TWO_PARAMETER, c = p1 * a + p2 * b);
+void hl_matrix_add(real* A_d,
+ real* B_d,
+ real* C_d,
int dimM,
int dimN,
real alpha,
@@ -36,33 +35,32 @@ void hl_matrix_add(real *A_d,
CHECK_NOTNULL(B_d);
CHECK_NOTNULL(C_d);
- hl_gpu_apply_ternary_op
- , 0, 0>(ternary::_add(alpha, beta),
- A_d,
- B_d,
- C_d,
- dimM,
- dimN,
- dimN,
- dimN,
- dimN);
+ hl_gpu_apply_ternary_op, 0, 0>(
+ ternary::_add(alpha, beta),
+ A_d,
+ B_d,
+ C_d,
+ dimM,
+ dimN,
+ dimN,
+ dimN,
+ dimN);
CHECK_SYNC("hl_matrix_add failed");
}
#ifdef PADDLE_TYPE_DOUBLE
- #define THRESHOLD 128
+#define THRESHOLD 128
#else
- #define THRESHOLD 64
+#define THRESHOLD 64
#endif
-__device__ __forceinline__
-void findMax(real* I,
- real* dfMax_s,
- int blockSize,
- int base,
- int curIdx,
- int nextIdx,
- int dimN,
- real* max) {
+__device__ __forceinline__ void findMax(real* I,
+ real* dfMax_s,
+ int blockSize,
+ int base,
+ int curIdx,
+ int nextIdx,
+ int dimN,
+ real* max) {
dfMax_s[base] = -1.0e20;
while (curIdx < dimN) {
if (dfMax_s[base] < I[nextIdx]) {
@@ -78,25 +76,24 @@ void findMax(real* I,
if (base < stride) {
nextIdx = base + stride;
if (dfMax_s[base] < dfMax_s[nextIdx]) {
- dfMax_s[base] = dfMax_s[nextIdx];
+ dfMax_s[base] = dfMax_s[nextIdx];
}
}
}
- if (0 == base) {
+ if (0 == base) {
max[0] = dfMax_s[0];
}
__syncthreads();
}
-__device__ __forceinline__
-void subMaxAndExp(real* I,
- real* O,
- int curIdx,
- int nextIdx,
- int blockSize,
- int dimN,
- real max) {
+__device__ __forceinline__ void subMaxAndExp(real* I,
+ real* O,
+ int curIdx,
+ int nextIdx,
+ int blockSize,
+ int dimN,
+ real max) {
real val;
while (curIdx < dimN) {
val = I[nextIdx] - max;
@@ -115,14 +112,13 @@ void subMaxAndExp(real* I,
__syncthreads();
}
-__device__ __forceinline__
-void valueSum(real* O,
- real* dfMax_s,
- int blockSize,
- int base,
- int curIdx,
- int nextIdx,
- int dimN) {
+__device__ __forceinline__ void valueSum(real* O,
+ real* dfMax_s,
+ int blockSize,
+ int base,
+ int curIdx,
+ int nextIdx,
+ int dimN) {
dfMax_s[base] = 0;
while (curIdx < dimN) {
dfMax_s[base] += O[nextIdx];
@@ -141,13 +137,8 @@ void valueSum(real* O,
__syncthreads();
}
-__device__ __forceinline__
-void divSum(real* O,
- real sum,
- int curIdx,
- int nextIdx,
- int blockSize,
- int dimN) {
+__device__ __forceinline__ void divSum(
+ real* O, real sum, int curIdx, int nextIdx, int blockSize, int dimN) {
while (curIdx < dimN) {
O[nextIdx] /= sum;
nextIdx += blockSize;
@@ -155,20 +146,18 @@ void divSum(real* O,
}
}
-__device__ __forceinline__
-void softmax(real* I,
- real* O,
- real* dfMax_s,
- int blockSize,
- int base,
- int curIdx,
- int nextIdx,
- int dimN) {
+__device__ __forceinline__ void softmax(real* I,
+ real* O,
+ real* dfMax_s,
+ int blockSize,
+ int base,
+ int curIdx,
+ int nextIdx,
+ int dimN) {
__shared__ real max;
// find the max number
- findMax(I, dfMax_s, blockSize, base, curIdx,
- nextIdx, dimN, &max);
+ findMax(I, dfMax_s, blockSize, base, curIdx, nextIdx, dimN, &max);
// sub max Value and do Exp operation
subMaxAndExp(I, O, base, nextIdx, blockSize, dimN, max);
@@ -181,8 +170,8 @@ void softmax(real* I,
divSum(O, dfMax_s[0], curIdx, nextIdx, blockSize, dimN);
}
-template
-__global__ void KeMatrixSoftMax(real *O, real *I, int dimN) {
+template
+__global__ void KeMatrixSoftMax(real* O, real* I, int dimN) {
int base = threadIdx.x;
__shared__ real dfMax_s[blockSize];
int nextIdx = blockIdx.x * dimN + base;
@@ -191,19 +180,18 @@ __global__ void KeMatrixSoftMax(real *O, real *I, int dimN) {
softmax(I, O, dfMax_s, blockSize, base, curIdx, nextIdx, dimN);
}
-void hl_matrix_softmax(real *A_d, real *C_d, int dimM, int dimN) {
+void hl_matrix_softmax(real* A_d, real* C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
dim3 block(512, 1);
dim3 grid(dimM, 1);
- KeMatrixSoftMax<512>
- <<>>(C_d, A_d, dimN);
+ KeMatrixSoftMax<512><<>>(C_d, A_d, dimN);
CHECK_SYNC("hl_matrix_softmax failed");
}
-template
-__global__ void KeSequenceSoftMax(real *O, real *I, const int* index) {
+template
+__global__ void KeSequenceSoftMax(real* O, real* I, const int* index) {
int base = threadIdx.x;
int bid = blockIdx.x;
__shared__ real dfMax_s[blockSize];
@@ -217,8 +205,8 @@ __global__ void KeSequenceSoftMax(real *O, real *I, const int* index) {
softmax(I, O, dfMax_s, blockSize, base, curIdx, nextIdx, dimN);
}
-void hl_sequence_softmax_forward(real *A_d,
- real *C_d,
+void hl_sequence_softmax_forward(real* A_d,
+ real* C_d,
const int* index,
int numSequence) {
CHECK_NOTNULL(A_d);
@@ -226,59 +214,48 @@ void hl_sequence_softmax_forward(real *A_d,
dim3 block(512, 1);
dim3 grid(numSequence, 1);
- KeSequenceSoftMax<512>
- <<>>(C_d, A_d, index);
+ KeSequenceSoftMax<512><<>>(C_d, A_d, index);
CHECK_SYNC("hl_sequence_softmax_forward failed");
}
-__global__ void KeMatrixDerivative(real *grad_d,
- real *output_d,
- real *sftmaxSum_d,
- int dimM,
- int dimN) {
- int rowIdx = blockIdx.x*blockDim.x + threadIdx.x;
- int colIdx = blockIdx.y*blockDim.y + threadIdx.y;
+__global__ void KeMatrixDerivative(
+ real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {
+ int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
+ int colIdx = blockIdx.y * blockDim.y + threadIdx.y;
int index;
if (rowIdx < dimM && colIdx < dimN) {
- index = rowIdx*dimN + colIdx;
+ index = rowIdx * dimN + colIdx;
grad_d[index] = output_d[index] * (grad_d[index] - sftmaxSum_d[rowIdx]);
}
}
-void hl_matrix_softmax_derivative(real *grad_d,
- real *output_d,
- real *sftmaxSum_d,
- int dimM,
- int dimN) {
+void hl_matrix_softmax_derivative(
+ real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {
CHECK_NOTNULL(grad_d);
CHECK_NOTNULL(output_d);
CHECK_NOTNULL(sftmaxSum_d);
int blocksX = (dimM + 0) / 1;
- int blocksY = (dimN + 1024 -1) / 1024;
+ int blocksY = (dimN + 1024 - 1) / 1024;
dim3 threads(1, 1024);
dim3 grid(blocksX, blocksY);
- KeMatrixDerivative<<< grid, threads, 0, STREAM_DEFAULT >>>
- (grad_d, output_d, sftmaxSum_d, dimM, dimN);
+ KeMatrixDerivative<<>>(
+ grad_d, output_d, sftmaxSum_d, dimM, dimN);
CHECK_SYNC("hl_matrix_softmax_derivative failed");
}
-__global__ void KeMatrixMultiBinaryCrossEntropy(real* output,
- real* entropy,
- int* row,
- int* col,
- int dimM,
- int dimN) {
+__global__ void KeMatrixMultiBinaryCrossEntropy(
+ real* output, real* entropy, int* row, int* col, int dimM, int dimN) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < dimM) {
- for (int i = 0; i < dimN; i ++) {
+ for (int i = 0; i < dimN; i++) {
entropy[index] -= log(1 - output[index * dimN + i]);
}
- int *row_col = col + row[index];
+ int* row_col = col + row[index];
int col_num = row[index + 1] - row[index];
- for (int i = 0; i < col_num; i ++) {
+ for (int i = 0; i < col_num; i++) {
real o = output[index * dimN + row_col[i]];
entropy[index] -= log(o / (1 - o));
}
@@ -299,37 +276,30 @@ void hl_matrix_multi_binary_cross_entropy(real* output,
dim3 threads(n_threads);
dim3 grid(blocks);
hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix);
- KeMatrixMultiBinaryCrossEntropy<<< grid, threads, 0, STREAM_DEFAULT >>>
- (output, entropy, mat->csr_row, mat->csr_col, dimM, dimN);
+ KeMatrixMultiBinaryCrossEntropy<<>>(
+ output, entropy, mat->csr_row, mat->csr_col, dimM, dimN);
CHECK_SYNC("hl_matrix_multi_binary_cross_entropy failed");
}
-__global__ void KeMatrixMultiBinaryCrossEntropyBp(real* output,
- real* grad,
- int* row,
- int* col,
- int dimM,
- int dimN) {
+__global__ void KeMatrixMultiBinaryCrossEntropyBp(
+ real* output, real* grad, int* row, int* col, int dimM, int dimN) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx < dimM) {
- for (int i = 0; i < dimN; i ++) {
+ for (int i = 0; i < dimN; i++) {
int index = row_idx * dimN + i;
grad[index] += 1.0 / (1 - output[index]);
}
int col_num = row[row_idx + 1] - row[row_idx];
- int *row_col = col + row[row_idx];
- for (int i = 0; i < col_num; i ++) {
+ int* row_col = col + row[row_idx];
+ for (int i = 0; i < col_num; i++) {
int index = row_idx * dimN + row_col[i];
grad[index] -= 1.0 / (output[index] * (1 - output[index]));
}
}
}
-void hl_matrix_multi_binary_cross_entropy_bp(real* output,
- real* grad,
- hl_sparse_matrix_s csr_mat,
- int dimM,
- int dimN) {
+void hl_matrix_multi_binary_cross_entropy_bp(
+ real* output, real* grad, hl_sparse_matrix_s csr_mat, int dimM, int dimN) {
CHECK_NOTNULL(output);
CHECK_NOTNULL(grad);
CHECK_NOTNULL(csr_mat);
@@ -339,16 +309,13 @@ void hl_matrix_multi_binary_cross_entropy_bp(real* output,
dim3 threads(n_threads);
dim3 grid(blocks);
hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix);
- KeMatrixMultiBinaryCrossEntropyBp<<< grid, threads, 0, STREAM_DEFAULT >>>
- (output, grad, mat->csr_row, mat->csr_col, dimM, dimN);
+ KeMatrixMultiBinaryCrossEntropyBp<<>>(
+ output, grad, mat->csr_row, mat->csr_col, dimM, dimN);
CHECK_SYNC("hl_matrix_multi_binary_cross_entropy_bp failed");
}
-__global__ void KeMatrixCrossEntropy(real* O,
- real* E,
- int* label,
- int dimM,
- int dimN) {
+__global__ void KeMatrixCrossEntropy(
+ real* O, real* E, int* label, int dimM, int dimN) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int newBase;
if (index < dimM) {
@@ -358,59 +325,49 @@ __global__ void KeMatrixCrossEntropy(real* O,
}
}
-void hl_matrix_cross_entropy(real* A_d,
- real* C_d,
- int* label_d,
- int dimM,
- int dimN) {
+void hl_matrix_cross_entropy(
+ real* A_d, real* C_d, int* label_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
int blocks = (dimM + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
- KeMatrixCrossEntropy<<< grid, threads, 0, STREAM_DEFAULT >>>
- (A_d, C_d, label_d, dimM, dimN);
+ KeMatrixCrossEntropy<<>>(
+ A_d, C_d, label_d, dimM, dimN);
CHECK_SYNC("hl_matrix_cross_entropy failed");
}
-__global__ void KeMatrixCrossEntropyBp(real* grad_d,
- real* output_d,
- int* label_d,
- int dimM,
- int dimN) {
- int rowIdx = blockIdx.x*blockDim.x + threadIdx.x;
- int colIdx = blockIdx.y*blockDim.y + threadIdx.y;
+__global__ void KeMatrixCrossEntropyBp(
+ real* grad_d, real* output_d, int* label_d, int dimM, int dimN) {
+ int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
+ int colIdx = blockIdx.y * blockDim.y + threadIdx.y;
int index;
if (rowIdx < dimM && colIdx < dimN) {
- index = rowIdx*dimN + colIdx;
+ index = rowIdx * dimN + colIdx;
if (label_d[rowIdx] == colIdx) {
grad_d[index] -= 1.0f / output_d[index];
}
}
}
-void hl_matrix_cross_entropy_bp(real* grad_d,
- real* output_d,
- int* label_d,
- int dimM,
- int dimN) {
+void hl_matrix_cross_entropy_bp(
+ real* grad_d, real* output_d, int* label_d, int dimM, int dimN) {
CHECK_NOTNULL(grad_d);
CHECK_NOTNULL(output_d);
CHECK_NOTNULL(label_d);
- int blocksX = (dimM + 0)/1;
- int blocksY = (dimN + 1024 -1) / 1024;
+ int blocksX = (dimM + 0) / 1;
+ int blocksY = (dimN + 1024 - 1) / 1024;
dim3 threads(1, 1024);
dim3 grid(blocksX, blocksY);
- KeMatrixCrossEntropyBp<<< grid, threads, 0, STREAM_DEFAULT >>>
- (grad_d, output_d, label_d, dimM, dimN);
+ KeMatrixCrossEntropyBp<<>>(
+ grad_d, output_d, label_d, dimM, dimN);
CHECK_SYNC("hl_matrix_cross_entropy_bp failed");
}
void hl_matrix_zero_mem(real* data, int num) {
- hl_gpu_apply_unary_op(
- unary::Zero(), data, 1, num, num);
+ hl_gpu_apply_unary_op(unary::Zero(), data, 1, num, num);
}
__global__ void KeParamReluForward(real* output,
@@ -423,8 +380,8 @@ __global__ void KeParamReluForward(real* output,
int ty = blockIdx.y * blockDim.y + threadIdx.y;
if (tx < width && ty < height) {
int index = ty * width + tx;
- output[index] = input[index] > 0 ? input[index] :
- input[index] * w[tx / partial_sum];
+ output[index] =
+ input[index] > 0 ? input[index] : input[index] * w[tx / partial_sum];
}
}
@@ -439,14 +396,14 @@ void hl_param_relu_forward(real* output,
CHECK_NOTNULL(w);
dim3 threads(16, 16);
int blockX = (width + 16 - 1) / 16;
- int blockY = (height + 16 -1) / 16;
+ int blockY = (height + 16 - 1) / 16;
dim3 grid(blockX, blockY);
- KeParamReluForward<<>>
- (output, input, w, width, height, partial_sum);
+ KeParamReluForward<<>>(
+ output, input, w, width, height, partial_sum);
CHECK_SYNC("hl_param_relu_forward failed");
}
-template
+template
__global__ void KeParamReluBackWardW(real* grad_w,
real* grad_o,
real* input,
@@ -491,8 +448,8 @@ void hl_param_relu_backward_w(real* grad_w,
int grid_num = width / partial_sum;
dim3 threads(blockSize, 1);
dim3 grid(grid_num, 1);
- KeParamReluBackWardW<<>>
- (grad_w, grad_o, input, width, height, partial_sum);
+ KeParamReluBackWardW<<>>(
+ grad_w, grad_o, input, width, height, partial_sum);
CHECK_SYNC("hl_param_relu_backward_w failed");
}
@@ -524,19 +481,15 @@ void hl_param_relu_backward_diff(real* grad_o,
CHECK_NOTNULL(diff);
dim3 threads(16, 16);
int blockX = (width + 16 - 1) / 16;
- int blockY = (height + 16 -1) / 16;
+ int blockY = (height + 16 - 1) / 16;
dim3 grid(blockX, blockY);
- KeParamReluBackwardDiff<<>>
- (grad_o, data, w, diff, width, height, partial_sum);
+ KeParamReluBackwardDiff<<>>(
+ grad_o, data, w, diff, width, height, partial_sum);
CHECK_SYNC("hl_param_relu_backward_diff failed");
}
-__global__ void KeMatrixAddSharedBias(real* A,
- real* B,
- const int channel,
- const int M,
- const int N,
- real scale) {
+__global__ void KeMatrixAddSharedBias(
+ real* A, real* B, const int channel, const int M, const int N, real scale) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int dim = N / channel;
if (index < M * N) {
@@ -554,15 +507,14 @@ void hl_matrix_add_shared_bias(real* A_d,
real scale) {
const int blocks = 512;
const int grids = DIVUP(dimM * dimN, blocks);
- KeMatrixAddSharedBias<<>>
- (A_d, B_d, channel, dimM, dimN, scale);
+ KeMatrixAddSharedBias<<>>(
+ A_d, B_d, channel, dimM, dimN, scale);
CHECK_SYNC("hl_matrix_add_shared_bias failed");
}
-
template
-__global__ void KeMatrixCollectSharedBias(real *B,
- real *A,
+__global__ void KeMatrixCollectSharedBias(real* B,
+ real* A,
const int channel,
const int M,
const int N,
@@ -589,7 +541,7 @@ __global__ void KeMatrixCollectSharedBias(real *B,
int n = j * blockSize + tid;
int m = n / dim;
int w = n % dim;
- smem[tid] = (m < M && w < dim) ? A[m * N + bid * dim + w] : 0.0;
+ smem[tid] = (m < M && w < dim) ? A[m * N + bid * dim + w] : 0.0;
__syncthreads();
simpleReduce(smem, tid, blockSize);
sum += smem[0];
@@ -611,33 +563,32 @@ void hl_matrix_collect_shared_bias(real* B_d,
const int limit = 64;
int grids = (dimM * dim) < limit ? DIVUP(channel, blocks) : channel;
- KeMatrixCollectSharedBias
- <<< grids, blocks, 0, STREAM_DEFAULT>>>
- (B_d, A_d, channel, dimM, dimN, dim, limit, scale);
+ KeMatrixCollectSharedBias<<>>(
+ B_d, A_d, channel, dimM, dimN, dim, limit, scale);
CHECK_SYNC("hl_matrix_collect_shared_bias failed");
}
-__global__ void keMatrixRotate(real* mat, real* matRot,
- int dimM, int dimN, bool clockWise) {
- int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if (idx < dimM * dimN) {
- int i = idx / dimN;
- int j = idx % dimN;
- if (clockWise) {
- matRot[j * dimM + i] = mat[(dimM - i - 1) * dimN + j];
- } else {
- matRot[j * dimM + i] = mat[i * dimN + (dimN - j - 1)];
- }
+__global__ void keMatrixRotate(
+ real* mat, real* matRot, int dimM, int dimN, bool clockWise) {
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < dimM * dimN) {
+ int i = idx / dimN;
+ int j = idx % dimN;
+ if (clockWise) {
+ matRot[j * dimM + i] = mat[(dimM - i - 1) * dimN + j];
+ } else {
+ matRot[j * dimM + i] = mat[i * dimN + (dimN - j - 1)];
}
+ }
}
-void hl_matrix_rotate(real *mat, real* matRot,
- int dimM, int dimN, bool clockWise) {
- CHECK_NOTNULL(mat);
- CHECK_NOTNULL(matRot);
- const int threads = 512;
- const int blocks = DIVUP(dimM * dimN, threads);
- keMatrixRotate<<< blocks, threads, 0, STREAM_DEFAULT >>>
- (mat, matRot, dimM, dimN, clockWise);
- CHECK_SYNC("hl_matrix_rotate failed");
+void hl_matrix_rotate(
+ real* mat, real* matRot, int dimM, int dimN, bool clockWise) {
+ CHECK_NOTNULL(mat);
+ CHECK_NOTNULL(matRot);
+ const int threads = 512;
+ const int blocks = DIVUP(dimM * dimN, threads);
+ keMatrixRotate<<>>(
+ mat, matRot, dimM, dimN, clockWise);
+ CHECK_SYNC("hl_matrix_rotate failed");
}
diff --git a/paddle/cuda/src/hl_cuda_sequence.cu b/paddle/cuda/src/hl_cuda_sequence.cu
index eeee921db54e20ea6a017d2b83f2d7ca9e5e037e..c52780dfcaff6e5b94d3568fac4ca011b76a1442 100644
--- a/paddle/cuda/src/hl_cuda_sequence.cu
+++ b/paddle/cuda/src/hl_cuda_sequence.cu
@@ -16,36 +16,36 @@ limitations under the License. */
#include "hl_device_functions.cuh"
#include "paddle/utils/Logging.h"
-__global__ void KeMaxSequenceForward(real *input,
- const int *sequence,
+__global__ void KeMaxSequenceForward(real* input,
+ const int* sequence,
real* output,
- int *index,
+ int* index,
int numSequences,
int dim) {
int dimIdx = threadIdx.x;
int sequenceId = blockIdx.x;
if (sequenceId >= numSequences) return;
int start = sequence[sequenceId];
- int end = sequence[sequenceId+1];
+ int end = sequence[sequenceId + 1];
for (int i = dimIdx; i < dim; i += blockDim.x) {
real tmp = -HL_FLOAT_MAX;
int tmpId = -1;
for (int insId = start; insId < end; insId++) {
- if (tmp < input[insId*dim + i]) {
- tmp = input[insId*dim + i];
+ if (tmp < input[insId * dim + i]) {
+ tmp = input[insId * dim + i];
tmpId = insId;
}
}
- output[sequenceId*dim + i] = tmp;
- index[sequenceId*dim + i] = tmpId;
+ output[sequenceId * dim + i] = tmp;
+ index[sequenceId * dim + i] = tmpId;
}
}
void hl_max_sequence_forward(real* input,
const int* sequence,
real* output,
- int *index,
+ int* index,
int numSequences,
int dim) {
CHECK_NOTNULL(input);
@@ -55,29 +55,23 @@ void hl_max_sequence_forward(real* input,
dim3 threads(256, 1);
dim3 grid(numSequences, 1);
- KeMaxSequenceForward<<< grid, threads, 0, STREAM_DEFAULT >>>
- (input, sequence, output, index, numSequences, dim);
+ KeMaxSequenceForward<<>>(
+ input, sequence, output, index, numSequences, dim);
CHECK_SYNC("hl_max_sequence_forward failed");
}
-__global__ void KeMaxSequenceBackward(real *outputGrad,
- int *index,
- real* inputGrad,
- int numSequences,
- int dim) {
+__global__ void KeMaxSequenceBackward(
+ real* outputGrad, int* index, real* inputGrad, int numSequences, int dim) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int colIdx = idx % dim;
- if (idx < numSequences*dim) {
+ if (idx < numSequences * dim) {
int insId = index[idx];
inputGrad[insId * dim + colIdx] += outputGrad[idx];
}
}
-void hl_max_sequence_backward(real* outputGrad,
- int *index,
- real* inputGrad,
- int numSequences,
- int dim) {
+void hl_max_sequence_backward(
+ real* outputGrad, int* index, real* inputGrad, int numSequences, int dim) {
CHECK_NOTNULL(outputGrad);
CHECK_NOTNULL(index);
CHECK_NOTNULL(inputGrad);
@@ -85,12 +79,12 @@ void hl_max_sequence_backward(real* outputGrad,
unsigned int blocks = (numSequences * dim + 128 - 1) / 128;
dim3 threads(128, 1);
dim3 grid(blocks, 1);
- KeMaxSequenceBackward<<< grid, threads, 0, STREAM_DEFAULT >>>
- (outputGrad, index, inputGrad, numSequences, dim);
+ KeMaxSequenceBackward<<>>(
+ outputGrad, index, inputGrad, numSequences, dim);
CHECK_SYNC("hl_max_sequence_backward failed");
}
-template
+template
__global__ void KeMatrixAddRows(real* output,
real* table,
int* ids,
@@ -104,8 +98,8 @@ __global__ void KeMatrixAddRows(real* output,
while (sampleId < numSamples) {
int tableId = ids[sampleId];
if ((0 <= tableId) && (tableId < tableSize)) {
- real *outputData = output + sampleId * dim;
- real *tableData = table + tableId * dim;
+ real* outputData = output + sampleId * dim;
+ real* tableData = table + tableId * dim;
for (int i = idx; i < dim; i += blockDimX) {
if (AddRow == 0) {
outputData[i] += tableData[i];
@@ -114,24 +108,27 @@ __global__ void KeMatrixAddRows(real* output,
}
}
}
- sampleId += blockDimY*gridDimX;
+ sampleId += blockDimY * gridDimX;
}
}
-template
-__global__
-void KeSequence2Batch(real *batch,
- real *sequence,
- const int *batchIndex,
- int seqWidth,
- int batchCount) {
+template
+__global__ void KeSequence2Batch(real* batch,
+ real* sequence,
+ const int* batchIndex,
+ int seqWidth,
+ int batchCount) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int id = blockIdx.x + idy * gridDimX;
while (id < batchCount) {
int seqId = batchIndex[id];
- real* batchData = batch + id*seqWidth;
- real* seqData = sequence + seqId*seqWidth;
+ real* batchData = batch + id * seqWidth;
+ real* seqData = sequence + seqId * seqWidth;
for (int i = idx; i < seqWidth; i += blockDimX) {
if (seq2batch) {
if (isAdd) {
@@ -147,13 +144,13 @@ void KeSequence2Batch(real *batch,
}
}
}
- id += blockDimY*gridDimX;
+ id += blockDimY * gridDimX;
}
}
-void hl_sequence2batch_copy(real *batch,
- real *sequence,
- const int *batchIndex,
+void hl_sequence2batch_copy(real* batch,
+ real* sequence,
+ const int* batchIndex,
int seqWidth,
int batchCount,
bool seq2batch) {
@@ -164,18 +161,18 @@ void hl_sequence2batch_copy(real *batch,
dim3 threads(128, 8);
dim3 grid(8, 1);
if (seq2batch) {
- KeSequence2Batch<128, 8, 8, 1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
- (batch, sequence, batchIndex, seqWidth, batchCount);
+ KeSequence2Batch<128, 8, 8, 1, 0><<>>(
+ batch, sequence, batchIndex, seqWidth, batchCount);
} else {
- KeSequence2Batch<128, 8, 8, 0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
- (batch, sequence, batchIndex, seqWidth, batchCount);
+ KeSequence2Batch<128, 8, 8, 0, 0><<>>(
+ batch, sequence, batchIndex, seqWidth, batchCount);
}
CHECK_SYNC("hl_sequence2batch_copy failed");
}
-void hl_sequence2batch_add(real *batch,
- real *sequence,
- int *batchIndex,
+void hl_sequence2batch_add(real* batch,
+ real* sequence,
+ int* batchIndex,
int seqWidth,
int batchCount,
bool seq2batch) {
@@ -186,23 +183,22 @@ void hl_sequence2batch_add(real *batch,
dim3 threads(128, 8);
dim3 grid(8, 1);
if (seq2batch) {
- KeSequence2Batch<128, 8, 8, 1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
- (batch, sequence, batchIndex, seqWidth, batchCount);
+ KeSequence2Batch<128, 8, 8, 1, 1><<>>(
+ batch, sequence, batchIndex, seqWidth, batchCount);
} else {
- KeSequence2Batch<128, 8, 8, 0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
- (batch, sequence, batchIndex, seqWidth, batchCount);
+ KeSequence2Batch<128, 8, 8, 0, 1><<>>(
+ batch, sequence, batchIndex, seqWidth, batchCount);
}
CHECK_SYNC("hl_sequence2batch_add failed");
}
-template
-__global__
-void KeSequence2BatchPadding(real* batch,
- real* sequence,
- const int* sequenceStartPositions,
- const size_t sequenceWidth,
- const size_t maxSequenceLength,
- const size_t numSequences) {
+template
+__global__ void KeSequence2BatchPadding(real* batch,
+ real* sequence,
+ const int* sequenceStartPositions,
+ const size_t sequenceWidth,
+ const size_t maxSequenceLength,
+ const size_t numSequences) {
int batchIdx = blockIdx.y;
int sequenceStart = sequenceStartPositions[batchIdx];
int sequenceLength = sequenceStartPositions[batchIdx + 1] - sequenceStart;
@@ -276,37 +272,49 @@ void hl_sequence2batch_copy_padding(real* batch,
if (seq2batch) {
/* sequence -> batch */
if (normByTimes) {
- KeSequence2BatchPadding<1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>(
- batch, sequence, sequenceStartPositions,
- sequenceWidth, maxSequenceLength, numSequences);
+ KeSequence2BatchPadding<1, 1><<>>(
+ batch,
+ sequence,
+ sequenceStartPositions,
+ sequenceWidth,
+ maxSequenceLength,
+ numSequences);
} else {
- KeSequence2BatchPadding<0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>(
- batch, sequence, sequenceStartPositions,
- sequenceWidth, maxSequenceLength, numSequences);
+ KeSequence2BatchPadding<0, 1><<>>(
+ batch,
+ sequence,
+ sequenceStartPositions,
+ sequenceWidth,
+ maxSequenceLength,
+ numSequences);
}
} else {
/* batch -> sequence */
if (normByTimes) {
- KeSequence2BatchPadding<1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>(
- batch, sequence, sequenceStartPositions,
- sequenceWidth, maxSequenceLength, numSequences);
+ KeSequence2BatchPadding<1, 0><<>>(
+ batch,
+ sequence,
+ sequenceStartPositions,
+ sequenceWidth,
+ maxSequenceLength,
+ numSequences);
} else {
- KeSequence2BatchPadding<0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>(
- batch, sequence, sequenceStartPositions,
- sequenceWidth, maxSequenceLength, numSequences);
+ KeSequence2BatchPadding<0, 0><<>>(
+ batch,
+ sequence,
+ sequenceStartPositions,
+ sequenceWidth,
+ maxSequenceLength,
+ numSequences);
}
}
CHECK_SYNC("hl_sequence2batch_copy_padding failed");
}
-__device__ inline float my_rsqrt(float x) {
- return rsqrtf(x);
-}
+__device__ inline float my_rsqrt(float x) { return rsqrtf(x); }
-__device__ inline double my_rsqrt(double x) {
- return rsqrt(x);
-}
+__device__ inline double my_rsqrt(double x) { return rsqrt(x); }
__global__ void KeSequenceAvgForward(real* dst,
real* src,
@@ -327,8 +335,8 @@ __global__ void KeSequenceAvgForward(real* dst,
for (int i = start; i < end; i++) {
sum += src[i * width + col];
}
- sum = mode == 1 ? sum :
- (mode == 0 ? sum / seqLength : sum * my_rsqrt((real)seqLength));
+ sum = mode == 1 ? sum : (mode == 0 ? sum / seqLength
+ : sum * my_rsqrt((real)seqLength));
dst[gid] += sum;
}
}
@@ -347,10 +355,10 @@ void hl_sequence_avg_forward(real* dst,
int grid = DIVUP(width * height, 512);
CHECK(mode == 0 || mode == 1 || mode == 2)
- << "mode error in hl_sequence_avg_forward!";
+ << "mode error in hl_sequence_avg_forward!";
- KeSequenceAvgForward<<< grid, block, 0, STREAM_DEFAULT >>>
- (dst, src, starts, height, width, mode);
+ KeSequenceAvgForward<<>>(
+ dst, src, starts, height, width, mode);
CHECK_SYNC("hl_sequence_avg_forward failed");
}
@@ -370,8 +378,8 @@ __global__ void KeSequenceAvgBackward(real* dst,
int seqLength = end - start;
if (seqLength == 0) return;
real grad = src[gid];
- grad = mode == 1 ? grad :
- (mode == 0 ? grad / seqLength : grad * my_rsqrt((real)seqLength));
+ grad = mode == 1 ? grad : (mode == 0 ? grad / seqLength
+ : grad * my_rsqrt((real)seqLength));
for (int i = start; i < end; i++) {
dst[i * width + col] += grad;
}
@@ -392,9 +400,9 @@ void hl_sequence_avg_backward(real* dst,
int grid = DIVUP(width * height, 512);
CHECK(mode == 0 || mode == 1 || mode == 2)
- << "mode error in hl_sequence_avg_backward!";
+ << "mode error in hl_sequence_avg_backward!";
- KeSequenceAvgBackward<<< grid, block, 0, STREAM_DEFAULT >>>
- (dst, src, starts, height, width, mode);
+ KeSequenceAvgBackward<<>>(
+ dst, src, starts, height, width, mode);
CHECK_SYNC("hl_sequence_avg_backward failed");
}
diff --git a/paddle/cuda/src/hl_cuda_sparse.cu b/paddle/cuda/src/hl_cuda_sparse.cu
index ab9ab57c884137f117c25c2752b5603b2e8b7135..6351e7e01ee55b6303a6e48bc9ebf9834a83130e 100644
--- a/paddle/cuda/src/hl_cuda_sparse.cu
+++ b/paddle/cuda/src/hl_cuda_sparse.cu
@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
-
#include "hl_cuda.h"
+#include "hl_cuda_sparse.cuh"
+#include "hl_matrix_apply.cuh"
+#include "hl_matrix_ops.cuh"
#include "hl_sparse.h"
#include "hl_sparse.ph"
-#include "hl_matrix_ops.cuh"
-#include "hl_matrix_apply.cuh"
-#include "hl_cuda_sparse.cuh"
#include "paddle/utils/Logging.h"
DEFINE_MATRIX_UNARY_PARAMETER_OP(mul_scalar, ONE_PARAMETER, a = a * p);
@@ -34,15 +33,15 @@ void hl_matrix_csr2dense(hl_sparse_matrix_s A_d,
CHECK(A_d->format == HL_SPARSE_CSR) << "matrix format error!";
if (A_d->nnz == 0) {
- hl_gpu_apply_unary_op(
- unary::Zero(), C_d, dimM, dimN, dimN);
+ hl_gpu_apply_unary_op(unary::Zero(), C_d, dimM, dimN, dimN);
return;
}
/* nnz != 0 */
hl_csr_matrix A_d2 = (hl_csr_matrix)(A_d->matrix);
- CHECK((A_d2->csr_val || A_d->type == HL_NO_VALUE) &&
- A_d2->csr_row && A_d2->csr_col) << "parameter transa error!";
+ CHECK((A_d2->csr_val || A_d->type == HL_NO_VALUE) && A_d2->csr_row &&
+ A_d2->csr_col)
+ << "parameter transa error!";
int blocksX = (dimN + CU_CSR2DENSE_THREAD_X - 1) / CU_CSR2DENSE_THREAD_X;
int blocksY = (dimM + CU_CSR2DENSE_THREAD_X - 1) / CU_CSR2DENSE_THREAD_X;
@@ -50,21 +49,11 @@ void hl_matrix_csr2dense(hl_sparse_matrix_s A_d,
dim3 grid(blocksX, blocksY);
if (A_d->type == HL_NO_VALUE) {
- KeSMatrixCsr2Dense<0>
- <<>>(A_d2->csr_val,
- A_d2->csr_row,
- A_d2->csr_col,
- C_d,
- dimM,
- dimN);
+ KeSMatrixCsr2Dense<0><<>>(
+ A_d2->csr_val, A_d2->csr_row, A_d2->csr_col, C_d, dimM, dimN);
} else if (A_d->type == HL_FLOAT_VALUE) {
- KeSMatrixCsr2Dense<1>
- <<>>(A_d2->csr_val,
- A_d2->csr_row,
- A_d2->csr_col,
- C_d,
- dimM,
- dimN);
+ KeSMatrixCsr2Dense<1><<>>(
+ A_d2->csr_val, A_d2->csr_row, A_d2->csr_col, C_d, dimM, dimN);
} else {
}
CHECK_SYNC("hl_matrix_csr2dense failed");
@@ -80,15 +69,15 @@ void hl_matrix_csc2dense(hl_sparse_matrix_s A_d,
CHECK(A_d->format == HL_SPARSE_CSC) << "matrix format error!";
if (A_d->nnz == 0) {
- hl_gpu_apply_unary_op(
- unary::Zero(), C_d, dimM, dimN, dimN);
+ hl_gpu_apply_unary_op(unary::Zero(), C_d, dimM, dimN, dimN);
return;
}
/* nnz != 0 */
hl_csc_matrix A_d2 = (hl_csc_matrix)(A_d->matrix);
- CHECK((A_d2->csc_val || A_d->type == HL_NO_VALUE) &&
- A_d2->csc_row && A_d2->csc_col) << "parameter transa error!";
+ CHECK((A_d2->csc_val || A_d->type == HL_NO_VALUE) && A_d2->csc_row &&
+ A_d2->csc_col)
+ << "parameter transa error!";
int blocksX = (dimN + CU_CSR2DENSE_THREAD_X - 1) / CU_CSR2DENSE_THREAD_X;
int blocksY = (dimM + CU_CSR2DENSE_THREAD_X - 1) / CU_CSR2DENSE_THREAD_X;
@@ -96,21 +85,11 @@ void hl_matrix_csc2dense(hl_sparse_matrix_s A_d,
dim3 grid(blocksX, blocksY);
if (A_d->type == HL_NO_VALUE) {
- KeSMatrixCsc2Dense<0>
- <<>>(A_d2->csc_val,
- A_d2->csc_row,
- A_d2->csc_col,
- C_d,
- dimM,
- dimN);
+ KeSMatrixCsc2Dense<0><<>>(
+ A_d2->csc_val, A_d2->csc_row, A_d2->csc_col, C_d, dimM, dimN);
} else if (A_d->type == HL_FLOAT_VALUE) {
- KeSMatrixCsc2Dense<1>
- <<>>(A_d2->csc_val,
- A_d2->csc_row,
- A_d2->csc_col,
- C_d,
- dimM,
- dimN);
+ KeSMatrixCsc2Dense<1><<>>(
+ A_d2->csc_val, A_d2->csc_row, A_d2->csc_col, C_d, dimM, dimN);
} else {
}
CHECK_SYNC("hl_matrix_csc2dense failed");
@@ -118,43 +97,43 @@ void hl_matrix_csc2dense(hl_sparse_matrix_s A_d,
void hl_malloc_sparse_matrix(hl_sparse_matrix_s *A_d,
hl_matrix_format_t format,
- hl_matrix_value_t value_type,
+ hl_matrix_value_t value_type,
int dimM,
int dimN,
int nnz) {
CHECK_NOTNULL(A_d);
CHECK(format == HL_SPARSE_CSR || format == HL_SPARSE_CSC)
- << "sparse matrix format error!";
+ << "sparse matrix format error!";
CHECK(value_type == HL_FLOAT_VALUE || value_type == HL_NO_VALUE)
- << "sparse matrix value type error!";
+ << "sparse matrix value type error!";
/* avoid malloc 0 bytes */
int nnz_s = (nnz == 0 ? 1 : nnz);
if (format == HL_SPARSE_CSR) {
CHECK(dimM > 0 && nnz >= 0) << "sparse matrix size error!";
- char* tmp = (char*)malloc(sizeof(_hl_sparse_matrix_s)
- + sizeof(_hl_csr_matrix));
+ char *tmp =
+ (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csr_matrix));
CHECK_NOTNULL(tmp);
- hl_csr_matrix csr = (hl_csr_matrix)(tmp+sizeof(_hl_sparse_matrix_s));
+ hl_csr_matrix csr = (hl_csr_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
csr->sparsity = -1.0;
if (value_type == HL_NO_VALUE) {
csr->csr_val = NULL;
csr->nnz_s = nnz_s;
- csr->row_s = dimM+1;
- csr->csr_row = (int*)hl_malloc_device((dimM+1)*sizeof(int));
- csr->csr_col = (int*)hl_malloc_device((nnz_s)*sizeof(int));
+ csr->row_s = dimM + 1;
+ csr->csr_row = (int *)hl_malloc_device((dimM + 1) * sizeof(int));
+ csr->csr_col = (int *)hl_malloc_device((nnz_s) * sizeof(int));
*A_d = (hl_sparse_matrix_s)tmp;
(*A_d)->matrix = (hl_matrix_s)csr;
} else if (value_type == HL_FLOAT_VALUE) {
csr->nnz_s = nnz_s;
- csr->row_s = dimM+1;
- csr->csr_val = (real*)hl_malloc_device((nnz_s)*sizeof(real));
- csr->csr_row = (int*)hl_malloc_device((dimM+1)*sizeof(int));
- csr->csr_col = (int*)hl_malloc_device((nnz_s)*sizeof(int));
+ csr->row_s = dimM + 1;
+ csr->csr_val = (real *)hl_malloc_device((nnz_s) * sizeof(real));
+ csr->csr_row = (int *)hl_malloc_device((dimM + 1) * sizeof(int));
+ csr->csr_col = (int *)hl_malloc_device((nnz_s) * sizeof(int));
*A_d = (hl_sparse_matrix_s)tmp;
(*A_d)->matrix = (hl_matrix_s)csr;
@@ -162,28 +141,28 @@ void hl_malloc_sparse_matrix(hl_sparse_matrix_s *A_d,
} else if (format == HL_SPARSE_CSC) {
CHECK(dimM > 0 && nnz >= 0) << "sparse matrix size error!";
- char* tmp = (char*)malloc(sizeof(_hl_sparse_matrix_s)
- + sizeof(_hl_csc_matrix));
+ char *tmp =
+ (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csc_matrix));
CHECK_NOTNULL(tmp);
- hl_csc_matrix csc = (hl_csc_matrix)(tmp+sizeof(_hl_sparse_matrix_s));
+ hl_csc_matrix csc = (hl_csc_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
csc->sparsity = -1.0f;
if (value_type == HL_NO_VALUE) {
csc->csc_val = NULL;
csc->nnz_s = nnz_s;
- csc->col_s = dimN+1;
- csc->csc_row = (int*)hl_malloc_device((nnz_s)*sizeof(int));
- csc->csc_col = (int*)hl_malloc_device((dimN+1)*sizeof(int));
+ csc->col_s = dimN + 1;
+ csc->csc_row = (int *)hl_malloc_device((nnz_s) * sizeof(int));
+ csc->csc_col = (int *)hl_malloc_device((dimN + 1) * sizeof(int));
*A_d = (hl_sparse_matrix_s)tmp;
(*A_d)->matrix = (hl_matrix_s)csc;
} else if (value_type == HL_FLOAT_VALUE) {
csc->nnz_s = nnz_s;
- csc->col_s = dimN+1;
- csc->csc_val = (real*)hl_malloc_device((nnz_s)*sizeof(real));
- csc->csc_row = (int*)hl_malloc_device((nnz_s)*sizeof(int));
- csc->csc_col = (int*)hl_malloc_device((dimN+1)*sizeof(int));
+ csc->col_s = dimN + 1;
+ csc->csc_val = (real *)hl_malloc_device((nnz_s) * sizeof(real));
+ csc->csc_row = (int *)hl_malloc_device((nnz_s) * sizeof(int));
+ csc->csc_col = (int *)hl_malloc_device((dimN + 1) * sizeof(int));
*A_d = (hl_sparse_matrix_s)tmp;
(*A_d)->matrix = (hl_matrix_s)csc;
@@ -200,7 +179,7 @@ void hl_malloc_sparse_matrix(hl_sparse_matrix_s *A_d,
void hl_free_sparse_matrix(hl_sparse_matrix_s A_d) {
CHECK_NOTNULL(A_d);
CHECK(A_d->format == HL_SPARSE_CSR || A_d->format == HL_SPARSE_CSC)
- << "sparse matrix format error!";
+ << "sparse matrix format error!";
if (A_d->matrix == NULL) {
free(A_d);
@@ -249,77 +228,77 @@ void hl_free_sparse_matrix(hl_sparse_matrix_s A_d) {
}
void hl_construct_sparse_matrix(hl_sparse_matrix_s *A_d,
- void * dest_d,
+ void *dest_d,
size_t size,
hl_matrix_format_t format,
- hl_matrix_value_t value_type,
+ hl_matrix_value_t value_type,
int dimM,
int dimN,
int nnz) {
CHECK_NOTNULL(A_d);
CHECK(format == HL_SPARSE_CSR || format == HL_SPARSE_CSC)
- << "sparse matrix format error!";
+ << "sparse matrix format error!";
if (format == HL_SPARSE_CSR) {
CHECK(dimM > 0 && nnz >= 0) << "sparse matrix size error!";
- size_t size_ = (dimM+1)*sizeof(int) + nnz*sizeof(int);
+ size_t size_ = (dimM + 1) * sizeof(int) + nnz * sizeof(int);
if (value_type != HL_NO_VALUE) {
- size_ += nnz*sizeof(real);
+ size_ += nnz * sizeof(real);
}
CHECK_LE(size_, size) << "dest_d size(" << size
- << ") too small, should bigger than(" << size_ << ")!";
+ << ") too small, should bigger than(" << size_
+ << ")!";
- char* tmp = (char*)malloc(sizeof(_hl_sparse_matrix_s)
- + sizeof(_hl_csr_matrix));
+ char *tmp =
+ (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csr_matrix));
CHECK_NOTNULL(tmp);
- hl_csr_matrix csr = (hl_csr_matrix)(tmp+sizeof(_hl_sparse_matrix_s));
+ hl_csr_matrix csr = (hl_csr_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
if (value_type == HL_NO_VALUE) {
csr->csr_val = NULL;
- csr->csr_row = (int*)dest_d;
- csr->csr_col = (int*)((char*)dest_d + (dimM+1)*sizeof(int));
+ csr->csr_row = (int *)dest_d;
+ csr->csr_col = (int *)((char *)dest_d + (dimM + 1) * sizeof(int));
} else {
- csr->csr_val = (real*)dest_d;
- csr->csr_row = (int*)((char*)dest_d + nnz*sizeof(real));
- csr->csr_col = (int*)((char*)dest_d +
- nnz*sizeof(real) +
- (dimM+1)*sizeof(int));
+ csr->csr_val = (real *)dest_d;
+ csr->csr_row = (int *)((char *)dest_d + nnz * sizeof(real));
+ csr->csr_col = (int *)((char *)dest_d + nnz * sizeof(real) +
+ (dimM + 1) * sizeof(int));
}
csr->nnz_s = nnz;
- csr->row_s = dimM+1;
+ csr->row_s = dimM + 1;
csr->sparsity = -1.0;
*A_d = (hl_sparse_matrix_s)tmp;
(*A_d)->matrix = (hl_matrix_s)csr;
} else if (format == HL_SPARSE_CSC) {
CHECK(dimM > 0 && nnz >= 0) << "sparse matrix size error!";
- size_t size_ = (dimN+1)*sizeof(int) + nnz*sizeof(int);
+ size_t size_ = (dimN + 1) * sizeof(int) + nnz * sizeof(int);
if (value_type != HL_NO_VALUE) {
- size_ += nnz*sizeof(real);
+ size_ += nnz * sizeof(real);
}
CHECK_LE(size_, size) << "dest_d size(" << size
- << ") too small, should bigger than(" << size_ << ")!";
+ << ") too small, should bigger than(" << size_
+ << ")!";
- char* tmp = (char*)malloc(sizeof(_hl_sparse_matrix_s)
- + sizeof(_hl_csc_matrix));
+ char *tmp =
+ (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csc_matrix));
CHECK_NOTNULL(tmp);
- hl_csc_matrix csc = (hl_csc_matrix)(tmp+sizeof(_hl_sparse_matrix_s));
+ hl_csc_matrix csc = (hl_csc_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
if (value_type == HL_NO_VALUE) {
csc->csc_val = NULL;
- csc->csc_col = (int*)dest_d;
- csc->csc_row = (int*)((char*)dest_d + (dimN+1)*sizeof(int));
+ csc->csc_col = (int *)dest_d;
+ csc->csc_row = (int *)((char *)dest_d + (dimN + 1) * sizeof(int));
} else {
- csc->csc_val = (real*)dest_d;
- csc->csc_col = (int*)((char*)dest_d + nnz*sizeof(real));
- csc->csc_row = (int*)((char*)dest_d +
- nnz*sizeof(real) +
- (dimN+1)*sizeof(int));
+ csc->csc_val = (real *)dest_d;
+ csc->csc_col = (int *)((char *)dest_d + nnz * sizeof(real));
+ csc->csc_row = (int *)((char *)dest_d + nnz * sizeof(real) +
+ (dimN + 1) * sizeof(int));
}
csc->nnz_s = nnz;
- csc->col_s = dimN+1;
+ csc->col_s = dimN + 1;
csc->sparsity = -1.0f;
*A_d = (hl_sparse_matrix_s)tmp;
(*A_d)->matrix = (hl_matrix_s)csc;
@@ -333,11 +312,11 @@ void hl_construct_sparse_matrix(hl_sparse_matrix_s *A_d,
}
void hl_construct_sparse_matrix(hl_sparse_matrix_s *A_d,
- real* value_d,
- int* rows_d,
- int* cols_d,
+ real *value_d,
+ int *rows_d,
+ int *cols_d,
hl_matrix_format_t format,
- hl_matrix_value_t value_type,
+ hl_matrix_value_t value_type,
int dimM,
int dimN,
int nnz) {
@@ -345,11 +324,11 @@ void hl_construct_sparse_matrix(hl_sparse_matrix_s *A_d,
CHECK(dimM > 0 && nnz >= 0) << "sparse matrix size error!";
CHECK(format == HL_SPARSE_CSR || format == HL_SPARSE_CSC)
- << "sparse matrix format error!";
+ << "sparse matrix format error!";
if (format == HL_SPARSE_CSR) {
- char* tmp = (char*)malloc(sizeof(_hl_sparse_matrix_s)
- + sizeof(_hl_csr_matrix));
+ char *tmp =
+ (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csr_matrix));
CHECK_NOTNULL(tmp);
hl_csr_matrix csr = (hl_csr_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
@@ -362,8 +341,8 @@ void hl_construct_sparse_matrix(hl_sparse_matrix_s *A_d,
*A_d = (hl_sparse_matrix_s)tmp;
(*A_d)->matrix = (hl_matrix_s)csr;
} else if (format == HL_SPARSE_CSC) {
- char* tmp = (char*)malloc(sizeof(_hl_sparse_matrix_s)
- + sizeof(_hl_csc_matrix));
+ char *tmp =
+ (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csc_matrix));
CHECK_NOTNULL(tmp);
hl_csc_matrix csc = (hl_csc_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
@@ -396,35 +375,30 @@ void hl_memcpy_csr_matrix(hl_sparse_matrix_s csr_matrix,
hl_stream_t stream) {
CHECK_NOTNULL(csr_matrix);
CHECK_EQ(csr_matrix->format, HL_SPARSE_CSR)
- << "csr_matrix is not csr format!";
+ << "csr_matrix is not csr format!";
CHECK_NOTNULL(csr_matrix->matrix);
hl_csr_matrix csr = (hl_csr_matrix)(csr_matrix->matrix);
- CHECK_LE(csr_matrix->nnz, csr->nnz_s)
- << "copy size " << csr_matrix->nnz
- << " is big than alloc size " << csr->nnz_s;
+ CHECK_LE(csr_matrix->nnz, csr->nnz_s) << "copy size " << csr_matrix->nnz
+ << " is big than alloc size "
+ << csr->nnz_s;
- CHECK_LE((csr_matrix->rows+1), csr->row_s)
- << "copy size " << (csr_matrix->rows + 1)
- << " is big than alloc size " << csr->row_s;
+ CHECK_LE((csr_matrix->rows + 1), csr->row_s)
+ << "copy size " << (csr_matrix->rows + 1) << " is big than alloc size "
+ << csr->row_s;
- CHECK(csr_matrix->type == HL_FLOAT_VALUE ||
- csr_matrix->type == HL_NO_VALUE)
- << "sparse matrix value type error!";
+ CHECK(csr_matrix->type == HL_FLOAT_VALUE || csr_matrix->type == HL_NO_VALUE)
+ << "sparse matrix value type error!";
if (csr_matrix->type == HL_NO_VALUE) {
if (csr_row == NULL && csr_col == NULL) {
return;
} else if (csr_row != NULL && csr_col != NULL) {
- hl_memcpy_async(csr->csr_row,
- csr_row,
- (csr_matrix->rows+1)*sizeof(int),
- stream);
+ hl_memcpy_async(
+ csr->csr_row, csr_row, (csr_matrix->rows + 1) * sizeof(int), stream);
- hl_memcpy_async(csr->csr_col,
- csr_col,
- (csr_matrix->nnz)*sizeof(int),
- stream);
+ hl_memcpy_async(
+ csr->csr_col, csr_col, (csr_matrix->nnz) * sizeof(int), stream);
} else {
LOG(FATAL) << "parameter csr_row or csr_col is null pointer!";
}
@@ -432,30 +406,21 @@ void hl_memcpy_csr_matrix(hl_sparse_matrix_s csr_matrix,
if (csr_val == NULL && csr_row == NULL && csr_col == NULL) {
return;
} else if (csr_val != NULL && csr_row == NULL && csr_col == NULL) {
- hl_memcpy_async(csr->csr_val,
- csr_val,
- (csr_matrix->nnz)*sizeof(real),
- stream);
+ hl_memcpy_async(
+ csr->csr_val, csr_val, (csr_matrix->nnz) * sizeof(real), stream);
} else if (csr_val != NULL && csr_row != NULL && csr_col != NULL) {
- hl_memcpy_async(csr->csr_val,
- csr_val,
- (csr_matrix->nnz)*sizeof(real),
- stream);
- hl_memcpy_async(csr->csr_row,
- csr_row,
- (csr_matrix->rows+1)*sizeof(int),
- stream);
- hl_memcpy_async(csr->csr_col,
- csr_col,
- (csr_matrix->nnz)*sizeof(int),
- stream);
+ hl_memcpy_async(
+ csr->csr_val, csr_val, (csr_matrix->nnz) * sizeof(real), stream);
+ hl_memcpy_async(
+ csr->csr_row, csr_row, (csr_matrix->rows + 1) * sizeof(int), stream);
+ hl_memcpy_async(
+ csr->csr_col, csr_col, (csr_matrix->nnz) * sizeof(int), stream);
} else {
LOG(FATAL) << "parameter csr_row or csr_col is null pointer!";
}
}
- csr->sparsity = ((float)csr_matrix->nnz) /
- ((float)csr_matrix->rows) /
+ csr->sparsity = ((float)csr_matrix->nnz) / ((float)csr_matrix->rows) /
((float)csr_matrix->cols);
}
@@ -466,33 +431,28 @@ void hl_memcpy_csc_matrix(hl_sparse_matrix_s csc_matrix,
hl_stream_t stream) {
CHECK_NOTNULL(csc_matrix);
CHECK_EQ(csc_matrix->format, HL_SPARSE_CSC)
- << "csc_matrix is not csc format error!";
+ << "csc_matrix is not csc format error!";
hl_csc_matrix csc = (hl_csc_matrix)(csc_matrix->matrix);
- CHECK_LE(csc_matrix->nnz, csc->nnz_s)
- << "copy size " << csc_matrix->nnz
- << " is big than alloc size " << csc->nnz_s;
+ CHECK_LE(csc_matrix->nnz, csc->nnz_s) << "copy size " << csc_matrix->nnz
+ << " is big than alloc size "
+ << csc->nnz_s;
- CHECK_LE((csc_matrix->cols+1), csc->col_s)
- << "copy size " <<(csc_matrix->cols + 1)
- << " is big than alloc size " << csc->col_s;
+ CHECK_LE((csc_matrix->cols + 1), csc->col_s)
+ << "copy size " << (csc_matrix->cols + 1) << " is big than alloc size "
+ << csc->col_s;
- CHECK(csc_matrix->type == HL_FLOAT_VALUE ||
- csc_matrix->type == HL_NO_VALUE)
- << "sparse matrix value type error!";
+ CHECK(csc_matrix->type == HL_FLOAT_VALUE || csc_matrix->type == HL_NO_VALUE)
+ << "sparse matrix value type error!";
if (csc_matrix->type == HL_NO_VALUE) {
if (csc_row == NULL && csc_col == NULL) {
return;
} else if (csc_row != NULL && csc_col != NULL) {
- hl_memcpy_async(csc->csc_row,
- csc_row,
- (csc_matrix->nnz)*sizeof(int),
- stream);
- hl_memcpy_async(csc->csc_col,
- csc_col,
- (csc_matrix->cols+1)*sizeof(int),
- stream);
+ hl_memcpy_async(
+ csc->csc_row, csc_row, (csc_matrix->nnz) * sizeof(int), stream);
+ hl_memcpy_async(
+ csc->csc_col, csc_col, (csc_matrix->cols + 1) * sizeof(int), stream);
} else {
LOG(FATAL) << "parameter csc_row or csc_col is null pointer!";
}
@@ -500,30 +460,21 @@ void hl_memcpy_csc_matrix(hl_sparse_matrix_s csc_matrix,
if (csc_val == NULL && csc_row == NULL && csc_col == NULL) {
return;
} else if (csc_val != NULL && csc_row == NULL && csc_col == NULL) {
- hl_memcpy_async(csc->csc_val,
- csc_val,
- (csc_matrix->nnz)*sizeof(real),
- stream);
+ hl_memcpy_async(
+ csc->csc_val, csc_val, (csc_matrix->nnz) * sizeof(real), stream);
} else if (csc_val != NULL && csc_row != NULL && csc_col != NULL) {
- hl_memcpy_async(csc->csc_val,
- csc_val,
- (csc_matrix->nnz)*sizeof(real),
- stream);
- hl_memcpy_async(csc->csc_row,
- csc_row,
- (csc_matrix->nnz)*sizeof(int),
- stream);
- hl_memcpy_async(csc->csc_col,
- csc_col,
- (csc_matrix->cols+1)*sizeof(int),
- stream);
+ hl_memcpy_async(
+ csc->csc_val, csc_val, (csc_matrix->nnz) * sizeof(real), stream);
+ hl_memcpy_async(
+ csc->csc_row, csc_row, (csc_matrix->nnz) * sizeof(int), stream);
+ hl_memcpy_async(
+ csc->csc_col, csc_col, (csc_matrix->cols + 1) * sizeof(int), stream);
} else {
LOG(FATAL) << "parameter csc_row or csc_col is null pointer!";
}
}
- csc->sparsity = ((float)csc_matrix->nnz) /
- ((float)csc_matrix->rows) /
+ csc->sparsity = ((float)csc_matrix->nnz) / ((float)csc_matrix->rows) /
((float)csc_matrix->cols);
}
@@ -531,32 +482,23 @@ void hl_memcpy_sparse_matrix(hl_sparse_matrix_s dst,
hl_sparse_matrix_s src,
hl_stream_t stream) {
CHECK(dst && src && dst->matrix && src->matrix)
- << "parameter dst or src is null pointer!";
- CHECK_EQ(dst->format, src->format)
- << "sparse matrix format does not match!";
+ << "parameter dst or src is null pointer!";
+ CHECK_EQ(dst->format, src->format) << "sparse matrix format does not match!";
CHECK(dst->type != HL_FLOAT_VALUE || src->type != HL_NO_VALUE)
- << "src sparse matrix is no value, dst sparse matrix has value!";
+ << "src sparse matrix is no value, dst sparse matrix has value!";
if (dst->format == HL_SPARSE_CSR) {
dst->rows = src->rows;
dst->cols = src->cols;
- dst->nnz = src->nnz;
+ dst->nnz = src->nnz;
hl_csr_matrix csr = (hl_csr_matrix)src->matrix;
- hl_memcpy_csr_matrix(dst,
- csr->csr_val,
- csr->csr_row,
- csr->csr_col,
- stream);
+ hl_memcpy_csr_matrix(dst, csr->csr_val, csr->csr_row, csr->csr_col, stream);
} else if (dst->format == HL_SPARSE_CSC) {
dst->rows = src->rows;
dst->cols = src->cols;
- dst->nnz = src->nnz;
+ dst->nnz = src->nnz;
hl_csc_matrix csc = (hl_csc_matrix)src->matrix;
- hl_memcpy_csc_matrix(dst,
- csc->csc_val,
- csc->csc_row,
- csc->csc_col,
- stream);
+ hl_memcpy_csc_matrix(dst, csc->csc_val, csc->csc_row, csc->csc_col, stream);
} else {
LOG(FATAL) << "sparse matrix format error!";
}
@@ -569,20 +511,24 @@ static void _beta_mul_c(real *c, int dimM, int dimN, real beta) {
if (beta == 0.0) {
hl_gpu_apply_unary_op(unary::Zero(), c, dimM, dimN, dimN);
} else {
- if (beta != 1.0){
- hl_gpu_apply_unary_op(
- unary::mul_scalar(beta), c, dimM, dimN, dimN);
+ if (beta != 1.0) {
+ hl_gpu_apply_unary_op(unary::mul_scalar(beta), c, dimM, dimN, dimN);
}
}
return;
}
-void hl_matrix_csr_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
- real *B_d, hl_trans_op_t transb,
+void hl_matrix_csr_mul_dense(hl_sparse_matrix_s A_d,
+ hl_trans_op_t transa,
+ real *B_d,
+ hl_trans_op_t transb,
real *C_d,
- int dimM, int dimN, int dimK,
- real alpha, real beta) {
+ int dimM,
+ int dimN,
+ int dimK,
+ real alpha,
+ real beta) {
CHECK_EQ(transb, HPPL_OP_N);
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
@@ -592,7 +538,7 @@ void hl_matrix_csr_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
if ((HPPL_OP_N == transa && (A_d->rows != dimM || A_d->cols != dimK)) ||
(HPPL_OP_T == transa && (A_d->rows != dimK || A_d->cols != dimM))) {
- LOG(FATAL) << "parameter error!";
+ LOG(FATAL) << "parameter error!";
}
if (A_d->nnz == 0) {
@@ -603,8 +549,7 @@ void hl_matrix_csr_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
/* nnz != 0 */
hl_csr_matrix A_d2 = (hl_csr_matrix)(A_d->matrix);
if ((A_d2->csr_val == NULL && A_d->type != HL_NO_VALUE) ||
- A_d2->csr_row == NULL ||
- A_d2->csr_col == NULL) {
+ A_d2->csr_row == NULL || A_d2->csr_col == NULL) {
LOG(FATAL) << "parameter error!";
}
@@ -617,63 +562,63 @@ void hl_matrix_csr_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
/* sparsity pattern */
// A_d->sparsity;
if (A_d->type == HL_NO_VALUE) {
- KeSMatrixCsrMulDense<0>
- <<>>(C_d,
- A_d2->csr_val,
- A_d2->csr_col,
- A_d2->csr_row,
- B_d,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixCsrMulDense<0><<>>(
+ C_d,
+ A_d2->csr_val,
+ A_d2->csr_col,
+ A_d2->csr_row,
+ B_d,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
} else {
- KeSMatrixCsrMulDense<1>
- <<>>(C_d,
- A_d2->csr_val,
- A_d2->csr_col,
- A_d2->csr_row,
- B_d,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixCsrMulDense<1><<>>(
+ C_d,
+ A_d2->csr_val,
+ A_d2->csr_col,
+ A_d2->csr_row,
+ B_d,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
}
} else if (HPPL_OP_T == transa) {
_beta_mul_c(C_d, dimM, dimN, beta);
- int blocksX = (dimN + CU_CSC_MUL_DENSE_BLOCK_N - 1) /
- CU_CSC_MUL_DENSE_BLOCK_N;
- int blocksY = (dimK + CU_CSC_MUL_DENSE_BLOCK_K - 1) /
- CU_CSC_MUL_DENSE_BLOCK_K;
+ int blocksX =
+ (dimN + CU_CSC_MUL_DENSE_BLOCK_N - 1) / CU_CSC_MUL_DENSE_BLOCK_N;
+ int blocksY =
+ (dimK + CU_CSC_MUL_DENSE_BLOCK_K - 1) / CU_CSC_MUL_DENSE_BLOCK_K;
dim3 threads(CU_CSC_MUL_DENSE_THREAD_X, CU_CSC_MUL_DENSE_THREAD_Y);
dim3 grid(blocksX, blocksY);
if (A_d->type == HL_NO_VALUE) {
- KeSMatrixCscMulDense<0>
- <<>>(C_d,
- A_d2->csr_val,
- A_d2->csr_col,
- A_d2->csr_row,
- B_d,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixCscMulDense<0><<>>(
+ C_d,
+ A_d2->csr_val,
+ A_d2->csr_col,
+ A_d2->csr_row,
+ B_d,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
} else {
- KeSMatrixCscMulDense<1>
- <<>>(C_d,
- A_d2->csr_val,
- A_d2->csr_col,
- A_d2->csr_row,
- B_d,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixCscMulDense<1><<>>(
+ C_d,
+ A_d2->csr_val,
+ A_d2->csr_col,
+ A_d2->csr_row,
+ B_d,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
}
} else {
LOG(FATAL) << "parameter transa error!";
@@ -682,11 +627,16 @@ void hl_matrix_csr_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
CHECK_SYNC("hl_matrix_csr_mul_dense failed");
}
-void hl_matrix_dense_mul_csc(real *A_d, hl_trans_op_t transa,
- hl_sparse_matrix_s B_d, hl_trans_op_t transb,
+void hl_matrix_dense_mul_csc(real *A_d,
+ hl_trans_op_t transa,
+ hl_sparse_matrix_s B_d,
+ hl_trans_op_t transb,
real *C_d,
- int dimM, int dimN, int dimK,
- real alpha, real beta) {
+ int dimM,
+ int dimN,
+ int dimK,
+ real alpha,
+ real beta) {
CHECK_EQ(transa, HPPL_OP_N);
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
@@ -698,8 +648,7 @@ void hl_matrix_dense_mul_csc(real *A_d, hl_trans_op_t transa,
LOG(FATAL) << "parameter dims error!";
}
- CHECK_EQ(B_d->format, HL_SPARSE_CSC)
- << "matrix format error!";
+ CHECK_EQ(B_d->format, HL_SPARSE_CSC) << "matrix format error!";
if (B_d->nnz == 0) {
_beta_mul_c(C_d, dimM, dimN, beta);
@@ -709,8 +658,7 @@ void hl_matrix_dense_mul_csc(real *A_d, hl_trans_op_t transa,
/* nnz != 0 */
hl_csc_matrix B_d2 = (hl_csc_matrix)(B_d->matrix);
if ((B_d2->csc_val == NULL && B_d->type != HL_NO_VALUE) ||
- B_d2->csc_row == NULL ||
- B_d2->csc_col == NULL) {
+ B_d2->csc_row == NULL || B_d2->csc_col == NULL) {
LOG(FATAL) << "parameter B is null!";
}
@@ -721,60 +669,60 @@ void hl_matrix_dense_mul_csc(real *A_d, hl_trans_op_t transa,
dim3 grid(blocksX, blocksY);
if (B_d->type == HL_NO_VALUE) {
- KeSMatrixDenseMulCsc<0>
- <<>>(C_d,
- A_d,
- B_d2->csc_val,
- B_d2->csc_row,
- B_d2->csc_col,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixDenseMulCsc<0><<>>(
+ C_d,
+ A_d,
+ B_d2->csc_val,
+ B_d2->csc_row,
+ B_d2->csc_col,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
} else {
- KeSMatrixDenseMulCsc<1>
- <<>>(C_d,
- A_d,
- B_d2->csc_val,
- B_d2->csc_row,
- B_d2->csc_col,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixDenseMulCsc<1><<>>(
+ C_d,
+ A_d,
+ B_d2->csc_val,
+ B_d2->csc_row,
+ B_d2->csc_col,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
}
} else if (transb == HPPL_OP_T) {
_beta_mul_c(C_d, dimM, dimN, beta);
- int blocksX = 1 + (dimK-1)/CU_DM_CSR_THREAD_X;
- int blocksY = 1 + (dimM-1)/CU_DM_CSR_BLOCK_M;
+ int blocksX = 1 + (dimK - 1) / CU_DM_CSR_THREAD_X;
+ int blocksY = 1 + (dimM - 1) / CU_DM_CSR_BLOCK_M;
dim3 threads(CU_DM_CSR_THREAD_X, CU_DM_CSR_THREAD_Y);
dim3 grid(blocksX, blocksY);
if (B_d->type == HL_NO_VALUE) {
- KeSMatrixDenseMulCsr<0>
- <<>>(C_d,
- A_d,
- B_d2->csc_val,
- B_d2->csc_col,
- B_d2->csc_row,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixDenseMulCsr<0><<>>(
+ C_d,
+ A_d,
+ B_d2->csc_val,
+ B_d2->csc_col,
+ B_d2->csc_row,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
} else {
- KeSMatrixDenseMulCsr<1>
- <<>>(C_d,
- A_d,
- B_d2->csc_val,
- B_d2->csc_col,
- B_d2->csc_row,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixDenseMulCsr<1><<>>(
+ C_d,
+ A_d,
+ B_d2->csc_val,
+ B_d2->csc_col,
+ B_d2->csc_row,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
}
} else {
LOG(FATAL) << "parameter transb error!";
@@ -783,24 +731,28 @@ void hl_matrix_dense_mul_csc(real *A_d, hl_trans_op_t transa,
CHECK_SYNC("hl_matrix_dense_mul_csc failed");
}
-void hl_matrix_dense_mul_csr(real *A_d, hl_trans_op_t transa,
- hl_sparse_matrix_s B_d, hl_trans_op_t transb,
+void hl_matrix_dense_mul_csr(real *A_d,
+ hl_trans_op_t transa,
+ hl_sparse_matrix_s B_d,
+ hl_trans_op_t transb,
real *C_d,
- int dimM, int dimN, int dimK,
- real alpha, real beta) {
+ int dimM,
+ int dimN,
+ int dimK,
+ real alpha,
+ real beta) {
CHECK_EQ(transa, HPPL_OP_N);
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
CHECK_NOTNULL(C_d);
- if (dimM <= 0 || dimN <= 0 || dimK <= 0
- || (transb == HPPL_OP_N && (B_d->rows != dimK || B_d->cols != dimN))
- || (transb == HPPL_OP_T && (B_d->rows != dimN || B_d->cols != dimK))) {
+ if (dimM <= 0 || dimN <= 0 || dimK <= 0 ||
+ (transb == HPPL_OP_N && (B_d->rows != dimK || B_d->cols != dimN)) ||
+ (transb == HPPL_OP_T && (B_d->rows != dimN || B_d->cols != dimK))) {
LOG(FATAL) << "parameter dims error!";
}
- CHECK_EQ(B_d->format, HL_SPARSE_CSR)
- << "matrix format error!";
+ CHECK_EQ(B_d->format, HL_SPARSE_CSR) << "matrix format error!";
if (B_d->nnz == 0) {
_beta_mul_c(C_d, dimM, dimN, beta);
@@ -810,41 +762,40 @@ void hl_matrix_dense_mul_csr(real *A_d, hl_trans_op_t transa,
/* nnz != 0 */
hl_csr_matrix B_d2 = (hl_csr_matrix)(B_d->matrix);
if ((B_d2->csr_val == NULL && B_d->type != HL_NO_VALUE) ||
- B_d2->csr_row == NULL ||
- B_d2->csr_col == NULL) {
+ B_d2->csr_row == NULL || B_d2->csr_col == NULL) {
LOG(FATAL) << "parameter transa error!";
}
if (transb == HPPL_OP_N) {
_beta_mul_c(C_d, dimM, dimN, beta);
- int blocksX = 1 + (dimK-1)/CU_DM_CSR_THREAD_X;
- int blocksY = 1 + (dimM-1)/CU_DM_CSR_BLOCK_M;
+ int blocksX = 1 + (dimK - 1) / CU_DM_CSR_THREAD_X;
+ int blocksY = 1 + (dimM - 1) / CU_DM_CSR_BLOCK_M;
dim3 threads(CU_DM_CSR_THREAD_X, CU_DM_CSR_THREAD_Y);
dim3 grid(blocksX, blocksY);
if (B_d->type == HL_NO_VALUE) {
- KeSMatrixDenseMulCsr<0>
- <<>>(C_d,
- A_d,
- B_d2->csr_val,
- B_d2->csr_row,
- B_d2->csr_col,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixDenseMulCsr<0><<>>(
+ C_d,
+ A_d,
+ B_d2->csr_val,
+ B_d2->csr_row,
+ B_d2->csr_col,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
} else {
- KeSMatrixDenseMulCsr<1>
- <<>>(C_d,
- A_d,
- B_d2->csr_val,
- B_d2->csr_row,
- B_d2->csr_col,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixDenseMulCsr<1><<>>(
+ C_d,
+ A_d,
+ B_d2->csr_val,
+ B_d2->csr_row,
+ B_d2->csr_col,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
}
} else if (transb == HPPL_OP_T) {
int blocksX = (dimM + CU_CSCMM_BLOCK_M_BEST - 1) / CU_CSCMM_BLOCK_M_BEST;
@@ -852,29 +803,29 @@ void hl_matrix_dense_mul_csr(real *A_d, hl_trans_op_t transa,
dim3 threads(CU_CSCMM_THREAD_X_BEST, CU_CSCMM_THREAD_Y_BEST);
dim3 grid(blocksX, blocksY);
if (B_d->type == HL_NO_VALUE) {
- KeSMatrixDenseMulCsc<0>
- <<>>(C_d,
- A_d,
- B_d2->csr_val,
- B_d2->csr_col,
- B_d2->csr_row,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixDenseMulCsc<0><<>>(
+ C_d,
+ A_d,
+ B_d2->csr_val,
+ B_d2->csr_col,
+ B_d2->csr_row,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
} else {
- KeSMatrixDenseMulCsc<1>
- <<>>(C_d,
- A_d,
- B_d2->csr_val,
- B_d2->csr_col,
- B_d2->csr_row,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixDenseMulCsc<1><<>>(
+ C_d,
+ A_d,
+ B_d2->csr_val,
+ B_d2->csr_col,
+ B_d2->csr_row,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
}
} else {
LOG(FATAL) << "parameter transb error!";
@@ -883,11 +834,16 @@ void hl_matrix_dense_mul_csr(real *A_d, hl_trans_op_t transa,
CHECK_SYNC("hl_matrix_dense_mul_csr failed");
}
-void hl_matrix_csc_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
- real *B_d, hl_trans_op_t transb,
+void hl_matrix_csc_mul_dense(hl_sparse_matrix_s A_d,
+ hl_trans_op_t transa,
+ real *B_d,
+ hl_trans_op_t transb,
real *C_d,
- int dimM, int dimN, int dimK,
- real alpha, real beta) {
+ int dimM,
+ int dimN,
+ int dimK,
+ real alpha,
+ real beta) {
CHECK_EQ(transb, HPPL_OP_N);
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
@@ -908,42 +864,43 @@ void hl_matrix_csc_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
/* nnz != 0 */
hl_csc_matrix A_d2 = (hl_csc_matrix)(A_d->matrix);
if ((A_d2->csc_val == NULL && A_d->type != HL_NO_VALUE) ||
- A_d2->csc_row == NULL ||
- A_d2->csc_col == NULL) {
+ A_d2->csc_row == NULL || A_d2->csc_col == NULL) {
LOG(FATAL) << "parameter error!";
}
if (HPPL_OP_N == transa) {
_beta_mul_c(C_d, dimM, dimN, beta);
- int blocksX = (dimN + CU_CSC_MUL_DENSE_BLOCK_N -1)/CU_CSC_MUL_DENSE_BLOCK_N;
- int blocksY = (dimK + CU_CSC_MUL_DENSE_BLOCK_K -1)/CU_CSC_MUL_DENSE_BLOCK_K;
+ int blocksX =
+ (dimN + CU_CSC_MUL_DENSE_BLOCK_N - 1) / CU_CSC_MUL_DENSE_BLOCK_N;
+ int blocksY =
+ (dimK + CU_CSC_MUL_DENSE_BLOCK_K - 1) / CU_CSC_MUL_DENSE_BLOCK_K;
dim3 threads(CU_CSC_MUL_DENSE_THREAD_X, CU_CSC_MUL_DENSE_THREAD_Y);
dim3 grid(blocksX, blocksY);
if (A_d->type == HL_NO_VALUE) {
- KeSMatrixCscMulDense<0>
- <<>>(C_d,
- A_d2->csc_val,
- A_d2->csc_row,
- A_d2->csc_col,
- B_d,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixCscMulDense<0><<>>(
+ C_d,
+ A_d2->csc_val,
+ A_d2->csc_row,
+ A_d2->csc_col,
+ B_d,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
} else {
- KeSMatrixCscMulDense<1>
- <<>>(C_d,
- A_d2->csc_val,
- A_d2->csc_row,
- A_d2->csc_col,
- B_d,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixCscMulDense<1><<>>(
+ C_d,
+ A_d2->csc_val,
+ A_d2->csc_row,
+ A_d2->csc_col,
+ B_d,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
}
} else if (HPPL_OP_T == transa) {
int blocksX = (dimN + CU_CSRMM_BLOCK_N - 1) / CU_CSRMM_BLOCK_N;
@@ -954,29 +911,29 @@ void hl_matrix_csc_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
/* sparsity pattern */
// A_d->sparsity;
if (A_d->type == HL_NO_VALUE) {
- KeSMatrixCsrMulDense<0>
- <<>>(C_d,
- A_d2->csc_val,
- A_d2->csc_row,
- A_d2->csc_col,
- B_d,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixCsrMulDense<0><<>>(
+ C_d,
+ A_d2->csc_val,
+ A_d2->csc_row,
+ A_d2->csc_col,
+ B_d,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
} else {
- KeSMatrixCsrMulDense<1>
- <<>>(C_d,
- A_d2->csc_val,
- A_d2->csc_row,
- A_d2->csc_col,
- B_d,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixCsrMulDense<1><<>>(
+ C_d,
+ A_d2->csc_val,
+ A_d2->csc_row,
+ A_d2->csc_col,
+ B_d,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
}
} else {
LOG(FATAL) << "parameter transa error!";
@@ -985,11 +942,16 @@ void hl_matrix_csc_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
CHECK_SYNC("hl_matrix_csc_mul_dense failed");
}
-void hl_sparse_matrix_mul(real *A_d, hl_trans_op_t transa,
- real *B_d, hl_trans_op_t transb,
- hl_sparse_matrix_s C_d,
- int dimM, int dimN, int dimK,
- real alpha, real beta) {
+void hl_sparse_matrix_mul(real *A_d,
+ hl_trans_op_t transa,
+ real *B_d,
+ hl_trans_op_t transb,
+ hl_sparse_matrix_s C_d,
+ int dimM,
+ int dimN,
+ int dimK,
+ real alpha,
+ real beta) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
CHECK_NOTNULL(C_d);
@@ -1000,18 +962,14 @@ void hl_sparse_matrix_mul(real *A_d, hl_trans_op_t transa,
if (C_d->format == HL_SPARSE_CSC) {
hl_csc_matrix C_d2 = (hl_csc_matrix)(C_d->matrix);
- if (C_d2->csc_val == NULL ||
- C_d2->csc_row == NULL ||
+ if (C_d2->csc_val == NULL || C_d2->csc_row == NULL ||
C_d2->csc_col == NULL) {
LOG(FATAL) << "parameter error!";
}
if (beta != 1.0) {
- hl_gpu_apply_unary_op(unary::mul_scalar(beta),
- C_d2->csc_val,
- 1,
- C_d->nnz,
- C_d->nnz);
+ hl_gpu_apply_unary_op(
+ unary::mul_scalar(beta), C_d2->csc_val, 1, C_d->nnz, C_d->nnz);
}
int blocksX = dimN;
@@ -1020,34 +978,30 @@ void hl_sparse_matrix_mul(real *A_d, hl_trans_op_t transa,
dim3 grid(blocksX, blocksY);
bool transA = transa == HPPL_OP_T ? 1 : 0;
bool transB = transb == HPPL_OP_T ? 1 : 0;
- KeSMatrixDenseMulDense2CSC
- <<>>(C_d2->csc_val,
- C_d2->csc_row,
- C_d2->csc_col,
- A_d,
- B_d,
- transA,
- transB,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
+ KeSMatrixDenseMulDense2CSC<<>>(
+ C_d2->csc_val,
+ C_d2->csc_row,
+ C_d2->csc_col,
+ A_d,
+ B_d,
+ transA,
+ transB,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
CHECK_SYNC("hl_sparse_matrix_mul failed");
} else {
hl_csr_matrix C_d2 = (hl_csr_matrix)(C_d->matrix);
if ((C_d2->csr_val == NULL && C_d->type != HL_NO_VALUE) ||
- C_d2->csr_row == NULL ||
- C_d2->csr_col == NULL) {
+ C_d2->csr_row == NULL || C_d2->csr_col == NULL) {
LOG(FATAL) << "parameter error!";
}
if (beta != 1.0) {
- hl_gpu_apply_unary_op(unary::mul_scalar(beta),
- C_d2->csr_val,
- 1,
- C_d->nnz,
- C_d->nnz);
+ hl_gpu_apply_unary_op(
+ unary::mul_scalar(beta), C_d2->csr_val, 1, C_d->nnz, C_d->nnz);
}
bool transA = transa == HPPL_OP_T ? 1 : 0;
@@ -1058,20 +1012,20 @@ void hl_sparse_matrix_mul(real *A_d, hl_trans_op_t transa,
dim3 threads(CU_CSCMM_DMD2CSR_THREAD_X, 1);
dim3 grid(blocksX, blocksY);
- KeSMatrixDenseMulDense2CSR
- <<>>(C_d2->csr_val,
- C_d2->csr_row,
- C_d2->csr_col,
- A_d,
- B_d,
- transA,
- transB,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
- CHECK_SYNC("hl_sparse_matrix_mul failed");
+ KeSMatrixDenseMulDense2CSR<<>>(
+ C_d2->csr_val,
+ C_d2->csr_row,
+ C_d2->csr_col,
+ A_d,
+ B_d,
+ transA,
+ transB,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
+ CHECK_SYNC("hl_sparse_matrix_mul failed");
} else {
CHECK(!transA) << "Not supported A is trans and B is not trans!";
@@ -1080,21 +1034,21 @@ void hl_sparse_matrix_mul(real *A_d, hl_trans_op_t transa,
avgNnzPerRow = avgNnzPerRow > 0 ? avgNnzPerRow : 1;
int gridx = DIVUP(avgNnzPerRow, CU_BLOCK_SIZE);
dim3 grid(gridx, dimM);
- KeSMatrixDenseMulDenseTrans2CSR
- <<>>(C_d2->csr_val,
- C_d2->csr_row,
- C_d2->csr_col,
- A_d,
- B_d,
- transA,
- transB,
- dimM,
- dimN,
- dimK,
- alpha,
- beta);
- CHECK_SYNC("hl_sparse_matrix_mul failed");
- }
+ KeSMatrixDenseMulDenseTrans2CSR<<>>(
+ C_d2->csr_val,
+ C_d2->csr_row,
+ C_d2->csr_col,
+ A_d,
+ B_d,
+ transA,
+ transB,
+ dimM,
+ dimN,
+ dimK,
+ alpha,
+ beta);
+ CHECK_SYNC("hl_sparse_matrix_mul failed");
+ }
}
}
@@ -1111,7 +1065,7 @@ void hl_memcpy_from_csc_matrix(real *csc_val,
CHECK_NOTNULL(csc_col);
CHECK_EQ(csc_matrix->format, HL_SPARSE_CSC)
- << "csc_matrix is not csc format error!";
+ << "csc_matrix is not csc format error!";
if (csc_matrix->nnz > row_size ||
csc_matrix->cols + 1 > static_cast(col_size)) {
@@ -1119,20 +1073,20 @@ void hl_memcpy_from_csc_matrix(real *csc_val,
}
hl_csc_matrix csc = (hl_csc_matrix)(csc_matrix->matrix);
- hl_memcpy_async((void*)csc_row,
- (void*)csc->csc_row,
+ hl_memcpy_async((void *)csc_row,
+ (void *)csc->csc_row,
(csc_matrix->nnz) * sizeof(int),
stream);
- hl_memcpy_async((void*)csc_col,
- (void*)csc->csc_col,
+ hl_memcpy_async((void *)csc_col,
+ (void *)csc->csc_col,
(csc_matrix->cols + 1) * sizeof(int),
stream);
if (csc_matrix->type == HL_FLOAT_VALUE) {
if (csc_val != NULL) {
CHECK_LE(csc_matrix->nnz, val_size) << "size not match!";
- hl_memcpy_async((void*)csc_val,
- (void*)csc->csc_val,
- (csc_matrix->nnz)*sizeof(real),
+ hl_memcpy_async((void *)csc_val,
+ (void *)csc->csc_val,
+ (csc_matrix->nnz) * sizeof(real),
stream);
} else {
LOG(FATAL) << "parameter csr_val is null pointer!";
@@ -1152,7 +1106,7 @@ void hl_memcpy_from_csr_matrix(real *csr_val,
CHECK_NOTNULL(csr_row);
CHECK_NOTNULL(csr_col);
CHECK_EQ(csr_matrix->format, HL_SPARSE_CSR)
- << "csr_matrix is not csr format error!";
+ << "csr_matrix is not csr format error!";
if (csr_matrix->nnz > col_size ||
csr_matrix->rows + 1 > static_cast(row_size)) {
@@ -1160,20 +1114,20 @@ void hl_memcpy_from_csr_matrix(real *csr_val,
}
hl_csr_matrix csr = (hl_csr_matrix)(csr_matrix->matrix);
- hl_memcpy_async((void*)csr_row,
- (void*)csr->csr_row,
- (csr_matrix->rows+1)*sizeof(int),
+ hl_memcpy_async((void *)csr_row,
+ (void *)csr->csr_row,
+ (csr_matrix->rows + 1) * sizeof(int),
stream);
- hl_memcpy_async((void*)csr_col,
- (void*)csr->csr_col,
- (csr_matrix->nnz)*sizeof(int),
+ hl_memcpy_async((void *)csr_col,
+ (void *)csr->csr_col,
+ (csr_matrix->nnz) * sizeof(int),
stream);
if (csr_matrix->type == HL_FLOAT_VALUE) {
if (csr_val != NULL) {
CHECK_LE(csr_matrix->nnz, val_size) << "size not match!";
- hl_memcpy_async((void*)csr_val,
- (void*)csr->csr_val,
- (csr_matrix->nnz)*sizeof(real),
+ hl_memcpy_async((void *)csr_val,
+ (void *)csr->csr_val,
+ (csr_matrix->nnz) * sizeof(real),
stream);
} else {
LOG(FATAL) << "parameter csr_val is null pointer!";
@@ -1181,8 +1135,8 @@ void hl_memcpy_from_csr_matrix(real *csr_val,
}
}
-void hl_sparse_matrix_column_sum(real* A_d, hl_sparse_matrix_s B_d, int dimM,
- int dimN, real scale) {
+void hl_sparse_matrix_column_sum(
+ real *A_d, hl_sparse_matrix_s B_d, int dimM, int dimN, real scale) {
if (B_d->format == HL_SPARSE_CSR) {
hl_matrix_csr_column_sum(A_d, B_d, dimM, dimN, scale);
} else {
@@ -1190,8 +1144,8 @@ void hl_sparse_matrix_column_sum(real* A_d, hl_sparse_matrix_s B_d, int dimM,
}
}
-void hl_matrix_csr_column_sum(real* A_d, hl_sparse_matrix_s B_d,
- int dimM, int dimN, real scale) {
+void hl_matrix_csr_column_sum(
+ real *A_d, hl_sparse_matrix_s B_d, int dimM, int dimN, real scale) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
@@ -1216,8 +1170,7 @@ void hl_matrix_csr_column_sum(real* A_d, hl_sparse_matrix_s B_d,
CHECK_SYNC("hl_matrix_csr_column_sum failed");
}
-void hl_sparse_matrix_add_bias(hl_sparse_matrix_s A_d,
- real* B_d, real scale) {
+void hl_sparse_matrix_add_bias(hl_sparse_matrix_s A_d, real *B_d, real scale) {
if (A_d->format == HL_SPARSE_CSR) {
hl_matrix_csr_add_bias(A_d, B_d, scale);
} else {
@@ -1225,8 +1178,7 @@ void hl_sparse_matrix_add_bias(hl_sparse_matrix_s A_d,
}
}
-void hl_matrix_csr_add_bias(hl_sparse_matrix_s A_d, real* B_d,
- real scale) {
+void hl_matrix_csr_add_bias(hl_sparse_matrix_s A_d, real *B_d, real scale) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
@@ -1247,8 +1199,12 @@ void hl_matrix_csr_add_bias(hl_sparse_matrix_s A_d, real* B_d,
CHECK_SYNC("hl_sparse_matrix_add_bias failed");
}
-void hl_sparse_matrix_add_dense(hl_sparse_matrix_s A_d, real *B_d, int dimM,
- int dimN, real alpha, real beta) {
+void hl_sparse_matrix_add_dense(hl_sparse_matrix_s A_d,
+ real *B_d,
+ int dimM,
+ int dimN,
+ real alpha,
+ real beta) {
if (A_d->format == HL_SPARSE_CSR) {
hl_matrix_csr_add_dense(A_d, B_d, dimM, dimN, alpha, beta);
} else {
@@ -1256,8 +1212,12 @@ void hl_sparse_matrix_add_dense(hl_sparse_matrix_s A_d, real *B_d, int dimM,
}
}
-void hl_matrix_csr_add_dense(hl_sparse_matrix_s A_d, real* B_d, int dimM,
- int dimN, real alpha, real beta) {
+void hl_matrix_csr_add_dense(hl_sparse_matrix_s A_d,
+ real *B_d,
+ int dimM,
+ int dimN,
+ real alpha,
+ real beta) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
@@ -1277,20 +1237,26 @@ void hl_matrix_csr_add_dense(hl_sparse_matrix_s A_d, real* B_d, int dimM,
gridX = gridX > 0 ? gridX : 1;
dim3 block(512, 1);
dim3 grid(gridX, dimM);
- KeSMatrixCsrAddDense<<>>(
- A_d2->csr_val, A_d2->csr_row, A_d2->csr_col, B_d, alpha, beta, dimM, dimN);
+ KeSMatrixCsrAddDense<<>>(A_d2->csr_val,
+ A_d2->csr_row,
+ A_d2->csr_col,
+ B_d,
+ alpha,
+ beta,
+ dimM,
+ dimN);
CHECK_SYNC("hl_sparse_matrix_add_dense failed");
}
-int* hl_sparse_matrix_get_rows(hl_sparse_matrix_s sMat) {
+int *hl_sparse_matrix_get_rows(hl_sparse_matrix_s sMat) {
__sparse_get_return__(sMat, row);
}
-int* hl_sparse_matrix_get_cols(hl_sparse_matrix_s sMat) {
+int *hl_sparse_matrix_get_cols(hl_sparse_matrix_s sMat) {
__sparse_get_return__(sMat, col);
}
-real* hl_sparse_matrix_get_value(hl_sparse_matrix_s sMat) {
+real *hl_sparse_matrix_get_value(hl_sparse_matrix_s sMat) {
__sparse_get_return__(sMat, val);
}
diff --git a/paddle/cuda/src/hl_perturbation_util.cu b/paddle/cuda/src/hl_perturbation_util.cu
index 2a945bcdb87fe49c121890128ef77b084ebe8e60..d01a91561efa2ebe8e0cabc2b4e8885f2c02ab48 100644
--- a/paddle/cuda/src/hl_perturbation_util.cu
+++ b/paddle/cuda/src/hl_perturbation_util.cu
@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
-
-#include
#include
-#include "hl_cuda.h"
-#include "hl_time.h"
+#include
#include "hl_base.h"
+#include "hl_cuda.h"
#include "hl_perturbation_util.cuh"
+#include "hl_time.h"
#define _USE_MATH_DEFINES
@@ -30,10 +29,16 @@ limitations under the License. */
* centerX, centerY: translation.
* sourceX, sourceY: output coordinates in the original image.
*/
-__device__ void getTranformCoord(int x, int y, real theta, real scale,
- real tgtCenter, real imgCenter,
- real centerR, real centerC,
- int* sourceX, int* sourceY) {
+__device__ void getTranformCoord(int x,
+ int y,
+ real theta,
+ real scale,
+ real tgtCenter,
+ real imgCenter,
+ real centerR,
+ real centerC,
+ int* sourceX,
+ int* sourceY) {
real H[4] = {cosf(-theta), -sinf(-theta), sinf(-theta), cosf(-theta)};
// compute coornidates in the rotated and scaled image
@@ -57,11 +62,17 @@ __device__ void getTranformCoord(int x, int y, real theta, real scale,
* created by Wei Xu (genome), converted by Jiang Wang
*/
-__global__ void kSamplingPatches(const real* imgs, real* targets,
- int imgSize, int tgtSize, const int channels,
- int samplingRate, const real* thetas,
- const real* scales, const int* centerRs,
- const int* centerCs, const real padValue,
+__global__ void kSamplingPatches(const real* imgs,
+ real* targets,
+ int imgSize,
+ int tgtSize,
+ const int channels,
+ int samplingRate,
+ const real* thetas,
+ const real* scales,
+ const int* centerRs,
+ const int* centerCs,
+ const real padValue,
const int numImages) {
const int caseIdx = blockIdx.x * 4 + threadIdx.x;
const int pxIdx = blockIdx.y * 128 + threadIdx.y;
@@ -80,8 +91,15 @@ __global__ void kSamplingPatches(const real* imgs, real* targets,
const int pxY = pxIdx / tgtSize;
int srcPxX, srcPxY;
- getTranformCoord(pxX, pxY, thetas[imgIdx], scales[imgIdx], tgtCenter,
- imgCenter, centerCs[caseIdx], centerRs[caseIdx], &srcPxX,
+ getTranformCoord(pxX,
+ pxY,
+ thetas[imgIdx],
+ scales[imgIdx],
+ tgtCenter,
+ imgCenter,
+ centerCs[caseIdx],
+ centerRs[caseIdx],
+ &srcPxX,
&srcPxY);
imgs += (imgIdx * imgPixels + srcPxY * imgSize + srcPxX) * channels;
@@ -100,10 +118,15 @@ __global__ void kSamplingPatches(const real* imgs, real* targets,
*
* created by Wei Xu
*/
-void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
- int*& gpuCenterR, int*& gpuCenterC,
- int numImages, int imgSize, real rotateAngle,
- real scaleRatio, int samplingRate,
+void hl_generate_disturb_params(real*& gpuAngle,
+ real*& gpuScaleRatio,
+ int*& gpuCenterR,
+ int*& gpuCenterC,
+ int numImages,
+ int imgSize,
+ real rotateAngle,
+ real scaleRatio,
+ int samplingRate,
bool isTrain) {
// The number of output samples.
int numPatches = numImages * samplingRate;
@@ -123,7 +146,8 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
for (int i = 0; i < numImages; i++) {
r_angle[i] =
(rotateAngle * M_PI / 180.0) * (rand() / (RAND_MAX + 1.0) // NOLINT
- - 0.5);
+ -
+ 0.5);
s_ratio[i] =
1 + (rand() / (RAND_MAX + 1.0) - 0.5) * scaleRatio; // NOLINT
}
@@ -140,8 +164,10 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
int pxY =
(int)(real(imgSize - 1) * rand() / (RAND_MAX + 1.0)); // NOLINT
- const real H[4] = {cos(-r_angle[i]), -sin(-r_angle[i]),
- sin(-r_angle[i]), cos(-r_angle[i])};
+ const real H[4] = {cos(-r_angle[i]),
+ -sin(-r_angle[i]),
+ sin(-r_angle[i]),
+ cos(-r_angle[i])};
real x = pxX - imgCenter;
real y = pxY - imgCenter;
real xx = H[0] * x + H[1] * y;
@@ -185,9 +211,12 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
delete[] center_c;
}
-void hl_conv_random_disturb_with_params(const real* images, int imgSize,
- int tgtSize, int channels,
- int numImages, int samplingRate,
+void hl_conv_random_disturb_with_params(const real* images,
+ int imgSize,
+ int tgtSize,
+ int channels,
+ int numImages,
+ int samplingRate,
const real* gpuRotationAngle,
const real* gpuScaleRatio,
const int* gpuCenterR,
@@ -202,29 +231,59 @@ void hl_conv_random_disturb_with_params(const real* images, int imgSize,
dim3 threadsPerBlock(4, 128);
dim3 numBlocks(DIVUP(numPatches, 4), DIVUP(targetSize, 128));
- kSamplingPatches <<>>
- (images, target, imgSize, tgtSize, channels, samplingRate,
- gpuRotationAngle, gpuScaleRatio, gpuCenterR, gpuCenterC,
- paddingValue, numImages);
+ kSamplingPatches<<>>(images,
+ target,
+ imgSize,
+ tgtSize,
+ channels,
+ samplingRate,
+ gpuRotationAngle,
+ gpuScaleRatio,
+ gpuCenterR,
+ gpuCenterC,
+ paddingValue,
+ numImages);
hl_device_synchronize();
}
-void hl_conv_random_disturb(const real* images, int imgSize,
- int tgtSize, int channels, int numImages,
- real scaleRatio, real rotateAngle,
- int samplingRate, real* gpu_r_angle,
- real* gpu_s_ratio, int* gpu_center_r,
- int* gpu_center_c, int paddingValue,
- bool isTrain, real* targets) {
+void hl_conv_random_disturb(const real* images,
+ int imgSize,
+ int tgtSize,
+ int channels,
+ int numImages,
+ real scaleRatio,
+ real rotateAngle,
+ int samplingRate,
+ real* gpu_r_angle,
+ real* gpu_s_ratio,
+ int* gpu_center_r,
+ int* gpu_center_c,
+ int paddingValue,
+ bool isTrain,
+ real* targets) {
// generate the random disturbance sequence and the sampling locations
- hl_generate_disturb_params(gpu_r_angle, gpu_s_ratio, gpu_center_r,
- gpu_center_c, numImages, imgSize, rotateAngle,
- scaleRatio, samplingRate, isTrain);
-
- hl_conv_random_disturb_with_params(
- images, imgSize, tgtSize, channels, numImages,
- samplingRate, gpu_r_angle, gpu_s_ratio,
- gpu_center_r, gpu_center_r, paddingValue,
- targets);
+ hl_generate_disturb_params(gpu_r_angle,
+ gpu_s_ratio,
+ gpu_center_r,
+ gpu_center_c,
+ numImages,
+ imgSize,
+ rotateAngle,
+ scaleRatio,
+ samplingRate,
+ isTrain);
+
+ hl_conv_random_disturb_with_params(images,
+ imgSize,
+ tgtSize,
+ channels,
+ numImages,
+ samplingRate,
+ gpu_r_angle,
+ gpu_s_ratio,
+ gpu_center_r,
+ gpu_center_r,
+ paddingValue,
+ targets);
}
diff --git a/paddle/cuda/src/hl_table_apply.cu b/paddle/cuda/src/hl_table_apply.cu
index 61edbe3ccc7028fd8779c4119f33c4cb5afe0564..d3b71c75e6e69d48c8d98041e3d6075aa8d53610 100644
--- a/paddle/cuda/src/hl_table_apply.cu
+++ b/paddle/cuda/src/hl_table_apply.cu
@@ -12,15 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
-
#include "hl_base.h"
-#include "hl_device_functions.cuh"
#include "hl_cuda.h"
+#include "hl_device_functions.cuh"
#include "paddle/utils/Logging.h"
-template
-__global__ void KeMatrixAddRows(real* output, int ldo,
- real* table, int ldt,
+template
+__global__ void KeMatrixAddRows(real* output,
+ int ldo,
+ real* table,
+ int ldt,
int* ids,
int numSamples,
int tableSize,
@@ -31,8 +32,8 @@ __global__ void KeMatrixAddRows(real* output, int ldo,
while (idy < numSamples) {
int tableId = ids[idy];
if ((0 <= tableId) && (tableId < tableSize)) {
- real *out = output + idy * ldo;
- real *tab = table + tableId * ldt;
+ real* out = output + idy * ldo;
+ real* tab = table + tableId * ldt;
for (int i = idx; i < dim; i += blockDimX) {
if (AddRow) {
paddle::paddleAtomicAdd(&tab[i], out[i]);
@@ -45,8 +46,10 @@ __global__ void KeMatrixAddRows(real* output, int ldo,
}
}
-void hl_matrix_select_rows(real* output, int ldo,
- real* table, int ldt,
+void hl_matrix_select_rows(real* output,
+ int ldo,
+ real* table,
+ int ldt,
int* ids,
int numSamples,
int tableSize,
@@ -57,14 +60,16 @@ void hl_matrix_select_rows(real* output, int ldo,
dim3 threads(128, 8);
dim3 grid(8, 1);
- KeMatrixAddRows<128, 8, 8, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
- (output, ldo, table, ldt, ids, numSamples, tableSize, dim);
+ KeMatrixAddRows<128, 8, 8, 0><<>>(
+ output, ldo, table, ldt, ids, numSamples, tableSize, dim);
CHECK_SYNC("hl_matrix_select_rows failed");
}
-void hl_matrix_add_to_rows(real* table, int ldt,
- real* input, int ldi,
+void hl_matrix_add_to_rows(real* table,
+ int ldt,
+ real* input,
+ int ldi,
int* ids,
int numSamples,
int tableSize,
@@ -75,16 +80,15 @@ void hl_matrix_add_to_rows(real* table, int ldt,
dim3 threads(128, 8);
dim3 grid(8, 1);
- KeMatrixAddRows<128, 8, 8, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
- (input, ldi, table, ldt, ids, numSamples, tableSize, dim);
+ KeMatrixAddRows<128, 8, 8, 1><<>>(
+ input, ldi, table, ldt, ids, numSamples, tableSize, dim);
CHECK_SYNC("hl_matrix_add_to_rows failed");
}
-template
-__global__ void KeVectorSelect(T* dst, int sized,
- const T* src, int sizes,
- const int* ids, int sizei) {
+template
+__global__ void KeVectorSelect(
+ T* dst, int sized, const T* src, int sizes, const int* ids, int sizei) {
int idx = threadIdx.x + blockDimX * blockIdx.x;
while (idx < sizei) {
int index = ids[idx];
@@ -95,9 +99,8 @@ __global__ void KeVectorSelect(T* dst, int sized,
}
template
-void hl_vector_select_from(T* dst, int sized,
- const T* src, int sizes,
- const int* ids, int sizei) {
+void hl_vector_select_from(
+ T* dst, int sized, const T* src, int sizes, const int* ids, int sizei) {
CHECK_NOTNULL(dst);
CHECK_NOTNULL(src);
CHECK_NOTNULL(ids);
@@ -105,18 +108,17 @@ void hl_vector_select_from(T* dst, int sized,
dim3 threads(512, 1);
dim3 grid(8, 1);
- KeVectorSelect<<< grid, threads, 0, STREAM_DEFAULT >>>
- (dst, sized, src, sizes, ids, sizei);
+ KeVectorSelect<<>>(
+ dst, sized, src, sizes, ids, sizei);
CHECK_SYNC("hl_vector_select_from failed");
}
-template
-void hl_vector_select_from(real* dst, int sized,
- const real* src, int sizes,
- const int* ids, int sizei);
-template
-void hl_vector_select_from(int* dst, int sized,
- const int* src, int sizes,
- const int* ids, int sizei);
-
+template void hl_vector_select_from(real* dst,
+ int sized,
+ const real* src,
+ int sizes,
+ const int* ids,
+ int sizei);
+template void hl_vector_select_from(
+ int* dst, int sized, const int* src, int sizes, const int* ids, int sizei);
diff --git a/paddle/cuda/src/hl_top_k.cu b/paddle/cuda/src/hl_top_k.cu
index 4f0bbfcf4e3aa51dd06acf254af65c62098a1df7..1896a56634c3a75e5a2a1e08661088b263f8ee10 100644
--- a/paddle/cuda/src/hl_top_k.cu
+++ b/paddle/cuda/src/hl_top_k.cu
@@ -12,45 +12,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
-
#include "hl_base.h"
-#include "hl_top_k.h"
#include "hl_sparse.ph"
+#include "hl_top_k.h"
#include "paddle/utils/Logging.h"
// using namespace hppl;
struct Pair {
- __device__ __forceinline__
- Pair() {}
+ __device__ __forceinline__ Pair() {}
- __device__ __forceinline__
- Pair(real value, int id) : v_(value), id_(id) {}
+ __device__ __forceinline__ Pair(real value, int id) : v_(value), id_(id) {}
- __device__ __forceinline__
- void set(real value, int id) {
+ __device__ __forceinline__ void set(real value, int id) {
v_ = value;
id_ = id;
}
- __device__ __forceinline__
- void operator=(const Pair& in) {
+ __device__ __forceinline__ void operator=(const Pair& in) {
v_ = in.v_;
id_ = in.id_;
}
- __device__ __forceinline__
- bool operator<(const real value) const {
+ __device__ __forceinline__ bool operator<(const real value) const {
return (v_ < value);
}
- __device__ __forceinline__
- bool operator<(const Pair& in) const {
+ __device__ __forceinline__ bool operator<(const Pair& in) const {
return (v_ < in.v_) || ((v_ == in.v_) && (id_ > in.id_));
}
- __device__ __forceinline__
- bool operator>(const Pair& in) const {
+ __device__ __forceinline__ bool operator>(const Pair& in) const {
return (v_ > in.v_) || ((v_ == in.v_) && (id_ < in.id_));
}
@@ -58,8 +50,9 @@ struct Pair {
int id_;
};
-__device__ __forceinline__
-void addTo(Pair topK[], const Pair &p, int beamSize) {
+__device__ __forceinline__ void addTo(Pair topK[],
+ const Pair& p,
+ int beamSize) {
for (int k = beamSize - 2; k >= 0; k--) {
if (topK[k] < p) {
topK[k + 1] = topK[k];
@@ -71,9 +64,8 @@ void addTo(Pair topK[], const Pair &p, int beamSize) {
topK[0] = p;
}
-template
-__device__ __forceinline__
-void addTo(Pair topK[], const Pair &p) {
+template
+__device__ __forceinline__ void addTo(Pair topK[], const Pair& p) {
for (int k = beamSize - 2; k >= 0; k--) {
if (topK[k] < p) {
topK[k + 1] = topK[k];
@@ -85,9 +77,9 @@ void addTo(Pair topK[], const Pair &p) {
topK[0] = p;
}
-template
-__device__ __forceinline__
-void getTopK(Pair topK[], real *src, int idx, int dim, int beamSize) {
+template
+__device__ __forceinline__ void getTopK(
+ Pair topK[], real* src, int idx, int dim, int beamSize) {
while (idx < dim) {
if (topK[beamSize - 1] < src[idx]) {
Pair tmp(src[idx], idx);
@@ -97,10 +89,9 @@ void getTopK(Pair topK[], real *src, int idx, int dim, int beamSize) {
}
}
-template
-__device__ __forceinline__
-void getTopK(Pair topK[], real *src, int idx, int dim,
- const Pair& max, int beamSize) {
+template
+__device__ __forceinline__ void getTopK(
+ Pair topK[], real* src, int idx, int dim, const Pair& max, int beamSize) {
while (idx < dim) {
if (topK[beamSize - 1] < src[idx]) {
Pair tmp(src[idx], idx);
@@ -112,10 +103,9 @@ void getTopK(Pair topK[], real *src, int idx, int dim,
}
}
-template
-__device__ __forceinline__
-void getTopK(Pair topK[], real *val, int *col,
- int idx, int dim, int beamSize) {
+template
+__device__ __forceinline__ void getTopK(
+ Pair topK[], real* val, int* col, int idx, int dim, int beamSize) {
while (idx < dim) {
if (topK[beamSize - 1] < val[idx]) {
Pair tmp(val[idx], col[idx]);
@@ -125,10 +115,14 @@ void getTopK(Pair topK[], real *val, int *col,
}
}
-template
-__device__ __forceinline__
-void getTopK(Pair topK[], real *val, int *col, int idx, int dim,
- const Pair& max, int beamSize) {
+template
+__device__ __forceinline__ void getTopK(Pair topK[],
+ real* val,
+ int* col,
+ int idx,
+ int dim,
+ const Pair& max,
+ int beamSize) {
while (idx < dim) {
if (topK[beamSize - 1] < val[idx]) {
Pair tmp(val[idx], col[idx]);
@@ -140,12 +134,16 @@ void getTopK(Pair topK[], real *val, int *col, int idx, int dim,
}
}
-template
-__device__ __forceinline__
-void threadGetTopK(Pair topK[], int& beam, int beamSize,
- real* src,
- bool& firstStep, bool& isEmpty, Pair& max,
- int dim, const int tid) {
+template
+__device__ __forceinline__ void threadGetTopK(Pair topK[],
+ int& beam,
+ int beamSize,
+ real* src,
+ bool& firstStep,
+ bool& isEmpty,
+ Pair& max,
+ int dim,
+ const int tid) {
if (beam > 0) {
int length = beam < beamSize ? beam : beamSize;
if (firstStep) {
@@ -160,8 +158,7 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
if (!isEmpty) {
- getTopK(topK + maxLength - beam, src, tid, dim,
- max, length);
+ getTopK(topK + maxLength - beam, src, tid, dim, max, length);
}
}
@@ -171,12 +168,17 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
-template
-__device__ __forceinline__
-void threadGetTopK(Pair topK[], int& beam, int beamSize,
- real* val, int* col,
- bool& firstStep, bool& isEmpty, Pair& max,
- int dim, const int tid) {
+template
+__device__ __forceinline__ void threadGetTopK(Pair topK[],
+ int& beam,
+ int beamSize,
+ real* val,
+ int* col,
+ bool& firstStep,
+ bool& isEmpty,
+ Pair& max,
+ int dim,
+ const int tid) {
if (beam > 0) {
int length = beam < beamSize ? beam : beamSize;
if (firstStep) {
@@ -191,8 +193,8 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
if (!isEmpty) {
- getTopK(topK + maxLength - beam, val, col, tid, dim,
- max, length);
+ getTopK(
+ topK + maxLength - beam, val, col, tid, dim, max, length);
}
}
@@ -202,12 +204,16 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
-template
-__device__ __forceinline__
-void blockReduce(Pair* shTopK, int* maxId, Pair topK[],
- real** topVal, int** topIds,
- int& beam, int& beamSize,
- const int tid, const int warp) {
+template
+__device__ __forceinline__ void blockReduce(Pair* shTopK,
+ int* maxId,
+ Pair topK[],
+ real** topVal,
+ int** topIds,
+ int& beam,
+ int& beamSize,
+ const int tid,
+ const int warp) {
while (true) {
__syncthreads();
if (tid < blockSize / 2) {
@@ -218,7 +224,7 @@ void blockReduce(Pair* shTopK, int* maxId, Pair topK[],
}
}
__syncthreads();
- for (int stride = blockSize / 4; stride > 0; stride = stride/2) {
+ for (int stride = blockSize / 4; stride > 0; stride = stride / 2) {
if (tid < stride) {
if (shTopK[maxId[tid]] < shTopK[maxId[tid + stride]]) {
maxId[tid] = maxId[tid + stride];
@@ -257,10 +263,12 @@ void blockReduce(Pair* shTopK, int* maxId, Pair topK[],
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
*/
-template
-__global__ void KeMatrixTopK(real* topVal, int ldv,
- int * topIds,
- real* src, int lds,
+template
+__global__ void KeMatrixTopK(real* topVal,
+ int ldv,
+ int* topIds,
+ real* src,
+ int lds,
int dim,
int beamSize) {
__shared__ Pair shTopK[blockSize];
@@ -271,7 +279,7 @@ __global__ void KeMatrixTopK(real* topVal, int ldv,
topVal += blockIdx.x * ldv;
topIds += blockIdx.x * beamSize;
- Pair topK[maxLength]; // NOLINT
+ Pair topK[maxLength]; // NOLINT
int beam = maxLength;
Pair max;
bool isEmpty = false;
@@ -281,18 +289,19 @@ __global__ void KeMatrixTopK(real* topVal, int ldv,
topK[k].set(-HL_FLOAT_MAX, -1);
}
while (beamSize) {
- threadGetTopK
- (topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
+ threadGetTopK(
+ topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
shTopK[tid] = topK[0];
- blockReduce
- (shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
+ blockReduce(
+ shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
}
}
-template
-__global__ void KeSMatrixTopK(real* topVal, int ldv,
- int * topIds,
+template
+__global__ void KeSMatrixTopK(real* topVal,
+ int ldv,
+ int* topIds,
real* val,
int* row,
int* col,
@@ -304,7 +313,7 @@ __global__ void KeSMatrixTopK(real* topVal, int ldv,
topVal += blockIdx.x * ldv;
topIds += blockIdx.x * beamSize;
- Pair topK[maxLength]; // NOLINT
+ Pair topK[maxLength]; // NOLINT
int beam = maxLength;
Pair max;
bool isEmpty = false;
@@ -330,18 +339,20 @@ __global__ void KeSMatrixTopK(real* topVal, int ldv,
topK[k].set(-HL_FLOAT_MAX, -1);
}
while (beamSize) {
- threadGetTopK
- (topK, beam, beamSize, val, col, firstStep, isEmpty, max, dim, tid);
+ threadGetTopK(
+ topK, beam, beamSize, val, col, firstStep, isEmpty, max, dim, tid);
shTopK[tid] = topK[0];
- blockReduce
- (shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
+ blockReduce(
+ shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
}
}
-void hl_matrix_top_k(real* topVal, int ldv,
- int * topIds,
- real* src, int lds,
+void hl_matrix_top_k(real* topVal,
+ int ldv,
+ int* topIds,
+ real* src,
+ int lds,
int dim,
int beamSize,
int numSamples) {
@@ -353,33 +364,32 @@ void hl_matrix_top_k(real* topVal, int ldv,
dim3 threads(256, 1);
dim3 grid(numSamples, 1);
- KeMatrixTopK<5, 256><<< grid, threads, 0, STREAM_DEFAULT >>>
- (topVal, ldv, topIds, src, lds, dim, beamSize);
+ KeMatrixTopK<5, 256><<>>(
+ topVal, ldv, topIds, src, lds, dim, beamSize);
CHECK_SYNC("hl_matrix_top_k failed");
}
-void hl_sparse_matrix_top_k(real* topVal, int ldv,
- int * topIds,
+void hl_sparse_matrix_top_k(real* topVal,
+ int ldv,
+ int* topIds,
hl_sparse_matrix_s src,
int beamSize,
int numSamples) {
CHECK_NOTNULL(topVal);
CHECK_NOTNULL(topIds);
CHECK_NOTNULL(src);
- CHECK_EQ(src->format, HL_SPARSE_CSR)
- <<"sparse matrix format error!";
+ CHECK_EQ(src->format, HL_SPARSE_CSR) << "sparse matrix format error!";
hl_csr_matrix csr = (hl_csr_matrix)src->matrix;
- if (csr->csr_val == NULL || csr->csr_row == NULL ||
- csr->csr_col == NULL) {
+ if (csr->csr_val == NULL || csr->csr_row == NULL || csr->csr_col == NULL) {
LOG(FATAL) << "parameter src is null!";
}
dim3 threads(256, 1);
dim3 grid(numSamples, 1);
- KeSMatrixTopK<5, 256><<< grid, threads, 0, STREAM_DEFAULT >>>
- (topVal, ldv, topIds, csr->csr_val, csr->csr_row, csr->csr_col, beamSize);
+ KeSMatrixTopK<5, 256><<>>(
+ topVal, ldv, topIds, csr->csr_val, csr->csr_row, csr->csr_col, beamSize);
CHECK_SYNC("hl_sparse_matrix_top_k failed");
}
@@ -392,10 +402,12 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv,
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
*/
-template
-__global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
- int * topIds,
- real* src, int lds,
+template
+__global__ void KeMatrixTopKClassificationError(real* topVal,
+ int ldv,
+ int* topIds,
+ real* src,
+ int lds,
int dim,
int beamSize,
int* label,
@@ -408,7 +420,7 @@ __global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
topVal += blockIdx.x * ldv;
topIds += blockIdx.x * beamSize;
- Pair topK[maxLength]; // NOLINT
+ Pair topK[maxLength]; // NOLINT
int beam = maxLength;
Pair max;
bool isEmpty = false;
@@ -420,34 +432,36 @@ __global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
}
while (beamSize) {
- threadGetTopK
- (topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
+ threadGetTopK(
+ topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
shTopK[tid] = topK[0];
- blockReduce
- (shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
+ blockReduce(
+ shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
}
__syncthreads();
if (tid == 0) {
for (int i = 0; i < topkSize; i++) {
- if (*--topIds == label[blockIdx.x]) {
- recResult[blockIdx.x] = 0;
- break;
- }
- recResult[blockIdx.x] = 1.0f;
+ if (*--topIds == label[blockIdx.x]) {
+ recResult[blockIdx.x] = 0;
+ break;
+ }
+ recResult[blockIdx.x] = 1.0f;
}
}
}
-void hl_matrix_classification_error(real* topVal, int ldv,
- int* topIds,
- real* src, int lds,
- int dim,
- int topkSize,
- int numSamples,
- int* label,
- real* recResult) {
+void hl_matrix_classification_error(real* topVal,
+ int ldv,
+ int* topIds,
+ real* src,
+ int lds,
+ int dim,
+ int topkSize,
+ int numSamples,
+ int* label,
+ real* recResult) {
CHECK_NOTNULL(topVal);
CHECK_NOTNULL(topIds);
CHECK_NOTNULL(src);
@@ -456,9 +470,8 @@ void hl_matrix_classification_error(real* topVal, int ldv,
dim3 threads(256, 1);
dim3 grid(numSamples, 1);
- KeMatrixTopKClassificationError<5, 256>
- <<< grid, threads, 0, STREAM_DEFAULT >>>
- (topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult);
+ KeMatrixTopKClassificationError<5, 256><<>>(
+ topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult);
CHECK_SYNC("hl_matrix_top_k classification error failed");
}
diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index e69c2ada5fd43518d15f3bca16f09b78c1a5b22d..d8012fba27bfca05e062e22d38d672bd395df7a6 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -12,13 +12,15 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
-proto_library(attr_type SRCS attr_type.proto)
-proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
-proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
+proto_library(attribute_proto SRCS attribute.proto)
+proto_library(op_proto SRCS op_proto.proto DEPS attribute_proto)
+proto_library(op_desc SRCS op_desc.proto DEPS attribute_proto)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
-cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope)
+cc_library(attribute SRCS attribute.cc DEPS op_desc op_proto)
+
+cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope attribute)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator)
@@ -26,7 +28,7 @@ cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
-py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
+py_proto_compile(framework_py_proto SRCS attribute.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc
new file mode 100644
index 0000000000000000000000000000000000000000..4c5790693b7e48396e945d09f4fdc72b86aa5978
--- /dev/null
+++ b/paddle/framework/attribute.cc
@@ -0,0 +1,85 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/framework/attribute.h"
+
+#include
+
+namespace paddle {
+namespace framework {
+
+template <>
+AttrType AttrTypeID() {
+ return INT;
+}
+template <>
+AttrType AttrTypeID() {
+ return FLOAT;
+}
+template <>
+AttrType AttrTypeID() {
+ return STRING;
+}
+template <>
+AttrType AttrTypeID>() {
+ return INTS;
+}
+template <>
+AttrType AttrTypeID