提交 58561d8f 编写于 作者: D dongzhihong

Merge remote-tracking branch 'origin/develop' into random_op

...@@ -36,8 +36,8 @@ include(simd) ...@@ -36,8 +36,8 @@ include(simd)
################################ Configurations ####################################### ################################ Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF) option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND})
option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF) option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND})
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
......
...@@ -74,8 +74,6 @@ if(WITH_MKLDNN) ...@@ -74,8 +74,6 @@ if(WITH_MKLDNN)
set(OPENMP_FLAGS "-fopenmp") set(OPENMP_FLAGS "-fopenmp")
set(CMAKE_C_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS}) set(CMAKE_C_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS})
set(CMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS}) set(CMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS})
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OPENMP_FLAGS}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OPENMP_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPENMP_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPENMP_FLAGS}")
else() else()
......
...@@ -42,29 +42,21 @@ macro(add_style_check_target TARGET_NAME) ...@@ -42,29 +42,21 @@ macro(add_style_check_target TARGET_NAME)
if(WITH_STYLE_CHECK) if(WITH_STYLE_CHECK)
set(SOURCES_LIST ${ARGN}) set(SOURCES_LIST ${ARGN})
list(REMOVE_DUPLICATES SOURCES_LIST) list(REMOVE_DUPLICATES SOURCES_LIST)
list(SORT SOURCES_LIST)
foreach(filename ${SOURCES_LIST}) foreach(filename ${SOURCES_LIST})
set(LINT ON)
foreach(pattern ${IGNORE_PATTERN}) foreach(pattern ${IGNORE_PATTERN})
if(filename MATCHES ${pattern}) if(filename MATCHES ${pattern})
message(STATUS "DROP LINT ${filename}") list(REMOVE_ITEM SOURCES_LIST ${filename})
set(LINT OFF)
endif() endif()
endforeach() endforeach()
if(LINT MATCHES ON)
# cpplint code style
get_filename_component(base_filename ${filename} NAME)
set(CUR_GEN ${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.cpplint)
add_custom_command(OUTPUT ${CUR_GEN} PRE_BUILD
COMMAND "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py"
"--filter=${STYLE_FILTER}"
"--write-success=${CUR_GEN}" ${filename}
DEPENDS ${filename} ${PROJ_ROOT}/paddle/scripts/cpplint.py
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${base_filename}.cpplint DEPENDS ${CUR_GEN})
add_dependencies(${TARGET_NAME} ${base_filename}.cpplint)
endif()
endforeach() endforeach()
if(SOURCES_LIST)
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py"
"--filter=${STYLE_FILTER}"
${SOURCES_LIST}
COMMENT "cpplint: Checking source code style"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endif() endif()
endmacro() endmacro()
...@@ -7,7 +7,7 @@ INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/extern_lib_any) ...@@ -7,7 +7,7 @@ INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/extern_lib_any)
ExternalProject_Add( ExternalProject_Add(
extern_lib_any extern_lib_any
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/thelink2012/any.git" GIT_REPOSITORY "https://github.com/PaddlePaddle/any.git"
GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020" GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020"
PREFIX ${ANY_SOURCE_DIR} PREFIX ${ANY_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
......
...@@ -28,7 +28,14 @@ INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR}) ...@@ -28,7 +28,14 @@ INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR})
ExternalProject_Add( ExternalProject_Add(
extern_gflags extern_gflags
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/gflags/gflags.git" # TODO(yiwang): The annoying warnings mentioned in
# https://github.com/PaddlePaddle/Paddle/issues/3277 are caused by
# gflags. I fired a PR https://github.com/gflags/gflags/pull/230
# to fix it. Before it gets accepted by the gflags team, we use
# my personal fork, which contains above fix, temporarily. Let's
# change this back to the official Github repo once my PR is
# merged.
GIT_REPOSITORY "https://github.com/wangkuiyi/gflags.git"
PREFIX ${GFLAGS_SOURCES_DIR} PREFIX ${GFLAGS_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -69,8 +69,13 @@ ENDIF(NOT ${CBLAS_FOUND}) ...@@ -69,8 +69,13 @@ ENDIF(NOT ${CBLAS_FOUND})
MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}") MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}")
INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) INCLUDE_DIRECTORIES(${CBLAS_INC_DIR})
ADD_LIBRARY(cblas STATIC IMPORTED) # FIXME(gangliao): generate cblas target to track all high performance
SET_PROPERTY(TARGET cblas PROPERTY IMPORTED_LOCATION ${CBLAS_LIBRARIES}) # linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas)
SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c)
FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
ADD_LIBRARY(cblas STATIC ${dummyfile})
TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
IF(NOT ${CBLAS_FOUND}) IF(NOT ${CBLAS_FOUND})
ADD_DEPENDENCIES(cblas extern_openblas) ADD_DEPENDENCIES(cblas extern_openblas)
LIST(APPEND external_project_dependencies cblas) LIST(APPEND external_project_dependencies cblas)
......
...@@ -115,7 +115,7 @@ set(COMMON_FLAGS ...@@ -115,7 +115,7 @@ set(COMMON_FLAGS
-Wno-error=literal-suffix -Wno-error=literal-suffix
-Wno-error=sign-compare -Wno-error=sign-compare
-Wno-error=unused-local-typedefs -Wno-error=unused-local-typedefs
-Wno-error=parentheses-equality # Warnings in Pybind11 -Wno-error=parentheses-equality # Warnings in pybind11
) )
set(GPU_COMMON_FLAGS set(GPU_COMMON_FLAGS
...@@ -195,6 +195,7 @@ endif() ...@@ -195,6 +195,7 @@ endif()
# Modern gpu architectures: Pascal # Modern gpu architectures: Pascal
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0") if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60") list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
list(APPEND CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
endif() endif()
# Custom gpu architecture # Custom gpu architecture
......
...@@ -403,3 +403,16 @@ function(py_proto_compile TARGET_NAME) ...@@ -403,3 +403,16 @@ function(py_proto_compile TARGET_NAME)
protobuf_generate_python(py_srcs ${py_proto_compile_SRCS}) protobuf_generate_python(py_srcs ${py_proto_compile_SRCS})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs}) add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs})
endfunction() endfunction()
function(py_test TARGET_NAME)
if(WITH_TESTING)
set(options STATIC static SHARED shared)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND env PYTHONPATH=${PADDLE_PYTHON_PACKAGE_DIR}
python2 ${py_test_SRCS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endfunction()
# Intel® MKL-DNN on PaddlePaddle: Design Doc
我们计划将Intel深度神经网络数学库(**MKL-DNN**\[[1](#references)\])集成到PaddlePaddle,充分展现英特尔平台的优势,有效提升PaddlePaddle在英特尔架构上的性能。
我们短期内的基本目标是:
- 完成常用layer的MKL-DNN实现。
- 完成常见深度神经网络VGG,GoogLeNet 和 ResNet的MKL-DNN实现。
## Contents
- [Overview](#overview)
- [Actions](#actions)
- [CMake](#cmake)
- [Layers](#layers)
- [Activations](#activations)
- [Unit Tests](#unit-tests)
- [Protobuf Messages](#protobuf-messages)
- [Python API](#python-api)
- [Demos](#demos)
- [Benchmarking](#benchmarking)
- [Others](#others)
- [Design Concerns](#design-concerns)
## Overview
我们会把MKL-DNN作为第三方库集成进PaddlePaddle,整体框架图
<div align="center">
<img src="image/overview.png" width=350><br/>
Figure 1. PaddlePaddle on IA.
</div>
## Actions
我们把集成方案大致分为了如下几个方面。
### CMake
我们会在`CMakeLists.txt`中会添加`WITH_MKLDNN`的选项,当设置这个值为`ON`的时候会启用编译MKL-DNN功能。同时会自动开启OpenMP用于提高MKL-DNN的性能。
同时,我们会引入`WITH_MKLML`选项,用于选择是否使用MKL-DNN自带的MKLML安装包。这个安装包可以独立于MKL-DNN使用,但是建议在开启MKL-DNN的同时也打开MKLML的开关,这样才能发挥最好的性能。
所以,我们会在`cmake/external`目录新建`mkldnn.cmake``mklml.cmake`文件,它们会在编译PaddlePaddle的时候下载对应的软件包,并放到PaddlePaddle的third party目录中。
**备注**:当`WITH_MKLML=ON`的时候,会优先使用这个包作为PaddlePaddle的CBLAS和LAPACK库,所以会稍微改动`cmake/cblas.cmake`中的逻辑。
### Layers
所有MKL-DNN相关的C++ layers,都会按照PaddlePaddle的目录结构存放在
`paddle/gserver/layers`中,并且文件名都会一以*Mkldnn*开头。
所有MKL-DNN的layers都会继承于一个叫做`MkldnnLayer`的父类,该父类继承于PaddlePaddle的基类`Layer`
### Activations
由于在PaddlePaddle中,激活函数是独立于layer概念的,所以会在`paddle/gserver/activations`目录下添加一个`MkldnnActivation.h`文件定义一些用于MKL-DNN的接口,实现方法还是会在`ActivationFunction.cpp`文件。
### Unit Tests
会在`paddle/gserver/test`目录下添加`test_Mkldnn.cpp``MkldnnTester.*`用于MKL-DNN的测试。
Activation的测试,计划在PaddlePaddle原有的测试文件上直接添加新的测试type。
### Protobuf Messages
根据具体layer的需求可能会在`proto/ModelConfig.proto`里面添加必要的选项。
### Python API
目前只考虑**v1 API**
计划在`python/paddle/trainer/config_parser.py`里面添加`use_mkldnn`这个选择,方便用户选择使用MKL-DNN的layers。
具体实现方式比如:
```python
use_mkldnn = bool(int(g_command_config_args.get("use_mkldnn", 0)))
if use_mkldnn
self.layer_type = mkldnn_*
```
所有MKL-DNN的layer type会以*mkldnn_*开头,以示区分。
并且可能在`python/paddle/trainer_config_helper`目录下的`activations.py ``layers.py`里面添加必要的MKL-DNN的接口。
### Demos
会在`v1_api_demo`目录下添加一个`mkldnn`的文件夹,里面放入一些用于MKL-DNN测试的demo脚本。
### Benchmarking
会考虑添加部分逻辑在`benchmark/paddle/image/run.sh`,添加使用MKL-DNN的测试。
### Others
1. 如果在使用MKL-DNN的情况下,会把CPU的Buffer对齐为64。
2. 深入PaddlePaddle,寻找有没有其他可以优化的可能,进一步优化。比如可能会用OpenMP改进SGD的更新性能。
## Design Concerns
为了更好的符合PaddlePaddle的代码风格\[[2](#references)\],同时又尽可能少的牺牲MKL-DNN的性能\[[3](#references)\]
我们总结出一些特别需要注意的点:
1. 使用**deviceId_**。为了尽可能少的在父类Layer中添加变量或者函数,我们决定使用已有的`deviceId_`变量来区分layer的属性,定义`-2``MkldnnLayer`特有的设备ID。
2. 重写父类Layer的**init**函数,修改`deviceId_``-2`,代表这个layer是用于跑在MKL-DNN的环境下。
3. 创建`MkldnnMatrix`,用于管理MKL-DNN会用到的相关memory函数、接口以及会用的到格式信息。
4. 创建`MkldnnBase`,定义一些除了layer和memory相关的类和函数。包括MKL-DNN会用到`MkldnnStream``CpuEngine`,和未来可能还会用到`FPGAEngine`等。
5.**Argument**里添加两个`MkldnnMatrixPtr`,取名为`mkldnnValue``mkldnnGrad`,用于存放`MkldnnLayer`会用到的memory buffer。 并且添加函数cvt(会修改为一个更加合适的函数名),用于处理"CPU device"和"MKL-DNN device"之间memory的相互转化。
6. 在父类`Layer`中的`getOutput`函数中添加一段逻辑,用于判断`deviceId`,并针对device在MKL-DNN和CPU之间不统一的情况,做一个前期转换。 也就是调用`Argument`的cvt函数把output统一到需要的device上。
7. 在原来的`FLAGS`中添加一个`use_mkldnn`的flag,用于选择是否使用MKL-DNN的相关功能。
## References
1. [Intel Math Kernel Library for Deep Neural Networks (Intel MKL-DNN)](https://github.com/01org/mkl-dnn "Intel MKL-DNN")
2. [原来的方案](https://github.com/PaddlePaddle/Paddle/pull/3096)会引入**nextLayer**的信息。但是在PaddlePaddle中,无论是重构前的layer还是重构后的op,都不会想要知道next layer/op的信息。
3. MKL-DNN的高性能格式与PaddlePaddle原有的`NCHW`不同(PaddlePaddle中的CUDNN部分使用的也是`NCHW`,所以不存在这个问题),所以需要引入一个转换方法,并且只需要在必要的时候转换这种格式,才能更好的发挥MKL-DNN的性能。
add_python_test(test_swig_api py_test(testTrain SRCS testTrain.py)
testArguments.py testGradientMachine.py testMatrix.py testVector.py testTrain.py testTrainer.py) py_test(testMatrix SRCS testMatrix.py)
py_test(testVector SRCS testVector.py)
py_test(testTrainer SRCS testTrainer.py)
py_test(testArguments SRCS testArguments.py)
py_test(testGradientMachine SRCS testGradientMachine.py)
...@@ -22,14 +22,14 @@ namespace framework { ...@@ -22,14 +22,14 @@ namespace framework {
template <> template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_.get_eigen_device<Eigen::DefaultDevice>(); return *device_context_->get_eigen_device<Eigen::DefaultDevice>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice& Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_.get_eigen_device<Eigen::GpuDevice>(); return *device_context_->get_eigen_device<Eigen::GpuDevice>();
} }
#endif #endif
......
...@@ -174,7 +174,11 @@ class OperatorContext { ...@@ -174,7 +174,11 @@ class OperatorContext {
template <typename T> template <typename T>
T* Output(const size_t index) const { T* Output(const size_t index) const {
auto var = OutputVar(index); auto var = OutputVar(index);
PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index); PADDLE_ENFORCE(
var != nullptr,
"Output(%d) not be nullptr, which means variable [%s] does not "
"exist in scope",
index, op_.outputs_[index]);
return var->GetMutable<T>(); return var->GetMutable<T>();
} }
...@@ -252,7 +256,7 @@ struct EigenDeviceConverter<platform::GPUPlace> { ...@@ -252,7 +256,7 @@ struct EigenDeviceConverter<platform::GPUPlace> {
class ExecutionContext : public OperatorContext { class ExecutionContext : public OperatorContext {
public: public:
ExecutionContext(const OperatorBase* op, const Scope& scope, ExecutionContext(const OperatorBase* op, const Scope& scope,
const platform::DeviceContext& device_context) const platform::DeviceContext* device_context)
: OperatorContext(op, scope), device_context_(device_context) {} : OperatorContext(op, scope), device_context_(device_context) {}
template <typename PlaceType, template <typename PlaceType,
...@@ -262,7 +266,7 @@ class ExecutionContext : public OperatorContext { ...@@ -262,7 +266,7 @@ class ExecutionContext : public OperatorContext {
platform::Place GetPlace() const { return device_context_->GetPlace(); } platform::Place GetPlace() const { return device_context_->GetPlace(); }
const platform::DeviceContext& device_context_; const platform::DeviceContext* device_context_;
}; };
class OpKernel { class OpKernel {
...@@ -311,7 +315,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -311,7 +315,7 @@ class OperatorWithKernel : public OperatorBase {
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(this, scope, dev_ctx)); opKernel->Compute(ExecutionContext(this, scope, &dev_ctx));
} }
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......
...@@ -49,9 +49,7 @@ class NNPACKConvFunction : public ConvFunctionBase { ...@@ -49,9 +49,7 @@ class NNPACKConvFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
CHECK_EQ(groups_, (size_t)1);
algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo")); algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo"));
// algorithm_ = nnp_convolution_algorithm_auto;
transform_strategy_ = nnp_convolution_transform_strategy_compute; transform_strategy_ = nnp_convolution_transform_strategy_compute;
nnp_status status = nnp_initialize(); nnp_status status = nnp_initialize();
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
...@@ -67,8 +65,7 @@ public: ...@@ -67,8 +65,7 @@ public:
} }
} }
virtual void check(const BufferArgs& inputs, void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const BufferArgs& outputs) override {
const TensorShape& input = inputs[0].shape(); const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape(); const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape(); const TensorShape& output = outputs[0].shape();
...@@ -91,8 +88,8 @@ public: ...@@ -91,8 +88,8 @@ public:
size_t filterHeight = getFilterHeight(filter); size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter); size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1]; size_t outputChannels = output[1];
// size_t outputHeight = output[2]; size_t outputHeight = output[2];
// size_t outputWidth = output[3]; size_t outputWidth = output[3];
nnp_size inputSize = {.width = inputWidth, .height = inputHeight}; nnp_size inputSize = {.width = inputWidth, .height = inputHeight};
nnp_padding padding = {.top = (size_t)paddingH(), nnp_padding padding = {.top = (size_t)paddingH(),
...@@ -171,49 +168,58 @@ public: ...@@ -171,49 +168,58 @@ public:
} }
} }
size_t inputOffset = inputChannels / groups_ * inputHeight * inputWidth;
size_t outputOffset = outputChannels / groups_ * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
if (batchSize == 1) { if (batchSize == 1) {
nnp_status status = for (size_t g = 0; g < groups_; g++) {
nnp_convolution_inference(algorithm_, nnp_status status =
transform_strategy_, nnp_convolution_inference(algorithm_,
inputChannels, transform_strategy_,
outputChannels, inputChannels / groups_,
inputSize, outputChannels / groups_,
padding, inputSize,
kernelSize, padding,
outputSubsampling, kernelSize,
inputData, outputSubsampling,
filterData, inputData + inputOffset * g,
nullptr, /* bias */ filterData + filterOffset * g,
outputData, nullptr, /* bias */
bufferPtr, outputData + outputOffset * g,
sizePtr, bufferPtr,
nnp_activation_identity, sizePtr,
nullptr, nnp_activation_identity,
threadpool_, /* threadpool */ nullptr,
nullptr); threadpool_, /* threadpool */
CHECK_EQ(status, nnp_status_success); nullptr);
CHECK_EQ(status, nnp_status_success);
}
} else { } else {
// only supports stride = 1 for (size_t g = 0; g < groups_; g++) {
CHECK_EQ(strideH(), 1); // only supports stride = 1
CHECK_EQ(strideW(), 1); CHECK_EQ(strideH(), 1);
nnp_status status = nnp_convolution_output(algorithm_, CHECK_EQ(strideW(), 1);
batchSize, nnp_status status =
inputChannels, nnp_convolution_output(algorithm_,
outputChannels, batchSize,
inputSize, inputChannels / groups_,
padding, outputChannels / groups_,
kernelSize, inputSize,
inputData, padding,
filterData, kernelSize,
nullptr, /* bias */ inputData + inputOffset * g,
outputData, filterData + filterOffset * g,
bufferPtr, nullptr, /* bias */
sizePtr, outputData + outputOffset * g,
nnp_activation_identity, bufferPtr,
nullptr, sizePtr,
threadpool_, /* threadpool */ nnp_activation_identity,
nullptr); nullptr,
CHECK_EQ(status, nnp_status_success); threadpool_, /* threadpool */
nullptr);
CHECK_EQ(status, nnp_status_success);
}
} }
} }
......
...@@ -57,8 +57,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, ...@@ -57,8 +57,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
convGradFilterType = "GemmConvGradFilter"; convGradFilterType = "GemmConvGradFilter";
} }
if (FLAGS_use_nnpack) { if (FLAGS_use_nnpack && !isDeconv_) {
CHECK_EQ(isDeconv_, false);
createFunction(forward_, createFunction(forward_,
"NNPACKConv", "NNPACKConv",
FuncConfig() FuncConfig()
......
# gserver pacakge unittests # gserver pacakge unittests
file(GLOB_RECURSE GSERVER_HEADER RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.h")
file(GLOB_RECURSE GSERVER_SOURCES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cpp")
add_style_check_target(paddle_gserver ${GSERVER_SOURCES})
add_style_check_target(paddle_gserver ${GSERVER_HEADER})
################### test_ProtoDataProvider ############ ################### test_ProtoDataProvider ############
add_unittest_without_exec(test_ProtoDataProvider add_unittest_without_exec(test_ProtoDataProvider
test_ProtoDataProvider.cpp) test_ProtoDataProvider.cpp)
......
...@@ -20,8 +20,8 @@ namespace operators { ...@@ -20,8 +20,8 @@ namespace operators {
class AddOp : public OperatorWithKernel { class AddOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of AddOp must be two"); PADDLE_ENFORCE_EQ(ctx.InputSize(), 2);
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1);
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr, PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
"Inputs of AddOp must all be set"); "Inputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
......
...@@ -23,12 +23,16 @@ class MulOp : public OperatorWithKernel { ...@@ -23,12 +23,16 @@ class MulOp : public OperatorWithKernel {
PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs"); PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs");
auto dim0 = ctx.Input<Tensor>(0)->dims(); auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim1 = ctx.Input<Tensor>(1)->dims(); auto dim1 = ctx.Input<Tensor>(1)->dims();
PADDLE_ENFORCE(dim0.size() == 2 && dim1.size() == 2, PADDLE_ENFORCE_EQ(dim0.size(), 2,
"The input of mul op must be matrix"); "input X(%s) should be a tensor with 2 dims, a matrix",
PADDLE_ENFORCE( ctx.op_.Input("X"));
dim0[1] == dim1[0], PADDLE_ENFORCE_EQ(dim1.size(), 2,
"input Y(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("Y"));
PADDLE_ENFORCE_EQ(
dim0[1], dim1[0],
"First matrix's width must be equal with second matrix's height."); "First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "The mul op must take one output"); PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "The mul op takes only one output");
ctx.Output<Tensor>(0)->Resize({dim0[0], dim1[1]}); ctx.Output<Tensor>(0)->Resize({dim0[0], dim1[1]});
} }
}; };
......
...@@ -36,6 +36,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { ...@@ -36,6 +36,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
InitMemories(step_scopes[0], true /*infer_shape_mode*/); InitMemories(step_scopes[0], true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (size_t i = 0; i < seq_len_; i++) { for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) { if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1, rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
...@@ -56,6 +57,7 @@ void RecurrentAlgorithm::Run(const Scope& scope, ...@@ -56,6 +57,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
Variable* net = scope.FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
for (size_t step_id = 0; step_id < seq_len_; step_id++) { for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// create output alias variables
if (step_id > 0) { if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
...@@ -67,22 +69,31 @@ void RecurrentAlgorithm::Run(const Scope& scope, ...@@ -67,22 +69,31 @@ void RecurrentAlgorithm::Run(const Scope& scope,
} }
void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// TODO(xxx) Only two scopes are needed for inference, this case will be // TODO(superjom) Only two scopes are needed for inference, this case will be
// supported later. // supported later.
auto step_scopes = auto step_scopes_var = scope.FindVar(arg_->step_scopes);
scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>(); PADDLE_ENFORCE(step_scopes_var != nullptr, "");
auto step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
// Now all variables in scope must be created outside of op.
auto net_var = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope",
arg_->step_net);
auto net_op = net_var->GetMutable<NetOp>();
PADDLE_ENFORCE(!net_op->outputs_.empty(), "net_op has no outputs");
if (seq_len_ > step_scopes->size()) { if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope(); auto& step_scope = scope.NewScope();
// Now all variables in scope must be created outside of op. // create step net's temp inputs
auto net_op = scope.FindVar(arg_->step_net)->GetMutable<NetOp>();
for (auto& input : net_op->inputs_) { for (auto& input : net_op->inputs_) {
// the weight are located in parent scope // the weight are located in parent scope
if (!step_scope.FindVar(input)) step_scope.NewVar(input); if (!step_scope.FindVar(input))
step_scope.NewVar(input)->GetMutable<Tensor>();
} }
for (auto& output : net_op->outputs_) { // create stepnet's outputs
for (const auto& output : net_op->outputs_) {
step_scope.NewVar(output); step_scope.NewVar(output);
} }
step_scopes->emplace_back(&step_scope); step_scopes->emplace_back(&step_scope);
...@@ -100,6 +111,7 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope, ...@@ -100,6 +111,7 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>(); Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>();
if (infer_shape_mode) { if (infer_shape_mode) {
pre_mem->Resize(boot_mem->dims()); pre_mem->Resize(boot_mem->dims());
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
} else { } else {
pre_mem->ShareDataWith<float>(*boot_mem); pre_mem->ShareDataWith<float>(*boot_mem);
} }
......
...@@ -53,11 +53,13 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes, ...@@ -53,11 +53,13 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.", PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
outlinks[i].external); outlinks[i].external);
Tensor* output = output_var->GetMutable<Tensor>(); Tensor* output = output_var->GetMutable<Tensor>();
if (infer_shape_mode) { if (infer_shape_mode) {
fmw::DDim step_dims = step_scopes[0] auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
->FindVar(outlinks[i].internal) PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
->GetMutable<Tensor>() outlinks[i].internal);
->dims(); fmw::DDim step_dims =
step_scope_var->template GetMutable<Tensor>()->dims();
std::vector<int> dims_vec = vectorize(step_dims); std::vector<int> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len); dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(fmw::make_ddim(dims_vec)); output->Resize(fmw::make_ddim(dims_vec));
...@@ -79,14 +81,15 @@ void LinkMemories(const std::vector<Scope*>& scopes, ...@@ -79,14 +81,15 @@ void LinkMemories(const std::vector<Scope*>& scopes,
const std::vector<rnn::MemoryAttr>& memories, const std::vector<rnn::MemoryAttr>& memories,
const size_t step_id, const int offset, const size_t step_id, const int offset,
bool infer_shape_mode) { bool infer_shape_mode) {
PADDLE_ENFORCE(step_id < scopes.size(), PADDLE_ENFORCE_LT(step_id, scopes.size(),
"step [%d] is out of range of step scopes' size [%d]", step_id, "step [%d] is out of range of step scopes' size [%d]",
scopes.size()); step_id, scopes.size());
PADDLE_ENFORCE(static_cast<int>(step_id) + offset >= 0, PADDLE_ENFORCE_GE(static_cast<int>(step_id) + offset, 0,
"offset [%d] must be large than -[%d]", offset, step_id); "offset [%d] must be large than -[%d]", offset, step_id);
PADDLE_ENFORCE(step_id + offset < scopes.size(), PADDLE_ENFORCE_LT(
"offset [%d] is out of range, it must be less than (%d - %d)", step_id + offset, scopes.size(),
offset, scopes.size(), step_id); "offset [%d] is out of range, it must be less than (%d - %d)", offset,
scopes.size(), step_id);
auto scope = scopes[step_id]; auto scope = scopes[step_id];
auto linked_scope = scopes[step_id + offset]; auto linked_scope = scopes[step_id + offset];
for (auto& attr : memories) { for (auto& attr : memories) {
......
...@@ -37,10 +37,8 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker { ...@@ -37,10 +37,8 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
class SigmoidOpGrad : public OperatorWithKernel { class SigmoidOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override {} void InferShape(const InferShapeContext &ctx) const override {
std::string DebugString() const override { ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
LOG(INFO) << "SigmoidGrad";
return "";
} }
}; };
...@@ -51,3 +49,5 @@ REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker); ...@@ -51,3 +49,5 @@ REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad); REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sigmoid_grad,
ops::SigmoidGradKernel<ops::CPUPlace, float>);
...@@ -16,3 +16,5 @@ ...@@ -16,3 +16,5 @@
#include "paddle/operators/sigmoid_op.h" #include "paddle/operators/sigmoid_op.h"
REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sigmoid_grad,
ops::SigmoidGradKernel<ops::GPUPlace, float>);
...@@ -27,6 +27,7 @@ class SigmoidKernel : public OpKernel { ...@@ -27,6 +27,7 @@ class SigmoidKernel : public OpKernel {
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
// The clipping is used in Paddle's raw implenmention
auto X = EigenVector<T>::Flatten(*input); auto X = EigenVector<T>::Flatten(*input);
auto Y = EigenVector<T>::Flatten(*output); auto Y = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
...@@ -34,5 +35,23 @@ class SigmoidKernel : public OpKernel { ...@@ -34,5 +35,23 @@ class SigmoidKernel : public OpKernel {
Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp());
} }
}; };
template <typename Place, typename T>
class SigmoidGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto Y_t = context.Input<Tensor>("Y");
auto dY_t = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX_t = context.Output<Tensor>(framework::GradVarName("X"));
dX_t->mutable_data<T>(context.GetPlace());
auto dX = EigenVector<T>::Flatten(*dX_t);
auto Y = EigenVector<T>::Flatten(*Y_t);
auto dY = EigenVector<T>::Flatten(*dY_t);
dX.device(context.GetEigenDevice<Place>()) = dY * Y * (1. - Y);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
add_python_test(test_framework py_test(test_net SRCS test_net.py)
test_protobuf.py
test_scope.py py_test(test_fc_op SRCS test_fc_op.py)
test_default_scope_funcs.py py_test(test_scope SRCS test_scope.py)
test_op_creation_methods.py
test_net.py py_test(test_tensor SRCS test_tensor.py)
test_tensor.py py_test(test_mul_op SRCS test_mul_op.py)
test_fc_op.py
test_add_two_op.py py_test(test_network SRCS test_network.py)
test_sgd_op.py py_test(test_mean_op SRCS test_mean_op.py)
test_mul_op.py
test_mean_op.py py_test(test_protobuf SRCS test_protobuf.py)
test_sigmoid_op.py
test_softmax_op.py py_test(test_add_two_op SRCS test_add_two_op.py)
test_rowwise_add_op.py py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
test_random_op.py py_test(test_softmax_op SRCS test_softmax_op.py)
test_network.py
gradient_checker.py) py_test(gradient_checker SRCS gradient_checker.py)
py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_op_creation_methods SRCS test_op_creation_methods.py)
...@@ -33,23 +33,28 @@ class OpTestMeta(type): ...@@ -33,23 +33,28 @@ class OpTestMeta(type):
for place in places: for place in places:
for in_name in func.all_input_args: for in_name in func.all_input_args:
if hasattr(self, in_name): if hasattr(self, "inputs") and in_name in self.inputs:
kwargs[in_name] = in_name kwargs[in_name] = in_name
var = scope.new_var(in_name).get_tensor() var = scope.new_var(in_name).get_tensor()
arr = getattr(self, in_name) arr = self.inputs[in_name]
var.set_dims(arr.shape) var.set_dims(arr.shape)
var.set(arr, place) var.set(arr, place)
else: else:
kwargs[in_name] = "@EMPTY@" kwargs[in_name] = "@EMPTY@"
for out_name in func.all_output_args: for out_name in func.all_output_args:
if hasattr(self, out_name): if not hasattr(self, "outputs"):
kwargs[out_name] = out_name raise ValueError(
scope.new_var(out_name).get_tensor() "The test op must set self.outputs dict.")
if out_name not in self.outputs:
raise ValueError("The %s is not in self.outputs dict." %
(out_name))
kwargs[out_name] = out_name
scope.new_var(out_name).get_tensor()
for attr_name in func.all_attr_args: for attr_name in func.all_attr_args:
if hasattr(self, attr_name): if hasattr(self, "attrs") and attr_name in self.attrs:
kwargs[attr_name] = getattr(self, attr_name) kwargs[attr_name] = self.attrs[attr_name]
op = func(**kwargs) op = func(**kwargs)
...@@ -60,7 +65,7 @@ class OpTestMeta(type): ...@@ -60,7 +65,7 @@ class OpTestMeta(type):
for out_name in func.all_output_args: for out_name in func.all_output_args:
actual = numpy.array(scope.find_var(out_name).get_tensor()) actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = getattr(self, out_name) expect = self.outputs[out_name]
numpy.isclose(actual, expect) numpy.isclose(actual, expect)
obj.test_all = test_all obj.test_all = test_all
......
...@@ -12,9 +12,11 @@ class TestAddOp(unittest.TestCase): ...@@ -12,9 +12,11 @@ class TestAddOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "add_two" self.type = "add_two"
self.X = numpy.random.random((102, 105)).astype("float32") self.inputs = {
self.Y = numpy.random.random((102, 105)).astype("float32") 'X': numpy.random.random((102, 105)).astype("float32"),
self.Out = self.X + self.Y 'Y': numpy.random.random((102, 105)).astype("float32")
}
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
class TestAddGradOp(unittest.TestCase): class TestAddGradOp(unittest.TestCase):
......
...@@ -7,15 +7,17 @@ class TestSGD(unittest.TestCase): ...@@ -7,15 +7,17 @@ class TestSGD(unittest.TestCase):
__metaclass__ = OpTestMeta __metaclass__ = OpTestMeta
def setUp(self): def setUp(self):
# TODO this unit test is not passed
self.type = "onehot_cross_entropy" self.type = "onehot_cross_entropy"
batch_size = 100 batch_size = 100
class_num = 10 class_num = 10
self.X = numpy.random.random((batch_size, class_num)).astype("float32") X = numpy.random.random((batch_size, class_num)).astype("float32")
self.label = 5 * numpy.ones(batch_size).astype("int32") label = 5 * numpy.ones(batch_size).astype("int32")
self.inputs = {'X': X, 'label': label}
Y = [] Y = []
for i in range(0, batch_size): for i in range(0, batch_size):
Y.append(-numpy.log(self.X[i][self.label[i]])) Y.append(-numpy.log(X[i][label[i]]))
self.Y = numpy.array(Y).astype("float32") self.outputs = {'Y': numpy.array(Y).astype("float32")}
# TODO(superjom) add gradient check # TODO(superjom) add gradient check
......
...@@ -8,8 +8,8 @@ class TestMeanOp(unittest.TestCase): ...@@ -8,8 +8,8 @@ class TestMeanOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "mean" self.type = "mean"
self.X = np.random.random((32, 784)).astype("float32") self.inputs = {'X': np.random.random((32, 784)).astype("float32")}
self.Out = np.mean(self.X) self.outputs = {'Out': np.mean(self.inputs['X'])}
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,9 +8,11 @@ class TestMulOp(unittest.TestCase): ...@@ -8,9 +8,11 @@ class TestMulOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "mul" self.type = "mul"
self.X = np.random.random((32, 84)).astype("float32") self.inputs = {
self.Y = np.random.random((84, 100)).astype("float32") 'X': np.random.random((32, 84)).astype("float32"),
self.Out = np.dot(self.X, self.Y) 'Y': np.random.random((84, 100)).astype("float32")
}
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
if __name__ == '__main__': if __name__ == '__main__':
......
import logging
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import unittest import unittest
import numpy as np import numpy as np
...@@ -7,10 +8,9 @@ ops = creation.op_creations ...@@ -7,10 +8,9 @@ ops = creation.op_creations
def create_tensor(scope, name, shape): def create_tensor(scope, name, shape):
tensor = scope.create_var(name).get_tensor() tensor = scope.new_var(name).get_tensor()
tensor.set_dims(shape) tensor.set_dims(shape)
tensor.alloc_float() tensor.set(np.random.random(shape), core.CPUPlace())
tensor.set(np.random.random(shape))
return tensor return tensor
...@@ -31,40 +31,36 @@ class TestRNN(unittest.TestCase): ...@@ -31,40 +31,36 @@ class TestRNN(unittest.TestCase):
- h - h
''' '''
input_dim = 30
batch_size = 50
weight_dim = 15
sent_len = 11
def init(self): def init(self):
input_dim = 30
batch_size = 50
weight_dim = 15
self.scope = core.Scope(None)
# create vars
create_tensor(self.scope, "x", [batch_size, input_dim])
create_tensor(self.scope, "W", [input_dim, weight_dim])
create_tensor(self.scope, "U", [weight_dim, weight_dim])
create_tensor(self.scope, "h_boot", [batch_size, weight_dim])
x_alias = "x@alias"
y_alias = "y@alias"
memory = "h@alias"
prememory = "h@pre"
output = "rnn_out"
output_alias = "rnn_out@alias"
# create step net
stepnet_var = self.scope.create_var("stepnet")
stepnet = stepnet_var.get_net()
# stepnet = core.Net.create()
x_fc_op = ops.fc(X=x_alias, W="W", Y="Wx")
h_fc_op = ops.fc(X=prememory, W="U", Y="Uh")
sum_op = ops.add_two(X="Wx", Y="Uh", Out="sum")
sig_op = ops.sigmoid(X="sum", Y=memory)
stepnet.add_op(x_fc_op)
stepnet.add_op(h_fc_op)
stepnet.add_op(sum_op)
stepnet.add_op(sig_op)
stepnet.complete_add_op(True)
self.scope = core.Scope()
self.create_global_variables()
self.create_step_net()
rnn_op = self.create_rnn_op()
ctx = core.DeviceContext.create(core.CPUPlace())
print 'infer_shape'
rnn_op.infer_shape(self.scope)
rnn_op.run(self.scope, ctx)
def create_global_variables(self):
# create inlink
create_tensor(self.scope, "x",
[self.sent_len, self.batch_size, self.input_dim])
create_tensor(self.scope, "W", [self.input_dim, self.input_dim])
create_tensor(self.scope, "U", [self.input_dim, self.input_dim])
create_tensor(self.scope, "h_boot", [self.batch_size, self.input_dim])
self.scope.new_var("step_scopes")
self.scope.new_var("h@alias")
self.scope.new_var("h")
def create_rnn_op(self):
# create RNNOp # create RNNOp
rnnop = ops.recurrent_op( rnnop = ops.recurrent_op(
# inputs # inputs
...@@ -72,17 +68,27 @@ class TestRNN(unittest.TestCase): ...@@ -72,17 +68,27 @@ class TestRNN(unittest.TestCase):
boot_memories=["h_boot"], boot_memories=["h_boot"],
step_net="stepnet", step_net="stepnet",
# outputs # outputs
outlinks=[output], outlinks=["h"],
step_scopes="step_scopes", step_scopes="step_scopes",
# attributes # attributes
inlink_alias=["x@alias"], inlink_alias=["x@alias"],
outlink_alias=[output_alias], outlink_alias=["h@alias"],
pre_memories=[prememory], pre_memories=["h@pre"],
memories=[memory]) memories=["h@alias"])
return rnnop
def create_step_net(self):
var = self.scope.new_var("stepnet")
stepnet = var.get_net()
ctx = core.DeviceContext.cpu_context() x_fc_op = ops.fc(X="x@alias", W="W", Y="Wx")
rnnop.infer_shape(self.scope) h_fc_op = ops.fc(X="h@pre", W="U", Y="Uh")
rnnop.run(self.scope, ctx) sum_op = ops.add_two(X="Wx", Y="Uh", Out="sum")
sig_op = ops.sigmoid(X="sum", Y="h@alias")
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.add_op(op)
stepnet.complete_add_op(True)
def test_recurrent(self): def test_recurrent(self):
self.init() self.init()
......
...@@ -8,9 +8,11 @@ class TestRowwiseAddOp(unittest.TestCase): ...@@ -8,9 +8,11 @@ class TestRowwiseAddOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "rowwise_add" self.type = "rowwise_add"
self.X = np.random.random((32, 84)).astype("float32") self.inputs = {
self.b = np.random.random(84).astype("float32") 'X': np.random.random((32, 84)).astype("float32"),
self.Out = np.add(self.X, self.b) 'b': np.random.random(84).astype("float32")
}
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,10 +8,13 @@ class TestSGD(unittest.TestCase): ...@@ -8,10 +8,13 @@ class TestSGD(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "sgd" self.type = "sgd"
self.param = numpy.random.random((102, 105)).astype("float32") w = numpy.random.random((102, 105)).astype("float32")
self.grad = numpy.random.random((102, 105)).astype("float32") g = numpy.random.random((102, 105)).astype("float32")
self.learning_rate = 0.1 lr = 0.1
self.param_out = self.param - self.learning_rate * self.grad
self.inputs = {'param': w, 'grad': g}
self.attrs = {'learning_rate': lr}
self.outputs = {'param_out': w - lr * g}
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,9 +8,12 @@ class TestSigmoidOp(unittest.TestCase): ...@@ -8,9 +8,12 @@ class TestSigmoidOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "sigmoid" self.type = "sigmoid"
self.X = np.random.random((32, 100)).astype("float32") self.inputs = {'X': np.random.random((32, 100)).astype("float32")}
self.Y = 1 / (1 + np.exp(-self.X)) self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
#class TestSigmoidGradOp(unittest.TestCase):
#TODO(qingqing) add unit test
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,8 +19,10 @@ class TestSoftmaxOp(unittest.TestCase): ...@@ -19,8 +19,10 @@ class TestSoftmaxOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "softmax" self.type = "softmax"
self.X = np.random.random((32, 100)).astype("float32") self.inputs = {'X': np.random.random((32, 100)).astype("float32")}
self.Y = np.apply_along_axis(stable_softmax, 1, self.X) self.outputs = {
'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
}
class TestSoftmaxGradOp(unittest.TestCase): class TestSoftmaxGradOp(unittest.TestCase):
......
if (NOT APPLE) if (NOT APPLE)
# The Mac OS X backend will not be able to function correctly if Python is # The Mac OS X backend will not be able to function correctly if Python is
# not installed as a framework. # not installed as a framework.
add_python_test(test_ploter test_ploter.py) py_test(test_ploter SRCS test_ploter.py)
endif() endif()
add_python_test(reader_tests creator_test.py decorator_test.py) py_test(creator_test SRCS creator_test.py)
py_test(decorator_test SRCS decorator_test.py)
add_python_test(test_v2_api test_data_feeder.py test_op.py test_parameters.py py_test(test_op SRCS test_op.py)
test_layer.py test_rnn_layer.py test_topology.py test_image.py) py_test(test_image SRCS test_image.py)
py_test(test_layer SRCS test_layer.py)
py_test(test_topology SRCS test_topology.py)
py_test(test_rnn_layer SRCS test_rnn_layer.py)
py_test(test_parameters SRCS test_parameters.py)
py_test(test_data_feeder SRCS test_data_feeder.py)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册