diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index 0eeccbf7d8a1df17351c8914df6dabf005802787..0002a470d90f722e3f9106ca56d70e6bf2cea339 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -25,7 +25,12 @@ IF(NOT ${CBLAS_FOUND}) "${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE FILEPATH "openblas library." FORCE) - SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1 libs) + IF(APPLE) + SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -isysroot ${CMAKE_OSX_SYSROOT}") + SET(COMMON_ARGS CC=${OPENBLAS_CC} NO_SHARED=1 NO_LAPACK=1 libs) + ELSE() + SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1 libs) + ENDIF() IF(CMAKE_CROSSCOMPILING) IF(ANDROID) @@ -40,11 +45,11 @@ IF(NOT ${CBLAS_FOUND}) SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=${TARGET} ARM_SOFTFP_ABI=1 USE_THREAD=0) ELSEIF(RPI) # use hardfp - SET(OPENBLAS_COMMIT "v0.2.19") + SET(OPENBLAS_COMMIT "v0.2.20") SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 USE_THREAD=0) ENDIF() ELSE() - SET(OPENBLAS_COMMIT "v0.2.19") + SET(OPENBLAS_COMMIT "v0.2.20") SET(OPTIONAL_ARGS "") IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^x86(_64)?$") SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 NUM_THREADS=64) diff --git a/doc/design/functions_operators_layers.md b/doc/design/functions_operators_layers.md new file mode 100644 index 0000000000000000000000000000000000000000..7a2e8fd0ace2e3f4462b15215de22c31e944b7cb --- /dev/null +++ b/doc/design/functions_operators_layers.md @@ -0,0 +1,99 @@ +# Design Doc: Functions, Operators, and Layers + +In a DL system, we can compose one or more fine grained operators into a coarse grained one. For example, the FC layer can be composed of a multiplication operator and an add operator. + +Historically, some fine grained operations are known as operators, and some coarse level ones are known as layers. But we need a well-defined separation. + +In general, operators are those very fine grained operations, e.g., mul and add. In the implementation, we can write them as C++ functions: + +```c++ +template T add(T x, T y) { return x + y; } +template T mul(T x, T y) { return x * y; } +``` + +Then we can wrap them into operators which are C++ classes and can be created from Python bindings by name. A C macro can do this. For example, the following macro invocation + +```c++ +#define MAKE_FUNCTION_OPERATOR(mul); +``` + +generates + +```c++ +template class mulOp : public OperatorBase {...}; +REGISTER_OP(mulOp, "mul"); +``` + +so that in Python we can create operator mul by: + +```python +X1 = Var() +X2 = Var() +Y = Var() +paddle.cpp.create_operator("mul", input=[X1, X2], output=Y) +``` + +Also, at the same time, we can compose a coarse level C++ operator class by composing functions `mul` and `add`: + +```c++ +template +class FCOp : public OperatorBase { + public: + void Run(...) { + add(mul(Input("X"), Input("W")), Input("b"); + } +}; +REGISTER_OP(FCOp, "fc"); +``` + +We need to support such composition in Python as well. To do so, we need a higher level Python wrapping of operator creation than `paddle.cpp.create_operator`. This higher level operator API should be compatible with the layer API. + +Let's explain using an example. Suppose that we are going to compose the FC using mul and add in Python, we'd like to have Python functions `mul` and `add` defined in module `operator`: + +```python +def operator.mul(X1, X2): + O = Var() + paddle.cpp.create_operator("mul", input={X1, Y1], output=O) + return O + +def operator.add(X1, X2): + O = Var() + paddle.cpp.create_operator("add", input={X1, X2], output=O) + return O +``` + +Above code snippets are automatically generated. Given them, users can define + +```python +def layer.fc(X): + W = Var() + b = Var() + return operator.add(operator.mul(X, W), b) +``` + +If we don't have `operator.mul` and `operator.add`, the definiton of `layer.fc` would be complicated: + +```python +def layer.fc(X): + W = Var() + b = Var() + O1 = Var() + paddle.cpp.create_operator("mul", input=[X, W], output=O1) + O2 = Var() + paddle.cpp.create_operator("add", input=[O1, b], output=O2) + return O2 +``` + +We'd like to have Python bindings to operators in package `paddle.operator`, and Python compositions of operators in package `paddle.layer`. So we have the following concepts in above illustrative example: + +``` +| C++ functions/functors | mul | add | | | +| C++ operator class | mulOp | addOp | FCOp | | +| Python binding | operator.mul | operator.add | operator.fc | | +| Python function | | | | layer.fc | +``` + +This is how we differentiate layer and operators in PaddlePaddle: + +- those defined in C++ and have a lightweighted Python wrapper in module `operators` are operators; whereas +- those who don't have C++ implementations but a Python implementation that compose C++ operators are known as layers. diff --git a/doc/design/if_else_op.md b/doc/design/if_else_op.md new file mode 100644 index 0000000000000000000000000000000000000000..7370c2a24fa644a64e738f202bac9b9209642e08 --- /dev/null +++ b/doc/design/if_else_op.md @@ -0,0 +1,59 @@ +IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has M (M<=N) instances, each corresponds to a true element in `cond`. + +```python +import paddle as pd + +x = var() +y = var() +cond = var() + +b = pd.create_ifop(inputs=[x], output_num=1) +with b.true_block(): + x = b.inputs(0) + z = operator.add(x, y) + b.set_output(0, operator.softmax(z)) + +out = b(cond) +``` + +If we want the output still has N instances, we can use IfElseOp with a default value, whose minibatch size must be N: + +```python +import paddle as pd + +x = var() +y = var() +cond = var() +default_value = var() +b = pd.create_ifelseop(inputs=[x], output_num=1) +with b.true_block(): + x = b.inputs(0) + z = operator.add(x, y) + b.set_output(0, operator.softmax(z)) + +with b.false_block(): + x = b.inputs(0) + z = layer.fc(x) + b.set_output(0, operator.softmax(z)) + +out = b(cond) +``` + +If only true_block is set in an IfElseOp, we can have a default value for false as: +```python +import paddle as pd + +x = var() +y = var() +cond = var() +default_value = var() +b = pd.create_ifelseop(inputs=[x], output_num=1, default_value) + +with b.true_block(): + x = b.inputs(0) + z = operator.add(x, y) + b.set_output(0, operator.softmax(z)) + +out = b(cond) +``` +where default_value is a list of vars for `cond` == False. diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md index 7f8da2da5a0d42ff065265c5d173d0e6167dc08a..ec79b7f42b2d70df8fcb25faca5bc3a4759e177c 100644 --- a/doc/howto/dev/new_op_cn.md +++ b/doc/howto/dev/new_op_cn.md @@ -178,13 +178,13 @@ class MulKernel : public framework::OpKernel { ```c++ namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); REGISTER_OP_CPU_KERNEL(mul_grad, ops::MulGradKernel); ``` - - `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker`为`ops::MulOpMaker`,注册`ops::MulOpGrad`,类型名为`mul_grad`, + - `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker`为`ops::MulOpMaker`,并且注册`ops::MulOpGrad`为其反向Op。 - `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。 - `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::MulKernel`类。 diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 9f84db72da24b0e678520b077f9cba7ffc2d589a..6b56d9ec8d3daae96aaaa04ed79cb637331e2281 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -173,6 +173,96 @@ extern void hl_avgpool_backward(const int frameCnt, real* backGrad, const int outStride); +extern void hl_maxpool3D_forward(const int frameCnt, + const real* inputData, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int paddingD, + const int paddingH, + const int paddingW, + real* tgtData, + real* maxPoolIdxData, + const int tgtStride); + +extern void hl_maxpool3D_backward(const int frameCnt, + const real* outGrad, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int paddingD, + const int paddingH, + const int paddingW, + real scaleA, + real scaleB, + real* targetGrad, + real* maxPoolIdxData, + const int outStride); + +extern void hl_avgpool3D_forward(const int frameCnt, + const real* inputData, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int paddingD, + const int paddingH, + const int paddingW, + real* tgtData, + const int tgtStride); + +extern void hl_avgpool3D_backward(const int frameCnt, + const real* outGrad, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + int paddingD, + int paddingH, + int paddingW, + real scaleA, + real scaleB, + real* backGrad, + const int outStride); + /** * @brief Bilinear interpolation forward. * @@ -275,4 +365,4 @@ extern void hl_maxout_backward(real* inGrad, size_t featLen, size_t groups); -#endif /* HL_CNN_H_ */ +#endif // HL_CNN_H_ diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index 2bbb9fa8dfd5eeac9d55aa67a28ebfbffa2acd46..a76dbf0b6578de0606702ad1af227fbf6e1cd62e 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -87,6 +87,96 @@ inline void hl_avgpool_backward(const int frameCnt, real* backGrad, const int outStride) {} +inline void hl_maxpool3D_forward(const int frameCnt, + const real* inputData, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int paddingD, + const int paddingH, + const int paddingW, + real* tgtData, + real* maxPoolIdxData, + const int tgtStride) {} + +inline void hl_maxpool3D_backward(const int frameCnt, + const real* outGrad, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int paddingD, + const int paddingH, + const int paddingW, + real scaleA, + real scaleB, + real* targetGrad, + real* maxPoolIdxData, + const int outStride) {} + +inline void hl_avgpool3D_forward(const int frameCnt, + const real* inputData, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int paddingD, + const int paddingH, + const int paddingW, + real* tgtData, + const int tgtStride) {} + +inline void hl_avgpool3D_backward(const int frameCnt, + const real* outGrad, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int paddingD, + const int paddingH, + const int paddingW, + real scaleA, + real scaleB, + real* backGrad, + const int outStride) {} + inline void hl_bilinear_forward(const real* inData, const size_t inImgH, const size_t inImgW, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index aac19b1ea566ad69f1f7374e393676c8debd9883..9ba3d142617537c0160f6dccb86ddca43ada15a5 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -353,6 +353,433 @@ void hl_avgpool_backward(const int frameCnt, CHECK_SYNC("hl_avgpool_backward failed"); } +__global__ void KeMaxPool3DForward(const int nthreads, + const real* inputData, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int ksizeD, + const int ksizeH, + const int ksizeW, + const int strideD, + const int strideH, + const int strideW, + const int padD, + const int padH, + const int padW, + real* tgtData, + real* maxPoolIdxData, + const int tgtStride) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int pw = index % pooledW; + int ph = (index / pooledW) % pooledH; + int pd = (index / pooledW / pooledH) % pooledD; + int c = (index / pooledW / pooledH / pooledD) % channels; + int frameNum = index / pooledW / pooledH / pooledD / channels; + int dstart = pd * strideD - padD; + int hstart = ph * strideH - padH; + int wstart = pw * strideW - padW; + int dend = min(dstart + ksizeD, depth); + int hend = min(hstart + ksizeH, height); + int wend = min(wstart + ksizeW, width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + real maxval = -FLT_MAX; + int maxIdx = -1; + inputData += (frameNum * channels + c) * depth * height * width; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (maxval < inputData[(d * height + h) * width + w]) { + maxval = inputData[(d * height + h) * width + w]; + maxIdx = (d * height + h) * width + w; + } + } + } + } + int tgtIndex = + index % (pooledW * pooledH * pooledD * channels) + frameNum * tgtStride; + tgtData[tgtIndex] = maxval; + maxPoolIdxData[tgtIndex] = maxIdx; + } +} + +void hl_maxpool3D_forward(const int frameCnt, + const real* inputData, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int padD, + const int padH, + const int padW, + real* tgtData, + real* maxPoolIdxData, + const int tgtStride) { + int num_kernels = pooledD * pooledH * pooledW * channels * frameCnt; + int blocks = (num_kernels + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KeMaxPool3DForward<<>>(num_kernels, + inputData, + channels, + depth, + height, + width, + pooledD, + pooledH, + pooledW, + sizeZ, + sizeY, + sizeX, + strideD, + strideH, + strideW, + padD, + padH, + padW, + tgtData, + maxPoolIdxData, + tgtStride); + CHECK_SYNC("hl_maxpool3D_forward failed"); +} + +__global__ void KeMaxPool3DBackward(const int nthreads, + const real* outGrad, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int padD, + const int padH, + const int padW, + real scaleA, + real scaleB, + real* targetGrad, + real* maxPoolIdxData, + const int outStride) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int offsetW = index % width; + int offsetH = (index / width) % height; + int offsetD = (index / width / height) % depth; + int offsetC = (index / width / height / depth) % channels; + int frameNum = index / width / height / depth / channels; + + int pdstart = + (offsetD + padD < sizeZ) ? 0 : (offsetD + padD - sizeZ) / strideD + 1; + int phstart = + (offsetH + padH < sizeY) ? 0 : (offsetH + padH - sizeY) / strideH + 1; + int pwstart = + (offsetW + padW < sizeX) ? 0 : (offsetW + padW - sizeX) / strideW + 1; + int pdend = min((offsetD + padD) / strideD + 1, pooledD); + int phend = min((offsetH + padH) / strideH + 1, pooledH); + int pwend = min((offsetW + padW) / strideW + 1, pooledW); + + real gradient = 0; + outGrad += ((frameNum * channels + offsetC) * pooledD * pooledH * pooledW); + maxPoolIdxData += + ((frameNum * channels + offsetC) * pooledD * pooledH * pooledW); + for (int pd = pdstart; pd < pdend; ++pd) { + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if (((offsetD * height + offsetH) * width + offsetW) == + maxPoolIdxData[(pd * pooledH + ph) * pooledW + pw]) + gradient += outGrad[(pd * pooledH + ph) * pooledW + pw]; + } + } + } + targetGrad[index] = scaleA * gradient + scaleB * targetGrad[index]; + } +} + +void hl_maxpool3D_backward(const int frameCnt, + const real* outGrad, + const int channels, + const int depth, + const int height, + const int width, + const int outputD, + const int outputH, + const int outputW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int paddingD, + const int paddingH, + const int paddingW, + real scaleA, + real scaleB, + real* targetGrad, + real* maxPoolIdxData, + const int outStride) { + int num_kernels = depth * height * width * channels * frameCnt; + int blocks = (num_kernels + 1024 - 1) / 1024; + + KeMaxPool3DBackward<<>>(num_kernels, + outGrad, + channels, + depth, + height, + width, + outputD, + outputH, + outputW, + sizeZ, + sizeY, + sizeX, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + scaleA, + scaleB, + targetGrad, + maxPoolIdxData, + outStride); + CHECK_SYNC("hl_maxpool3D_backward"); +} + +__global__ void KeAvgPool3DForward(const int nthreads, + const real* inputData, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int padD, + const int padH, + const int padW, + real* tgtData, + const int tgtStride) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int pw = index % pooledW; + int ph = (index / pooledW) % pooledH; + int pd = (index / pooledW / pooledH) % pooledD; + int c = (index / pooledW / pooledH / pooledD) % channels; + int frameNum = index / pooledW / pooledH / pooledD / channels; + int dstart = pd * strideD - padD; + int hstart = ph * strideH - padH; + int wstart = pw * strideW - padW; + int dend = min(dstart + sizeZ, depth + padD); + int hend = min(hstart + sizeY, height + padH); + int wend = min(wstart + sizeX, width + padW); + int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + dend = min(dend, depth); + hend = min(hend, height); + wend = min(wend, width); + + real aveval = 0; + inputData += (frameNum * channels + c) * depth * height * width; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + aveval += inputData[(d * height + h) * width + w]; + } + } + } + int tgtIndex = + index % (pooledW * pooledH * pooledD * channels) + frameNum * tgtStride; + tgtData[tgtIndex] = aveval / pool_size; + } +} + +void hl_avgpool3D_forward(const int frameCnt, + const real* inputData, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int paddingD, + const int paddingH, + const int paddingW, + real* tgtData, + const int tgtStride) { + int num_kernels = pooledD * pooledH * pooledW * channels * frameCnt; + int blocks = (num_kernels + 1024 - 1) / 1024; + KeAvgPool3DForward<<>>(num_kernels, + inputData, + channels, + depth, + height, + width, + pooledD, + pooledH, + pooledW, + sizeZ, + sizeY, + sizeX, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + tgtData, + tgtStride); + CHECK_SYNC("hl_avgpool3D_forward failed"); +} + +__global__ void KeAvgPool3DBackward(const int nthreads, + const real* outGrad, + const int channels, + const int depth, + const int height, + const int width, + const int pooledD, + const int pooledH, + const int pooledW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + const int padD, + const int padH, + const int padW, + real scaleA, + real scaleB, + real* tgtGrad, + const int outStride) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int offsetW = index % width + padW; + int offsetH = (index / width) % height + padH; + int offsetD = (index / width / height) % depth + padD; + int offsetC = (index / width / height / depth) % channels; + int frameNum = index / width / height / depth / channels; + + int pdstart = (offsetD < sizeZ) ? 0 : (offsetD - sizeZ) / strideD + 1; + int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1; + int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1; + int pdend = min(offsetD / strideD + 1, pooledD); + int phend = min(offsetH / strideH + 1, pooledH); + int pwend = min(offsetW / strideW + 1, pooledW); + + real gradient = 0; + outGrad += (frameNum * channels + offsetC) * pooledD * pooledH * pooledW; + + for (int pd = pdstart; pd < pdend; ++pd) { + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int dstart = pd * strideD - padD; + int hstart = ph * strideH - padH; + int wstart = pw * strideW - padW; + int dend = min(dstart + sizeZ, depth + padD); + int hend = min(hstart + sizeY, height + padH); + int wend = min(wstart + sizeX, width + padW); + int poolsize = (dend - dstart) * (hend - hstart) * (wend - wstart); + gradient += outGrad[(pd * pooledH + ph) * pooledW + pw] / poolsize; + } + } + } + tgtGrad[index] = scaleA * gradient + scaleB * tgtGrad[index]; + } +} + +void hl_avgpool3D_backward(const int frameCnt, + const real* outGrad, + const int channels, + const int depth, + const int height, + const int width, + const int outputD, + const int outputH, + const int outputW, + const int sizeZ, + const int sizeY, + const int sizeX, + const int strideD, + const int strideH, + const int strideW, + int paddingD, + int paddingH, + int paddingW, + real scaleA, + real scaleB, + real* backGrad, + const int outStride) { + int num_kernels = depth * height * width * channels * frameCnt; + int blocks = (num_kernels + 1024 - 1) / 1024; + + KeAvgPool3DBackward<<>>(num_kernels, + outGrad, + channels, + depth, + height, + width, + outputD, + outputH, + outputW, + sizeZ, + sizeY, + sizeX, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + scaleA, + scaleB, + backGrad, + outStride); + CHECK_SYNC("hl_avgpool3D_backward failed"); +} + __global__ void KeBilinearInterpFw(const real* in, const size_t inImgH, const size_t inImgW, diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md index 8aa6728a95bc464ab8884986f0cec6c817d3303b..9500c92a265d60a696e1e2c422d0f2bd1621ef71 100644 --- a/paddle/framework/backward.md +++ b/paddle/framework/backward.md @@ -18,7 +18,7 @@ A backward network is built up with several backward operators. Backward operato For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro: ```cpp -REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad); +REGISTER_OP(mul, MulOp, MulOpMaker, MulOpGrad); ``` `mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively. diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index f100c4d05489ac3bd4ceb5f11ae871985f0e5d83..bf8b11e5f5ae801621f84bdbeffb5c4cf2dd8905 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -127,8 +127,8 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { public: FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("x", "x"); - AddOutput("out", "out"); + AddInput("Src", "x"); + AddOutput("Dst", "out"); AddComment(""); } }; @@ -138,7 +138,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker { AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "x").AsDuplicable(); - AddOutput("Y", "y"); + AddOutput("Out", "out"); AddComment(""); } }; @@ -148,16 +148,14 @@ class AddOpMaker : public OpProtoAndCheckerMaker { namespace f = paddle::framework; namespace ops = paddle::operators; using EnforceNotMet = paddle::platform::EnforceNotMet; -REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad, - f::NOP); -REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); -REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); +REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, f::NOP); +REGISTER_OP(mul, f::NOP, f::MulOpMaker, f::NOP); +REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); -REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP); +REGISTER_OP(add, f::NOP, f::AddOpMaker, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); -REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, - f::NOP); +REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, f::NOP); TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp( diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 902c2655e9182d74a48ad13e17a39a3304d5fa57..8a817a3e13ca64d6f8df566891a1059995e041ae 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -54,8 +54,8 @@ TEST(GradOpBuilder, AddTwo) { EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); } -REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP); -REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP); +REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, f::NOP); +REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, f::NOP); TEST(GradOpBuilder, MutiInOut) { std::shared_ptr test_op(f::OpRegistry::CreateOp( diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 2b178907747b3911292b070b65160a24c120b726..71eac4a10b34c3010a2758120c25754af58f669d 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -19,25 +19,24 @@ namespace paddle { namespace framework { -LODTensor::LOD LODTensor::LOD::SliceLevels(size_t level_begin, - size_t level_end) const { +LOD SliceLevels(const LOD& in, size_t level_begin, size_t level_end) { LOD new_lod; new_lod.reserve(level_end - level_begin); for (size_t i = level_begin; i < level_end; i++) { - new_lod.emplace_back(at(i)); + new_lod.emplace_back(in.at(i)); } return new_lod; } -LODTensor::LOD LODTensor::LOD::SliceInLevel(size_t level, size_t elem_begin, - size_t elem_end) const { +LOD SliceInLevel(const LOD& in, size_t level, size_t elem_begin, + size_t elem_end) { // slice the lod. LOD new_lod; - new_lod.reserve(size() - level); - auto start = this->at(level)[elem_begin]; - auto end = this->at(level)[elem_end]; + new_lod.reserve(in.size() - level); + auto start = in.at(level)[elem_begin]; + auto end = in.at(level)[elem_end]; - for (auto it = this->begin() + level; it != this->end(); it++) { + for (auto it = in.begin() + level; it != in.end(); it++) { auto it_begin = std::find(it->begin(), it->end(), start); auto it_end = std::find(it_begin, it->end(), end); PADDLE_ENFORCE(it_begin != it->end(), "error in parsing lod info"); @@ -49,11 +48,11 @@ LODTensor::LOD LODTensor::LOD::SliceInLevel(size_t level, size_t elem_begin, [start](int v) { return v - start; }); PADDLE_ENFORCE_EQ(new_lod.back().front(), 0, "error in slice LOD"); } - PADDLE_ENFORCE_LE(new_lod.size(), this->size()); + PADDLE_ENFORCE_LE(new_lod.size(), in.size()); return new_lod; } -bool operator==(const LODTensor::LOD& a, const LODTensor::LOD& b) { +bool operator==(const LOD& a, const LOD& b) { if (a.size() != b.size()) { return false; } @@ -70,9 +69,27 @@ bool operator==(const LODTensor::LOD& a, const LODTensor::LOD& b) { } } } - return true; } +void LODTensor::SliceLevels(size_t level_begin, size_t level_end) { + auto new_lod = framework::SliceLevels(lod_, level_begin, level_end); + lod_ = new_lod; +} + +void LODTensor::SliceInLevel(size_t level, size_t elem_begin, size_t elem_end) { + PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level, + NumLevels()); + PADDLE_ENFORCE(elem_begin < NumElements(level), + "element begin [%d] out of range [%d]", elem_begin, + NumElements(level)); + PADDLE_ENFORCE(elem_end < NumElements(level) + 1, + "element end [%d] out of range [%d]", elem_end, + NumElements(level)); + + auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end); + lod_ = new_lod; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 9e27aec38d336db8a4f0adbed098d299aa741356..9e6b6b4aca41ed464292b56bf6f2d27514f874f7 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -15,7 +15,7 @@ #pragma once #include -#if !defined(PADDLE_ONLY_CPU) +#ifndef PADDLE_ONLY_CPU #include #include #endif @@ -27,33 +27,39 @@ namespace paddle { namespace framework { +#ifdef PADDLE_ONLY_CPU +template +using Vector = std::vector; +#else +template +using Vector = thrust::host_vector; +#endif + +using LOD = std::vector>; + +LOD SliceLevels(const LOD& in, size_t level_begin, size_t level_end); + +LOD SliceInLevel(const LOD& in, size_t level, size_t elem_begin, + size_t elem_end); + +bool operator==(const LOD& a, const LOD& b); + /* * LODTensor (Level of details Tensor) * see https://en.wikipedia.org/wiki/Level_of_details for reference. */ -class LODTensor : public Tensor { +class LODTensor { public: -// Level save offsets of each unit. -#ifdef PADDLE_ONLY_CPU - template - using Vector = std::vector; -#else - template - using Vector = thrust::host_vector; -#endif - // LoD stores offsets of each level of units, the largest units level first, - // then the smaller units level. Each Level stores the offsets of units in - // Tesor. - class LOD : public std::vector> { - public: - LOD SliceLevels(size_t level_begin, size_t level_end) const; - LOD SliceInLevel(size_t level, size_t elem_begin, size_t elem_end) const; - }; - LODTensor() {} - explicit LODTensor(const LOD &lod) : lod_(lod) {} + LODTensor(const LOD& lod, Tensor* t) : lod_(lod), tensor_(t) {} + + void set_lod(const LOD& lod) { lod_ = lod; } - virtual Tensor *Clone() const { return new LODTensor(lod_); } + void set_tensor(Tensor* tensor) { tensor_ = tensor; } + + Tensor& tensor() { return *tensor_; } + + LOD lod() { return lod_; } /* * Get a element from LOD. @@ -79,71 +85,23 @@ class LODTensor : public Tensor { PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level, NumLevels()); // the last offset is the end of last element - return lod_[level].size() - 1; + return (lod_)[level].size() - 1; } /* - * Slice of levels[level_begin:level_end], with tensor shared. + * Slice of levels[level_begin:level_end] */ - template - LODTensor SliceLevels(size_t level_begin, size_t level_end) const; + void SliceLevels(size_t level_begin, size_t level_end); /* - * Slice of elements of a level, [elem_begin: elem_end], with tensor shared. + * Slice of elements of a level, [elem_begin: elem_end] * @note: low performance in slice lod_. */ - template - LODTensor SliceInLevel(size_t level, size_t elem_begin, - size_t elem_end) const; - - /* - * Copy other's lod_'s content, free to mutate. - */ - void CopyLOD(const LODTensor &other) { lod_ = other.lod_; } - /* - * Determine whether LODTensor has a valid LOD info. - */ - const LOD &lod() const { return lod_; } - LOD *mutable_lod() { return &lod_; } - - virtual ~LODTensor() {} + void SliceInLevel(size_t level, size_t elem_begin, size_t elem_end); private: LOD lod_; + Tensor* tensor_; // not owned }; - -bool operator==(const LODTensor::LOD &a, const LODTensor::LOD &b); - -template -LODTensor LODTensor::SliceLevels(size_t level_begin, size_t level_end) const { - auto new_lod = lod_.SliceLevels(level_begin, level_end); - // slice levels just need to update LOD info, each level will contains the - // whole tensor_, so no need to modify tensor_. - LODTensor new_tensor(new_lod); - new_tensor.ShareDataWith(*this); - return new_tensor; -} - -template -LODTensor LODTensor::SliceInLevel(size_t level, size_t elem_begin, - size_t elem_end) const { - PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level, - NumLevels()); - PADDLE_ENFORCE(elem_begin < NumElements(level), - "element begin [%d] out of range [%d]", elem_begin, - NumElements(level)); - PADDLE_ENFORCE(elem_end < NumElements(level) + 1, - "element end [%d] out of range [%d]", elem_end, - NumElements(level)); - - auto new_lod = lod_.SliceInLevel(level, elem_begin, elem_end); - - // slice elements just need to update LOD info, because offsets are not - // changed, so the original tensor_ can be reused. - LODTensor new_tensor(new_lod); - new_tensor.ShareDataWith(*this); - return new_tensor; -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor.md b/paddle/framework/lod_tensor.md new file mode 100644 index 0000000000000000000000000000000000000000..8dfe3ee823084cb8c38550a82e761a741eabe135 --- /dev/null +++ b/paddle/framework/lod_tensor.md @@ -0,0 +1,122 @@ +# Design Doc: LoD (Level-of-Detail) Tensor + +PaddlePaddle's RNN doesn't require that all instances have the same length. To do so, we introduce an extension to Tensor, namely, LoD Tensor. + +## Challenge of Variable-length Inputs + +People usually represent a mini-batch by a Tensor. For example, a mini-batch of 32 images, each of size 32x32, is a 10x32x32 Tensor. So a transformation, T, of all images can be a matrix multiplication of the 32x32xO-dimensional tensor T and the 10x32x32 Tensor. + +Another example is that each mini-batch contains 32 sentences, where each word is a D-dimensional one-hot vector. If all sentences have the same length L, we can represent this mini-batch by a 32xLxD tensor. However, in most cases, sentences have variable lengths, and we will need an index data structure to record these variable lengths. + +## LoD as a Solution + +### Mini-Batch of variable-length sentenses + +Let's imagine a mini-batch of 3 variable lengths sentences, containing 3, 1, and 2 words respectively. We can represent it by a (3+1+2)xD tensor plus some index information: + +``` + 3 +3 1 2 +||| | || +``` + +Each `|` represents a D-dimensional word vectors. The number 3 on top indicate 3 sentences, and numbers 3, 1, and 2 on the second level represent the number of words in each sentence. + +### Mini-Batch of variable-length videos + +This approach generalizes to the case where elements are not words, but higher dimensional objects, like images. Suppose that a mini-batch contains videos of the same frame size 640x480. If a mini-batch contains 3 videos of 3, 1, and 2 frames respectively. The underlying tensor is of size (3+1+2)x640x480. The index information illustrates as: + +``` + 3 +3 1 2 +口口口 口 口口 +``` + +where each `口` represents an image. + +### Mini-Batch of fixed-size images + +Let's get back to a typical example, image classification, where each mini-batch has M fixed-sized images. The LoD Tensor representation is + +``` + M +1 1 1 1 1 +口口口口 ... 口 +``` + +The many 1's on the second level seem duplicated. For this particular case of 2 levels and the second level always have length 1, we can ignore the LoD index. + +### Design and summarization + +In summary, as long as that the essential elements (words or images) have the same size, we can represent mini-batches by a LoD Tensor: + +- The underlying tensor has size LxD1xD2x..., where D1xD2... is the size of the essential elements, and +- the first dimension size L has an additon property -- a LoD index as a nested vector: + + ```c++ + typedef std::vector > LoD; + ``` + +- The LoD index can is not necessary when there are only two levels and all elements of the second level have length 1. + +## Slicing of LoD Tensor + +Consider that we have a network with three levels of RNN: the top level one handles articles, the second level one handles sentences, and the basic level one handles words. This network requires that mini-batches represented by 4 level LoD Tensor, for example, + +``` + 3 +3 1 2 +3 2 4 1 2 3 +||| || |||| | || ||| +``` + +To allow each level of RNN to handle its input, we define **the slicing of a LoD Tensor is defined as getting the j-th sequence on level i, or the -slice** + +For example, the <2,1>-slice of above slice is + +``` +2 +|| +``` + +and the <1,2>-slice of above example is + +``` +2 +2 3 +|| ||| +``` + +Let's go on slicing this slice. Its <1,1>-slice is + +``` +3 +||| +``` + +### The General Slicing Algorithm + +The algorithm, with over-simplified data structure, is defined as + +```c++ +typedef vector > LoD; + +struct LoDTensor { + LoD lod_; + float* tensor_; +}; + +LoDTensor Slice(const LoDTensor& lodt, int level, int sequence) { + +} +``` + +### Slicing the Top Level + +Please be aware that an RNN operator only slices the top level of a LoD Tensor to get the step inputs. + +```c++ +LoDTensor Slice(const LoDTensor& lodt, int sequence) { + +} +``` diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index 2881136ced6ef957a192e303e529b9b2867b3dda..9a351605edb5013bdab2c6193bdd9ce401acc937 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -24,13 +24,12 @@ namespace framework { class LODTensorTester : public ::testing::Test { public: virtual void SetUp() override { - lod_tensor.reset(new LODTensor); // tensor's batch_size: 30 // 3 levels // 0 10 20 // 0 5 10 15 20 // 0 2 5 7 10 12 15 20 - LODTensor::LOD lod; + LOD lod; lod.push_back(std::vector{0, 10, 20}); lod.push_back(std::vector{0, 5, 10, 15, 20}); lod.push_back(std::vector{0, 2, 5, 7, 10, 12, 15, 17, 20}); @@ -41,75 +40,65 @@ class LODTensorTester : public ::testing::Test { // malloc memory tensor.mutable_data(place); - lod_tensor.reset(new LODTensor(lod)); - lod_tensor->Resize({20 /*batch size*/, 128 /*dim*/}); - - lod_tensor->ShareDataWith(tensor); - // lod_tensor->ShareDataWith(tensor); + lod_tensor.set_lod(lod); + lod_tensor.set_tensor(&tensor); } protected: - std::unique_ptr lod_tensor; platform::CPUPlace place; Tensor tensor; + LODTensor lod_tensor; }; -TEST_F(LODTensorTester, NumLevels) { ASSERT_EQ(lod_tensor->NumLevels(), 3UL); } +TEST_F(LODTensorTester, NumLevels) { ASSERT_EQ(lod_tensor.NumLevels(), 3UL); } TEST_F(LODTensorTester, NumElements) { - ASSERT_EQ(lod_tensor->NumElements(0), 2UL); - ASSERT_EQ(lod_tensor->NumElements(1), 4UL); - ASSERT_EQ(lod_tensor->NumElements(2), 8UL); + ASSERT_EQ(lod_tensor.NumElements(0), 2UL); + ASSERT_EQ(lod_tensor.NumElements(1), 4UL); + ASSERT_EQ(lod_tensor.NumElements(2), 8UL); } TEST_F(LODTensorTester, SliceLevels) { // slice 1 level for (size_t level = 0; level < 3UL; ++level) { - auto new_lod_tensor = lod_tensor->SliceLevels(level, level + 1); + LODTensor new_lod_tensor = lod_tensor; + new_lod_tensor.SliceLevels(level, level + 1); ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL); - ASSERT_EQ(new_lod_tensor.NumElements(0UL), lod_tensor->NumElements(level)); - // ASSERT_EQ(new_lod_tensor, *lod_tensor); + ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level)); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); } // slice 2 level for (size_t level = 0; level < 2UL; ++level) { - auto new_lod_tensor = lod_tensor->SliceLevels(level, level + 2); + LODTensor new_lod_tensor = lod_tensor; + new_lod_tensor.SliceLevels(level, level + 2); ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); - ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor->NumElements(level)); - ASSERT_EQ(new_lod_tensor.NumElements(1), - lod_tensor->NumElements(level + 1)); - ASSERT_EQ(new_lod_tensor.data(), lod_tensor->data()); + ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level)); + ASSERT_EQ(new_lod_tensor.NumElements(1), lod_tensor.NumElements(level + 1)); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); } } TEST_F(LODTensorTester, SliceInLevel) { size_t level = 0; - auto new_lod_tensor = lod_tensor->SliceInLevel(level, 0, 2); + LODTensor new_lod_tensor = lod_tensor; + new_lod_tensor.SliceInLevel(level, 0, 2); EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL); EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL); EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL); EXPECT_EQ(new_lod_tensor.NumElements(2), 8UL); - ASSERT_EQ(new_lod_tensor.data(), lod_tensor->data()); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); level = 1; - new_lod_tensor = lod_tensor->SliceInLevel(level, 0, 2); + new_lod_tensor = lod_tensor; + new_lod_tensor.SliceInLevel(level, 0, 2); ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL); - ASSERT_EQ(new_lod_tensor.data(), lod_tensor->data()); -} - -TEST_F(LODTensorTester, ShareLOD) { - LODTensor new_lod_tensor; - new_lod_tensor.CopyLOD(*lod_tensor); - ASSERT_EQ(new_lod_tensor.lod(), lod_tensor->lod()); -} - -TEST_F(LODTensorTester, CopyLOD) { - LODTensor new_lod_tensor; - new_lod_tensor.CopyLOD(*lod_tensor); - bool equals = std::equal(lod_tensor->lod().begin(), lod_tensor->lod().end(), - new_lod_tensor.lod().begin()); - ASSERT_TRUE(equals); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); } } // namespace framework diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 94245c6c44aca962b0db890947a9dc5550ac0799..b98d8f23a14cf6fbe787953ad16b5c9ab99222ad 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -80,9 +80,19 @@ class OpInfoMap { } const OpInfo& Get(const std::string& type) const { + auto op_info_ptr = GetNullable(type); + PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been registered", + type); + return *op_info_ptr; + } + + const OpInfo* GetNullable(const std::string& type) const { auto it = map_.find(type); - PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type); - return it->second; + if (it == map_.end()) { + return nullptr; + } else { + return &it->second; + } } template diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 2d09cde41e3f5086279f9441e0fdc52549bed5ab..64c7f23ab6b79bad9533f566ca39db3cfd5ac5c5 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -33,8 +33,7 @@ namespace framework { class OpRegistry { public: template - static void RegisterOp(const std::string& op_type, - const std::string& grad_op_type) { + static void RegisterOp(const std::string& op_type) { PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); OpInfo op_info; @@ -43,9 +42,9 @@ class OpRegistry { const VariableNameMap& outputs, const AttributeMap& attrs) { return new OpType(type, inputs, outputs, attrs); }; - op_info.grad_op_type_ = grad_op_type; if (std::type_index(typeid(ProtoMakerType)) != std::type_index(typeid(NOPMaker))) { + op_info.grad_op_type_ = op_type + "_grad"; op_info.proto_ = new OpProto; op_info.checker_ = new OpAttrChecker; auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); @@ -55,15 +54,14 @@ class OpRegistry { op_info.proto_->IsInitialized(), "Fail to initialize %s's OpProto, because %s is not initialized", op_type, op_info.proto_->InitializationErrorString()); + // register gradient op + RegisterOp(op_info.grad_op_type_); } else { + op_info.grad_op_type_ = ""; op_info.proto_ = nullptr; op_info.checker_ = nullptr; } OpInfoMap::Instance().Insert(op_type, op_info); - // register gradient op - if (!grad_op_type.empty()) { - RegisterOp(grad_op_type, ""); - } } static std::unique_ptr CreateOp(const std::string& type, @@ -92,10 +90,8 @@ class Registrar { template class OpRegistrar : public Registrar { public: - explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } - OpRegistrar(const char* op_type, const char* grad_op_type) { - OpRegistry::RegisterOp(op_type, - grad_op_type); + explicit OpRegistrar(const char* op_type) { + OpRegistry::RegisterOp(op_type); } }; @@ -121,8 +117,7 @@ class OpKernelRegistrar : public Registrar { /** * Macro to register Operator. */ -#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ - grad_op_class) \ +#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ class _OpClass_##op_type##_ : public op_class { \ @@ -137,14 +132,14 @@ class OpKernelRegistrar : public Registrar { }; \ static ::paddle::framework::OpRegistrar< \ _OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \ - __op_registrar_##op_type##__(#op_type, #grad_op_type); \ + __op_registrar_##op_type##__(#op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ return 0; \ } #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ - REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP) + REGISTER_OP(op_type, op_class, op_maker_class, ::paddle::framework::NOP) /** * Macro to register OperatorKernel. diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 7abbde610f1e9c530393b9a9cabe40b826712212..790cfc4746b1d34da413fa3c29a266f962c6dde6 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -33,12 +33,12 @@ ExecutionContext::GetEigenDevice() const { } #endif -const std::string& OperatorBase::Input(const std::string& name) const { +std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); - PADDLE_ENFORCE_EQ(ins.size(), 1UL, + PADDLE_ENFORCE_LE(ins.size(), 1UL, "Op %s input %s should contain only one variable", type_, name); - return ins[0]; + return ins.empty() ? kEmptyVarName : ins[0]; } const std::vector& OperatorBase::Inputs( @@ -49,12 +49,12 @@ const std::vector& OperatorBase::Inputs( return it->second; } -const std::string& OperatorBase::Output(const std::string& name) const { +std::string OperatorBase::Output(const std::string& name) const { auto& outs = Outputs(name); - PADDLE_ENFORCE_EQ(outs.size(), 1UL, + PADDLE_ENFORCE_LE(outs.size(), 1UL, "Op %s output %s should contain only one variable", type_, name); - return outs[0]; + return outs.empty() ? kEmptyVarName : outs[0]; } const std::vector& OperatorBase::Outputs( @@ -119,16 +119,8 @@ OperatorBase::OperatorBase(const std::string& type, const VariableNameMap& outputs, const AttributeMap& attrs) : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { - static std::atomic gUniqId(0UL); - for (auto& output : outputs_) { - for (auto& output_name : output.second) { - if (output_name == kTempVarName) { - output_name += type_; - output_name += "@"; - output_name += std::to_string(gUniqId.fetch_add(1)); - } - } - } + GenerateTemporaryNames(); + CheckAllInputOutputSet(); } std::vector OperatorBase::OutputVars(bool has_intermediate) const { @@ -156,6 +148,35 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { return ret_val; } +void OperatorBase::CheckAllInputOutputSet() const { + auto& info_map = OpInfoMap::Instance(); + auto* op_info = info_map.GetNullable(Type()); + if (op_info == nullptr || op_info->proto_ == nullptr) return; + + for (auto& in : op_info->Proto().inputs()) { + PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(), + "Type %s's input %s is not set", Type(), in.name()); + } + + for (auto& out : op_info->Proto().outputs()) { + PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(), + "Type %s's output %s is not set", Type(), out.name()); + } +} + +void OperatorBase::GenerateTemporaryNames() { + static std::atomic gUniqId(0UL); + for (auto& output : outputs_) { + for (auto& output_name : output.second) { + if (output_name == kTempVarName) { + output_name += type_; + output_name += "@"; + output_name += std::to_string(gUniqId.fetch_add(1)); + } + } + } +} + void OpProtoAndCheckerMaker::Validate() { validated_ = true; CheckNoDuplicatedInOutAttrs(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 8397570d26f06f0238e9c5afc85d721df7679257..590e335fdc8843ed9edd01a09605163de93f52d9 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -95,12 +95,12 @@ class OperatorBase { const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` - const std::string& Input(const std::string& name) const; + std::string Input(const std::string& name) const; //! Get a input which has multiple variables. const std::vector& Inputs(const std::string& name) const; //! Get a output with argument's name described in `op_proto` - const std::string& Output(const std::string& name) const; + std::string Output(const std::string& name) const; //! Get an output which has multiple variables. //! TODO add a vector_view to prevent memory copy. const std::vector& Outputs(const std::string& name) const; @@ -127,6 +127,10 @@ class OperatorBase { // IG (Inputs Gradients) VariableNameMap outputs_; AttributeMap attrs_; + + private: + void GenerateTemporaryNames(); + void CheckAllInputOutputSet() const; }; // Macro for define a clone method. @@ -238,11 +242,13 @@ class InferShapeContext { } const Variable* InputVar(const std::string& name) const { - return scope_.FindVar(op_.Input(name)); + auto ipt = op_.Input(name); + return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); } Variable* OutputVar(const std::string& name) const { - return scope_.FindVar(op_.Output(name)); + auto opt = op_.Output(name); + return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); } const std::vector MultiInputVar( @@ -250,9 +256,11 @@ class InferShapeContext { auto names = op_.Inputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { return scope_.FindVar(name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { + return name == kEmptyVarName ? nullptr + : scope_.FindVar(name); + }); return res; } @@ -260,24 +268,24 @@ class InferShapeContext { auto names = op_.Outputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { return scope_.FindVar(name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { + return name == kEmptyVarName ? nullptr + : scope_.FindVar(name); + }); return res; } template const T* Input(const std::string& name) const { auto* var = InputVar(name); - PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name); - return &var->Get(); + return var == nullptr ? nullptr : &var->Get(); } template T* Output(const std::string& name) const { auto var = OutputVar(name); - PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name); - return var->GetMutable(); + return var == nullptr ? nullptr : var->GetMutable(); } template @@ -288,10 +296,7 @@ class InferShapeContext { std::transform(names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) { auto var = scope_.FindVar(sub_name); - PADDLE_ENFORCE_NOT_NULL( - var, "MultiInput(%s:%s) should not be nullptr", name, - sub_name); - return &var->Get(); + return var == nullptr ? nullptr : &var->Get(); }); return res; } @@ -304,10 +309,7 @@ class InferShapeContext { std::transform(names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) { auto var = scope_.FindVar(sub_name); - PADDLE_ENFORCE_NOT_NULL( - var, "MultiOutput(%s:%s) should not be nullptr.", name, - sub_name); - return var->GetMutable(); + return var == nullptr ? nullptr : var->GetMutable(); }); return res; } diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 7d7263b899afb7a2128548f264065a8013b6f0c9..7893e233b776425a61d9e3edd43d944a27743188 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -117,6 +117,8 @@ inline void Tensor::CopyFrom(const Tensor& src, memory::Copy(boost::get(dst_place), dst_ptr, boost::get(src_place), src_ptr, size, 0); } + PADDLE_ENFORCE(cudaStreamSynchronize(0), + "cudaStreamSynchronize failed in Tensor CopyFrom"); #endif } diff --git a/paddle/gserver/layers/CrossEntropyOverBeam.cpp b/paddle/gserver/layers/CrossEntropyOverBeam.cpp index 4acc077035b17fdf5ec06e0d4d916fa0a62f6cba..578bdbbe72120abccc63ed13d11e1dec65d41e44 100644 --- a/paddle/gserver/layers/CrossEntropyOverBeam.cpp +++ b/paddle/gserver/layers/CrossEntropyOverBeam.cpp @@ -223,7 +223,7 @@ void CrossEntropyOverBeam::checkInputs() { << inputLayers_[i * 3]->getName() << " should be a nested sequence"; CHECK_EQ(getInputValue(i * 3 + 1)->getWidth(), beamSize_); - CHECK_EQ(scores.getNumSequences(), batchSize_); + CHECK_EQ(batchSize_, static_cast(scores.getNumSequences())); CHECK_EQ(scores.getNumSubSequences(), selCandidates.getBatchSize()); } else { CHECK(scores.hasSeq()) << "input " << i << " " @@ -231,10 +231,10 @@ void CrossEntropyOverBeam::checkInputs() { << " should be a sequence"; batchSize_ = scores.getNumSequences(); beamSize_ = getInputValue(i * 3 + 1)->getWidth(); - CHECK_EQ(batchSize_, selCandidates.getBatchSize()); + CHECK_EQ(batchSize_, static_cast(selCandidates.getBatchSize())); } CHECK_EQ(1U, scores.value->getWidth()); - CHECK_EQ(batchSize_, goldSeq.getBatchSize()); + CHECK_EQ(batchSize_, static_cast(goldSeq.getBatchSize())); } } @@ -377,8 +377,8 @@ void CrossEntropyOverBeam::forward(PassType passType) { MatrixPtr outputValue = getOutputValue(); for (size_t i = 0; i < batchSize_; ++i) { - beamCosts_[i].setData( - std::move(std::make_shared(beamPerSeq_[i])), beamSize_); + BeamExpansionPtr ptr = std::make_shared(beamPerSeq_[i]); + beamCosts_[i].setData(std::move(ptr), beamSize_); outputValue->getData()[i] = beamCosts_[i].forward(); } } diff --git a/paddle/gserver/layers/Pool3DLayer.cpp b/paddle/gserver/layers/Pool3DLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..199f21adb1a5923b590e4f0e716fc67effb2a2d1 --- /dev/null +++ b/paddle/gserver/layers/Pool3DLayer.cpp @@ -0,0 +1,178 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Pool3DLayer.h" +#include "PoolProjectionLayer.h" +#include "paddle/utils/Logging.h" + +namespace paddle { + +REGISTER_LAYER(pool3d, Pool3DLayer); + +bool Pool3DLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + + /* the size of inputs for pool-layer is 1 */ + CHECK_EQ(config_.inputs_size(), 1); + + const PoolConfig& conf = config_.inputs(0).pool_conf(); + poolType_ = conf.pool_type(); + channels_ = conf.channels(); + + sizeX_ = conf.size_x(); + sizeY_ = conf.size_y(); + sizeZ_ = conf.size_z(); + + strideW_ = conf.stride(); + strideH_ = conf.stride_y(); + strideD_ = conf.stride_z(); + + imgSizeW_ = conf.img_size(); + imgSizeH_ = conf.img_size_y(); + imgSizeD_ = conf.img_size_z(); + + paddingW_ = conf.padding(); + paddingH_ = conf.padding_y(); + paddingD_ = conf.padding_z(); + + outputW_ = conf.output_x(); + outputH_ = conf.output_y(); + outputD_ = conf.output_z(); + + return true; +} + +size_t Pool3DLayer::getSize() { + CHECK_EQ(inputLayers_.size(), 1UL); + + size_t layerSize = 0; + outputD_ = outputSize(imgSizeD_, sizeZ_, paddingD_, strideD_, false); + outputH_ = outputSize(imgSizeH_, sizeY_, paddingH_, strideH_, false); + outputW_ = outputSize(imgSizeW_, sizeX_, paddingW_, strideW_, false); + + layerSize = outputD_ * outputH_ * outputW_ * channels_; + getOutput().setFrameHeight(outputH_); + getOutput().setFrameWidth(outputW_); + getOutput().setFrameDepth(outputD_); + return layerSize; +} + +void Pool3DLayer::forward(PassType passType) { + Layer::forward(passType); + const MatrixPtr& inMat = inputLayers_[0]->getOutputValue(); + size_t batchSize = inMat->getHeight(); + size_t outWidth = getSize(); + resetOutput(batchSize, outWidth); + Matrix::resizeOrCreate(maxPoolIdx_, batchSize, outWidth, false, useGpu_); + const MatrixPtr outMat = getOutputValue(); + + if (poolType_ == "avg") { + outMat->avgPool3DForward(*inMat, + channels_, + imgSizeD_, + imgSizeH_, + imgSizeW_, + outputD_, + outputH_, + outputW_, + sizeZ_, + sizeY_, + sizeX_, + strideD_, + strideH_, + strideW_, + paddingD_, + paddingH_, + paddingW_); + } else if (poolType_ == "max") { + outMat->maxPool3DForward(*inMat, + *maxPoolIdx_, + channels_, + imgSizeD_, + imgSizeH_, + imgSizeW_, + outputD_, + outputH_, + outputW_, + sizeZ_, + sizeY_, + sizeX_, + strideD_, + strideH_, + strideW_, + paddingD_, + paddingH_, + paddingW_); + } else { + LOG(FATAL) << "Unknown pool type: " << poolType_; + } + forwardActivation(); +} + +void Pool3DLayer::backward(const UpdateCallback& callback) { + backwardActivation(); + + (void)callback; + if (NULL == getInputGrad(0)) return; + MatrixPtr inMat = inputLayers_[0]->getOutputValue(); + MatrixPtr inGradMat = inputLayers_[0]->getOutputGrad(); + MatrixPtr outMat = getOutputValue(); + MatrixPtr outGradMat = getOutputGrad(); + + if (poolType_ == "avg") { + inGradMat->avgPool3DBackward(*outGradMat, + imgSizeD_, + imgSizeH_, + imgSizeW_, + outputD_, + outputH_, + outputW_, + sizeZ_, + sizeY_, + sizeZ_, + strideD_, + strideH_, + strideW_, + paddingD_, + paddingH_, + paddingW_, + 1.0, + 1.0); + } else if (poolType_ == "max") { + inGradMat->maxPool3DBackward(*outGradMat, + *maxPoolIdx_, + imgSizeD_, + imgSizeH_, + imgSizeW_, + outputD_, + outputH_, + outputW_, + sizeZ_, + sizeY_, + sizeZ_, + strideD_, + strideH_, + strideW_, + paddingD_, + paddingH_, + paddingW_, + 1.0, + 1.0); + } else { + LOG(FATAL) << "Unknown pool type: " << poolType_; + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/Pool3DLayer.h b/paddle/gserver/layers/Pool3DLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..8329a02f571bf3b5422134c756c248f77fd517b1 --- /dev/null +++ b/paddle/gserver/layers/Pool3DLayer.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "Layer.h" +#include "paddle/math/MathUtils.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +/** + * @brief Basic parent layer of pooling + * Pools the input within regions + */ +class Pool3DLayer : public Layer { +public: + explicit Pool3DLayer(const LayerConfig& config) : Layer(config) {} + ~Pool3DLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + void forward(PassType passType) override; + void backward(const UpdateCallback& callback) override; + size_t getSize(); + +protected: + int channels_; + int sizeX_, sizeY_, sizeZ_; + int strideW_, strideH_, strideD_; + int paddingW_, paddingH_, paddingD_; + int imgSizeW_, imgSizeH_, imgSizeD_; + int outputW_, outputH_, outputD_; + std::string poolType_; + MatrixPtr maxPoolIdx_; +}; +} // namespace paddle diff --git a/paddle/gserver/layers/PrintLayer.cpp b/paddle/gserver/layers/PrintLayer.cpp index 0a1e17b9aa57b373f0df6e079341729539f4e193..e83ae34bbe7d31b9bb7c16bc3fa84db7bd4e33d2 100644 --- a/paddle/gserver/layers/PrintLayer.cpp +++ b/paddle/gserver/layers/PrintLayer.cpp @@ -48,7 +48,16 @@ public: << inputLayers_.size() << ") at " << getName(); } s << format.substr(pos); - LOG(INFO) << s.str(); + + const std::string delimiter("\n"); + std::string content = s.str(); + std::string::size_type foundPos = 0; + std::string::size_type prevPos = 0; + while ((foundPos = content.find(delimiter, prevPos)) != std::string::npos) { + LOG(INFO) << content.substr(prevPos, foundPos - prevPos); + prevPos = foundPos + delimiter.size(); + } + LOG(INFO) << content.substr(prevPos); } void backward(const UpdateCallback& callback) override {} diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index c54f3ef965f4bed730cf5e0e82130b0416cb34c7..a831ffbc73fbd6ad42fa31b2d6d583718474e59b 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1246,6 +1246,75 @@ TEST(Layer, PoolLayer) { #endif } +void setPool3DConfig(TestConfig* config, + PoolConfig* pool, + const string& poolType) { + // filter size + const int NUM_FILTERS = 16; + const int FILTER_SIZE = 3; + const int FILTER_SIZE_Y = 3; + const int FILTER_SIZE_Z = 3; + const int CHANNELS = 16; + + (*config).biasSize = 0; + (*config).layerConfig.set_type("pool3d"); + (*config).layerConfig.set_num_filters(NUM_FILTERS); + + int kw = FILTER_SIZE, kh = FILTER_SIZE_Y, kd = FILTER_SIZE_Z; + int pw = 0, ph = 0, pd = 0; + int sw = 2, sh = 2, sd = 2; + + pool->set_pool_type(poolType); + pool->set_pool_type("avg"); + pool->set_channels(CHANNELS); + pool->set_size_x(kw); + pool->set_size_y(kh); + pool->set_size_z(kd); + pool->set_padding(0); + pool->set_padding_y(0); + pool->set_padding_z(0); + pool->set_stride(sw); + pool->set_stride_y(sh); + pool->set_stride_z(sd); + pool->set_start(0); + int ow = outputSize(pool->img_size(), kw, pw, sw, /* caffeMode */ false); + int oh = outputSize(pool->img_size_y(), kh, ph, sh, /* caffeMode */ false); + int od = outputSize(pool->img_size_z(), kd, pd, sd, /* caffeMode */ false); + pool->set_output_x(ow); + pool->set_output_y(oh); + pool->set_output_z(od); +} + +void testPool3DLayer(const string& poolType, bool trans, bool useGpu) { + TestConfig config; + config.inputDefs.push_back({INPUT_DATA, "layer_0", 11664, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + PoolConfig* pool = input->mutable_pool_conf(); + + const int IMAGE_SIZE = 9; + const int IMAGE_SIZE_Y = 9; + const int IMAGE_SIZE_Z = 9; + + pool->set_img_size(IMAGE_SIZE); + pool->set_img_size_y(IMAGE_SIZE_Y); + pool->set_img_size_z(IMAGE_SIZE_Z); + + setPool3DConfig(&config, pool, poolType); + config.layerConfig.set_size(pool->output_x() * pool->output_y() * + pool->channels()); + + testLayerGrad(config, "pool3d", 100, trans, useGpu); +} + +TEST(Layer, Pool3DLayer) { + testPool3DLayer("avg", /* trans= */ false, /* useGpu= */ false); + testPool3DLayer("max", /* trans= */ false, /* useGpu= */ false); +#ifndef PADDLE_ONLY_CPU + testPool3DLayer("avg", /* trans= */ false, /* useGpu= */ true); + testPool3DLayer("max", /* trans= */ false, /* useGpu= */ true); +#endif +} + void testSppLayer(const string& poolType, const int pyramidHeight, bool trans, diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 579a0f3cf32b803c1d3ac2af57517ad6490f31ef..8bc42571f7c141aa31e18d0504b95b2ed4f0da77 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1190,6 +1190,221 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad, outGrad.getStride()); } +void GpuMatrix::maxPool3DForward(Matrix& inputMat, + Matrix& maxPoolIdx, + size_t channels, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW) { + CHECK(inputMat.useGpu_) << "Matrix type are not correct"; + + real* inputData = inputMat.getData(); + real* maxPoolIdxData = maxPoolIdx.getData(); + size_t num = inputMat.getHeight(); + size_t width = imgSizeW; + size_t height = imgSizeH; + size_t depth = imgSizeD; + CHECK(depth * height * width * channels == inputMat.getWidth()); + CHECK(height_ == inputMat.getHeight()); + CHECK(width_ == outputD * outputH * outputW * channels); + + hl_maxpool3D_forward(num, + inputData, + channels, + depth, + height, + width, + outputD, + outputH, + outputW, + sizeZ, + sizeY, + sizeX, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + getData(), + maxPoolIdxData, + getStride()); +} + +void GpuMatrix::maxPool3DBackward(Matrix& outGrad, + Matrix& maxPoolIdx, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW, + real scaleTargets, + real scaleOutput) { + CHECK(outGrad.useGpu_ && maxPoolIdx.useGpu_) << "Matrix type are not equal"; + + real* outDiff = outGrad.getData(); + real* maxPoolIdxData = maxPoolIdx.getData(); + size_t frameNum = getHeight(); + size_t channels = outGrad.getWidth() / outputD / outputH / outputW; + size_t width = imgSizeW; + size_t height = imgSizeH; + size_t depth = imgSizeD; + CHECK(depth * height * width * channels == getWidth()); + CHECK(width_ == depth * width * height * channels); + CHECK(outGrad.getHeight() == maxPoolIdx.getHeight() && + outGrad.getWidth() == maxPoolIdx.getWidth()); + + hl_maxpool3D_backward(frameNum, + outDiff, + channels, + depth, + height, + width, + outputD, + outputH, + outputW, + sizeZ, + sizeY, + sizeX, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + scaleTargets, + scaleOutput, + getData(), + maxPoolIdxData, + outGrad.getStride()); +} + +void GpuMatrix::avgPool3DForward(Matrix& inputMat, + size_t channels, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW) { + CHECK(inputMat.useGpu_) << "Matrix type are not equal"; + + real* inputData = inputMat.getData(); + size_t frameNum = inputMat.getHeight(); + size_t height = imgSizeH; + size_t width = imgSizeW; + size_t depth = imgSizeD; + CHECK(depth * height * width * channels == inputMat.getWidth()); + CHECK(height_ == inputMat.getHeight()); + CHECK(width_ == outputD * outputH * outputW * channels); + + hl_avgpool3D_forward(frameNum, + inputData, + channels, + depth, + height, + width, + outputD, + outputH, + outputW, + sizeZ, + sizeY, + sizeX, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + getData(), + getStride()); +} + +void GpuMatrix::avgPool3DBackward(Matrix& outGrad, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW, + real scaleTargets, + real scaleOutput) { + CHECK(outGrad.useGpu_) << "Matrix type are not equal"; + + real* outDiff = outGrad.getData(); + size_t frameNum = outGrad.getHeight(); + size_t channels = outGrad.getWidth() / outputD / outputH / outputW; + size_t height = imgSizeH; + size_t width = imgSizeW; + size_t depth = imgSizeD; + CHECK(depth * height * width * channels == width_); + CHECK(height_ == outGrad.getHeight()); + CHECK(outGrad.getWidth() == outputD * outputH * outputW * channels); + + hl_avgpool3D_backward(frameNum, + outDiff, + channels, + depth, + height, + width, + outputD, + outputH, + outputW, + sizeZ, + sizeY, + sizeX, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + scaleTargets, + scaleOutput, + getData(), + outGrad.getStride()); +} + void GpuMatrix::maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index) { @@ -1996,6 +2211,276 @@ void CpuMatrix::avgPoolBackward(Matrix& input, } } +void CpuMatrix::maxPool3DForward(Matrix& inputMat, + Matrix& maxPoolIdx, + size_t channels, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW) { + real* inputData = inputMat.getData(); + real* outData = getData(); + real* maxPoolIdxData = maxPoolIdx.getData(); + size_t num = inputMat.getHeight(); + size_t inWidth = imgSizeW; + size_t inHeight = imgSizeH; + size_t inDepth = imgSizeD; + CHECK(inHeight * inWidth * inDepth == inputMat.getWidth() / channels); + CHECK_EQ(num, this->getHeight()); + CHECK_EQ(channels * outputH * outputW * outputD, this->getWidth()); + size_t outStride = getStride(); + + /* initialize the data_ */ + for (size_t i = 0; i < height_; i++) { + for (size_t j = 0; j < width_; j++) { + outData[(i)*outStride + j] = -(real)FLT_MAX; + maxPoolIdxData[(i)*outStride + j] = -1; + } + } + + /* pool max one by one */ + for (size_t n = 0; n < num; ++n) { // frame by frame + if (!isContiguous()) { + outData = getData() + n * outStride; + maxPoolIdxData = maxPoolIdx.getData() + n * outStride; + } + for (size_t c = 0; c < channels; ++c) { // channel by channel + for (size_t pd = 0; pd < outputD; ++pd) { + for (size_t ph = 0; ph < outputH; ++ph) { + for (size_t pw = 0; pw < outputW; ++pw) { + int dstart = pd * strideD - paddingD; + int hstart = ph * strideH - paddingH; + int wstart = pw * strideW - paddingW; + int dend = std::min(dstart + sizeZ, inDepth); + int hend = std::min(hstart + sizeY, inHeight); + int wend = std::min(wstart + sizeX, inWidth); + dstart = std::max(dstart, 0); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + int maxIdx = -1; + real maxOutData = outData[(pd * outputH + ph) * outputW + pw]; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (maxOutData < + inputData[(d * inHeight + h) * inWidth + w]) { + maxOutData = inputData[(d * inHeight + h) * inWidth + w]; + maxIdx = (d * inHeight + h) * inWidth + w; + } + } + } + } + outData[(pd * outputH + ph) * outputW + pw] = maxOutData; + maxPoolIdxData[(pd * outputH + ph) * outputW + pw] = maxIdx; + } + } + } + // compute offset + inputData += inDepth * inHeight * inWidth; + outData += outputD * outputH * outputW; + maxPoolIdxData += outputD * outputH * outputW; + } + } +} + +void CpuMatrix::maxPool3DBackward(Matrix& outGrad, + Matrix& maxPoolIdx, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW, + real scaleTargets, + real scaleOutput) { + size_t num = getHeight(); + size_t channels = size_t(width_ / imgSizeD / imgSizeH / imgSizeW); + CHECK(maxPoolIdx.getHeight() == outGrad.getHeight() && + maxPoolIdx.getWidth() == outGrad.getWidth()); + + real* tgtGrad = getData(); + real* otGrad = outGrad.getData(); + real* maxPoolIdxData = maxPoolIdx.getData(); + size_t outStride = outGrad.getStride(); + + for (size_t n = 0; n < num; ++n) { + if (!outGrad.isContiguous()) { + otGrad = outGrad.getData() + n * outStride; + maxPoolIdxData = maxPoolIdx.getData() + n * outStride; + } + for (size_t c = 0; c < channels; ++c) { + for (size_t pd = 0; pd < outputD; ++pd) { + for (size_t ph = 0; ph < outputH; ++ph) { + for (size_t pw = 0; pw < outputW; ++pw) { + const size_t index = (pd * outputH + ph) * outputW + pw; + const size_t tgtIdx = static_cast(maxPoolIdxData[index]); + tgtGrad[tgtIdx] = + scaleTargets * tgtGrad[tgtIdx] + scaleOutput * otGrad[index]; + } + } + } + // offset + tgtGrad += imgSizeD * imgSizeH * imgSizeW; + otGrad += outputD * outputH * outputW; + maxPoolIdxData += outputD * outputH * outputW; + } + } +} + +void CpuMatrix::avgPool3DForward(Matrix& input, + size_t channels, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW) { + // The main loop + size_t num = input.getHeight(); + size_t inDepth = imgSizeD; + size_t inHeight = imgSizeH; + size_t inWidth = imgSizeW; + CHECK(inDepth * inHeight * inWidth * channels == input.getWidth()); + CHECK(outputD * outputH * outputW * channels * num == height_ * width_); + real* tgtData = getData(); + real* inData = input.getData(); + + for (size_t n = 0; n < num; ++n) { + if (!isContiguous()) { + tgtData = data_ + n * getStride(); + } + for (size_t c = 0; c < channels; ++c) { + for (size_t pd = 0; pd < outputD; ++pd) { + for (size_t ph = 0; ph < outputH; ++ph) { + for (size_t pw = 0; pw < outputW; ++pw) { + int dstart = pd * strideD - paddingD; + int hstart = ph * strideH - paddingH; + int wstart = pw * strideW - paddingW; + int dend = std::min(dstart + sizeZ, inDepth + paddingD); + int hend = std::min(hstart + sizeY, inHeight + paddingH); + int wend = std::min(wstart + sizeX, inWidth + paddingW); + int poolSize = (dend - dstart) * (hend - hstart) * (wend - wstart); + dstart = std::max(dstart, 0); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + dend = std::min(dend, static_cast(inDepth)); + hend = std::min(hend, static_cast(inHeight)); + wend = std::min(wend, static_cast(inWidth)); + + CHECK(poolSize); + tgtData[(pd * outputH + ph) * outputW + pw] = 0; // clear + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + tgtData[(pd * outputH + ph) * outputW + pw] += + inData[(d * inHeight + h) * inWidth + w]; + } + } + } + tgtData[(pd * outputH + ph) * outputW + pw] /= poolSize; + } + } + } + // compute offset + inData += inDepth * inHeight * inWidth; + tgtData += outputD * outputH * outputW; + } + } +} + +void CpuMatrix::avgPool3DBackward(Matrix& input, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW, + real scaleTargets, + real scaleOutput) { + size_t num = input.getHeight(); + size_t channels = input.getWidth() / outputD / outputH / outputW; + CHECK(imgSizeD * imgSizeH * imgSizeW * channels == getWidth()); + real* inData = input.getData(); + real* outData = getData(); + + for (size_t n = 0; n < num; ++n) { + if (!input.isContiguous()) { + inData = input.getData() + n * input.getStride(); + } + for (size_t c = 0; c < channels; ++c) { + for (size_t pd = 0; pd < outputD; ++pd) { + for (size_t ph = 0; ph < outputH; ++ph) { + for (size_t pw = 0; pw < outputW; ++pw) { + int dstart = pd * strideD - paddingD; + int hstart = ph * strideH - paddingH; + int wstart = pw * strideW - paddingW; + int dend = std::min(dstart + sizeZ, imgSizeD + paddingD); + int hend = std::min(hstart + sizeY, imgSizeH + paddingH); + int wend = std::min(wstart + sizeX, imgSizeW + paddingW); + int poolSize = (dend - dstart) * (hend - hstart) * (wend - wstart); + dstart = std::max(dstart, 0); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + dend = std::min(dend, static_cast(imgSizeD)); + hend = std::min(hend, static_cast(imgSizeH)); + wend = std::min(wend, static_cast(imgSizeW)); + CHECK(poolSize); + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + outData[(d * imgSizeH + h) * imgSizeW + w] += + inData[(pd * outputH + ph) * outputW + pw] / poolSize; + } + } + } + } + } + } + // offset + outData += imgSizeD * imgSizeH * imgSizeW; + inData += outputD * outputH * outputW; + } + } +} + /** * Input: one or more sequences. Each sequence contains some instances. * Output: output size is the number of input sequences (NOT input instances). diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index cc3a56f279cc6a17104c681a51f1ca907143fc44..431d4e071072317c8fdfdc4f0d13e7cd4e3d062b 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -928,15 +928,102 @@ public: size_t paddingW) { LOG(FATAL) << "Not implemeted"; } - /** - * Input: one or more sequences. Each sequence contains some instances. - * - * Output: output size is the number of input sequences (NOT input - * instances). - * - * output[i] is set to max_input[i]. + * Pooling 3D forward operation, pick out the largest element + * in the sizeX of value */ + virtual void maxPool3DForward(Matrix& inputMat, + Matrix& maxPoolIdx, + size_t channels, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW) { + LOG(FATAL) << "Not implemeted"; + } + + virtual void maxPool3DBackward(Matrix& outGrad, + Matrix& maxPoolIdx, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW, + real scaleTargets, + real scaleOutput) { + LOG(FATAL) << "Not implemeted"; + } + + virtual void avgPool3DForward(Matrix& input, + size_t channels, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW) { + LOG(FATAL) << "Not implemeted"; + } + + virtual void avgPool3DBackward(Matrix& input, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW, + real scaleTargets, + real scaleOutput) { + LOG(FATAL) << "Not implemeted"; + } + + /** + * Input: one or more sequences. Each sequence contains some instances. + * + * Output: output size is the number of input sequences (NOT input + * instances). + * + * output[i] is set to max_input[i]. + */ virtual void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index) { @@ -1384,6 +1471,82 @@ public: size_t paddingH, size_t paddingW); + void maxPool3DForward(Matrix& inputMat, + Matrix& maxPoolIdx, + size_t channels, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW); + + void maxPool3DBackward(Matrix& outGrad, + Matrix& maxPoolIdx, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW, + real scaleTargets, + real scaleOutput); + + void avgPool3DForward(Matrix& input, + size_t channels, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW); + + void avgPool3DBackward(Matrix& input, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW, + real scaleTargets, + real scaleOutput); + void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index); @@ -1575,6 +1738,82 @@ public: size_t paddingH, size_t paddingW); + void maxPool3DForward(Matrix& inputMat, + Matrix& maxPoolIdx, + size_t channels, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW); + + void maxPool3DBackward(Matrix& outGrad, + Matrix& maxPoolIdx, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW, + real scaleTargets, + real scaleOutput); + + void avgPool3DForward(Matrix& input, + size_t channels, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW); + + void avgPool3DBackward(Matrix& input, + size_t imgSizeD, + size_t imgSizeH, + size_t imgSizeW, + size_t outputD, + size_t outputH, + size_t outputW, + size_t sizeZ, + size_t sizeY, + size_t sizeX, + size_t strideD, + size_t strideH, + size_t strideW, + size_t paddingD, + size_t paddingH, + size_t paddingW, + real scaleTargets, + real scaleOutput); + void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 3abe4484dbc86711c5798b0900a47d09a0d47299..103f06acc57d7a23f019f5e713f6cacf2179e9e0 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1204,6 +1204,399 @@ TEST(Matrix, warpCTC) { } } +void testMaxPool3DFwdBwd(int numSamples, + int channels, + int imgSizeD, + int imgSizeH, + int imgSizeW, + int ksizeD, + int ksizeH, + int ksizeW, + int strideD, + int strideH, + int strideW, + int padD, + int padH, + int padW) { + int outD = outputSize(imgSizeD, ksizeD, padD, strideD, true); + int outH = outputSize(imgSizeH, ksizeH, padH, strideH, true); + int outW = outputSize(imgSizeW, ksizeW, padW, strideW, true); + + int inWidth = channels * imgSizeD * imgSizeH * imgSizeW; + MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false); + MatrixPtr inputGpu = GpuMatrix::create(numSamples, inWidth, false, true); + + int outWidth = channels * outD * outH * outW; + MatrixPtr target = CpuMatrix::create(numSamples, outWidth, false, false); + MatrixPtr targetGpu = GpuMatrix::create(numSamples, outWidth, false, true); + MatrixPtr maxIdx = CpuMatrix::create(numSamples, outWidth, false, false); + MatrixPtr maxIdxGpu = GpuMatrix::create(numSamples, outWidth, false, true); + + input->randomizeUniform(); + target->randomizeUniform(); + inputGpu->copyFrom(*input); + targetGpu->copyFrom(*target); + + target->maxPool3DForward(*input, + *maxIdx, + channels, + imgSizeD, + imgSizeH, + imgSizeW, + outD, + outH, + outW, + ksizeD, + ksizeH, + ksizeW, + strideD, + strideH, + strideW, + padD, + padH, + padW); + targetGpu->maxPool3DForward(*inputGpu, + *maxIdxGpu, + channels, + imgSizeD, + imgSizeH, + imgSizeW, + outD, + outH, + outW, + ksizeD, + ksizeH, + ksizeW, + strideD, + strideH, + strideW, + padD, + padH, + padW); + MatrixPtr targetCheck = CpuMatrix::create(numSamples, outWidth, false, false); + targetCheck->copyFrom(*targetGpu); + checkMatrixEqual(target, targetCheck); + + MatrixPtr inputGrad = CpuMatrix::create(numSamples, inWidth, false, false); + MatrixPtr inputGpuGrad = GpuMatrix::create(numSamples, inWidth, false, true); + MatrixPtr targetGrad = CpuMatrix::create(numSamples, outWidth, false, false); + MatrixPtr targetGpuGrad = + GpuMatrix::create(numSamples, outWidth, false, true); + + inputGrad->randomizeUniform(); + targetGrad->randomizeUniform(); + inputGpuGrad->copyFrom(*inputGrad); + targetGpuGrad->copyFrom(*targetGrad); + + inputGrad->maxPool3DBackward(*targetGrad, + *maxIdx, + imgSizeD, + imgSizeH, + imgSizeW, + outD, + outH, + outW, + ksizeD, + ksizeH, + ksizeW, + strideD, + strideH, + strideW, + padD, + padH, + padW, + 1.0, + 1.0); + inputGpuGrad->maxPool3DBackward(*targetGpuGrad, + *maxIdxGpu, + imgSizeD, + imgSizeH, + imgSizeW, + outD, + outH, + outW, + ksizeD, + ksizeH, + ksizeW, + strideD, + strideH, + strideW, + padD, + padH, + padW, + 1.0, + 1.0); + MatrixPtr targetBwdCheck = + CpuMatrix::create(numSamples, inWidth, false, false); + targetBwdCheck->copyFrom(*inputGpuGrad); + checkMatrixEqual(inputGrad, targetBwdCheck); +} + +void testAvgPool3DFwdBwd(int numSamples, + int channels, + int imgSizeD, + int imgSizeH, + int imgSizeW, + int ksizeD, + int ksizeH, + int ksizeW, + int strideD, + int strideH, + int strideW, + int padD, + int padH, + int padW) { + int outD = outputSize(imgSizeD, ksizeD, padD, strideD, true); + int outH = outputSize(imgSizeH, ksizeH, padH, strideH, true); + int outW = outputSize(imgSizeW, ksizeW, padW, strideW, true); + + int inWidth = imgSizeD * imgSizeH * imgSizeW * channels; + MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false); + MatrixPtr inputGpu = GpuMatrix::create(numSamples, inWidth, false, true); + + int outWidth = channels * outD * outH * outW; + MatrixPtr target = CpuMatrix::create(numSamples, outWidth, false, false); + MatrixPtr targetGpu = GpuMatrix::create(numSamples, outWidth, false, true); + + input->randomizeUniform(); + target->randomizeUniform(); + inputGpu->copyFrom(*input); + targetGpu->copyFrom(*target); + + target->avgPool3DForward(*input, + channels, + imgSizeD, + imgSizeH, + imgSizeW, + outD, + outH, + outW, + ksizeD, + ksizeH, + ksizeW, + strideD, + strideH, + strideW, + padD, + padH, + padW); + + targetGpu->avgPool3DForward(*inputGpu, + channels, + imgSizeD, + imgSizeH, + imgSizeW, + outD, + outH, + outW, + ksizeD, + ksizeH, + ksizeW, + strideD, + strideH, + strideW, + padD, + padH, + padW); + + TensorCheckErr(*target, *targetGpu); + + MatrixPtr inputGrad = CpuMatrix::create(numSamples, inWidth, false, false); + MatrixPtr inputGpuGrad = GpuMatrix::create(numSamples, inWidth, false, true); + MatrixPtr targetGrad = CpuMatrix::create(numSamples, outWidth, false, false); + MatrixPtr targetGpuGrad = + GpuMatrix::create(numSamples, outWidth, false, true); + + inputGrad->randomizeUniform(); + targetGrad->randomizeUniform(); + inputGpuGrad->copyFrom(*inputGrad); + targetGpuGrad->copyFrom(*targetGrad); + + inputGrad->avgPool3DBackward(*targetGrad, + imgSizeD, + imgSizeH, + imgSizeW, + outD, + outH, + outW, + ksizeD, + ksizeH, + ksizeW, + strideD, + strideH, + strideW, + padD, + padH, + padW, + 1.0, + 1.0); + + inputGpuGrad->avgPool3DBackward(*targetGpuGrad, + imgSizeD, + imgSizeH, + imgSizeW, + outD, + outH, + outW, + ksizeD, + ksizeH, + ksizeW, + strideD, + strideH, + strideW, + padD, + padH, + padW, + 1.0, + 1.0); + TensorCheckErr(*inputGrad, *inputGpuGrad); +} + +// TODO(yi): I noticed many such blindly combinatorial tests in this +// file. They are no help to locate defects at all. +TEST(Matrix, Pool3DFwdBwd) { + for (auto numSamples : {1, 3}) { + for (auto channels : {3}) { + for (auto imgSizeD : {9, 16}) { + for (auto imgSizeH : {9, 32}) { + for (auto imgSizeW : {9, 32}) { + for (auto sizeX : {3}) { + for (auto sizeY : {3}) { + for (auto sizeZ : {3}) { + for (auto sD : {2}) { + for (auto sH : {2}) { + for (auto sW : {2}) { + for (auto pD : {0, (sizeZ - 1) / 2}) { + for (auto pH : {0, (sizeY - 1) / 2}) { + for (auto pW : {0, (sizeX - 1) / 2}) { + VLOG(3) << " numSamples=" << numSamples + << " channels=" << channels + << " imgSizeD=" << imgSizeD + << " imgSizeH=" << imgSizeH + << " imgSizeW=" << imgSizeW + << " sizeX=" << sizeX + << " sizeY=" << sizeY + << " sizeZ=" << sizeZ << " strideD=" << sD + << " strideH=" << sH << " strideW=" << sW + << " padingD=" << pD << " padingH=" << pH + << " padingW=" << pW; + + testMaxPool3DFwdBwd(numSamples, + channels, + imgSizeD, + imgSizeH, + imgSizeW, + sizeX, + sizeY, + sizeZ, + sD, + sH, + sW, + pD, + pH, + pW); + testAvgPool3DFwdBwd(numSamples, + channels, + imgSizeD, + imgSizeH, + imgSizeW, + sizeX, + sizeY, + sizeZ, + sD, + sH, + sW, + pD, + pH, + pW); + } + } + } + } + } + } + } + } + } + } + } + } + } + } + + // for (auto numSamples : {1, 3}) { + // for (auto channels : {1, 3}) { + // for (auto imgSizeD : {9,16}) { + // for (auto imgSizeH : {9, 32}) { + // for (auto imgSizeW : {9, 32}) { + // for (auto sizeX : {2, 3}) { + // for (auto sizeY : {2, 3}) { + // for (auto sizeZ : {2,3}){ + // for (auto sD : {1, 2}) { + // for (auto sH : {1, 2}) { + // for (auto sW : {1, 2}) { + // for (auto pD : {0, (sizeZ - 1) / 2}){ + // for (auto pH : {0, (sizeY - 1) / 2}) { + // for (auto pW : {0, (sizeX - 1) / 2}) { + // VLOG(3) << " numSamples=" << numSamples + // << " channels=" << channels + // << " imgSizeD=" << imgSizeD + // << " imgSizeH=" << imgSizeH + // << " imgSizeW=" << imgSizeW + // << " sizeX=" << sizeX + // << " sizeY=" << sizeY + // << " sizeZ=" << sizeZ + // << " strideD=" << sD + // << " strideH=" << sH + // << " strideW=" << sW + // << " padingD=" << pD + // << " padingH=" << pH + // << " padingW=" << pW; + // + // testMaxPool3DFwdBwd(numSamples, + // channels, + // imgSizeD, + // imgSizeH, + // imgSizeW, + // sizeX, + // sizeY, + // sizeZ, + // sD, + // sH, + // sW, + // pD, + // pH, + // pW); + // testAvgPool3DFwdBwd(numSamples, + // channels, + // imgSizeD, + // imgSizeH, + // imgSizeW, + // sizeX, + // sizeY, + // sizeZ, + // sD, + // sH, + // sW, + // pD, + // pH, + // pW); + // } + // } + // } + // } + // } + // } + // } + // } + // } + // } + // } + // } + // } + // } +} + void testMatrixCol2Vol(int depth, int height, int width) { int channel = 3; int filterX = 3, filterY = 4, filterZ = 5; @@ -1303,6 +1696,5 @@ TEST(Matrix, col2Vol) { } } } -/////// #endif diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f0fd12f1b5276d033ea086c60c80616fb1be7585..e5efcccb0e219a1c9df888cfec7f8902806676d4 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -1,7 +1,10 @@ +file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") +string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}") function(op_library TARGET) # op_library is a function to create op library. The interface is same as # cc_library. But it handle split GPU/CPU code and link some common library # for ops. + set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE) set(cc_srcs) set(cu_srcs) set(op_common_deps operator op_registry) @@ -43,33 +46,26 @@ endfunction() add_subdirectory(math) -cc_test(gather_test SRCS gather_test.cc DEPS tensor) -op_library(gather_op SRCS gather_op.cc gather_op.cu) - -cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) -op_library(scatter_op SRCS scatter_op.cc scatter_op.cu) - -cc_library(net_op SRCS net_op.cc DEPS op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) - -op_library(add_op SRCS add_op.cc add_op.cu) - -op_library(mean_op SRCS mean_op.cc mean_op.cu) +list(REMOVE_ITEM GENERAL_OPS + net_op + minus_op + mul_op + recurrent_op + scale_op) +op_library(net_op SRCS net_op.cc) +op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op) op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) -op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) +op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc + DEPS framework_proto tensor operator net_op) +op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) -op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) -op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) -op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu) -op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) -op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) +foreach(src ${GENERAL_OPS}) + op_library(${src} SRCS ${src}.cc ${src}.cu) +endforeach() -op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) +set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") -op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc - DEPS framework_proto tensor op_registry operator net_op) -op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) -op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu) -op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) -op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op) +cc_test(gather_test SRCS gather_test.cc DEPS tensor) +cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) +cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 8ab748ed71e9a5dc0ee0259a78a2b886870bec5b..6384d8c8ce13dae8b58ed1069d496dd8e93eaa8a 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -57,7 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad); +REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, ops::AddOpGrad); REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index ab1e1c101a10e09a81f7785d2f1514822e3bdf15..ac76326262c88e2014cf64f7fb73b5a7338ab3e9 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -67,8 +67,7 @@ OnehotCrossEntropy Operator. namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, - ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOp); + ops::OnehotCrossEntropyOpMaker, ops::OnehotCrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, ops::OnehotCrossEntropyOpKernel); REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 123bed296c462c30bddd3bfbd530098fdbfe4856..07fa704824174f939e459093b245036771d9cd4f 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -63,8 +63,7 @@ Out = X[Index] } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, - ops::GatherGradOp); +REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, ops::GatherGradOp); REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 94d40890a765413e88a35a6ad995ca97ac84dcda..c3108ba8ec7ad85bd3485c135bf03e514bc66cd1 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -66,7 +66,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker, - lookup_table_grad, ops::LookupTableOpGrad); + ops::LookupTableOpGrad); REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel); REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel); diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index d3d0e55a674587fb04f43f24d0790de4358f035a..e66e0abb25f9b933025a6d098ed9dd9eb18a47a5 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -54,7 +54,7 @@ class MeanGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp); +REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); REGISTER_OP_CPU_KERNEL(mean_grad, diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index 1eee9644babbdfac68821ca774845ad8ebbd5aee..b4afebcd97a8efff70aaaa85bc2ec5455ddd05c5 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -81,7 +81,6 @@ class MinusGradOp : public NetOp { USE_OP(scale); USE_OP_ITSELF(identity); namespace ops = paddle::operators; -REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, - ops::MinusGradOp); +REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradOp); REGISTER_OP_CPU_KERNEL(minus, ops::MinusKernel); diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 173cc3850ca9d97200e272ec59d1bd3fe09b5053..559d19e6bdc083fffebe1c82a0bebbb18dd134fd 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -84,7 +84,7 @@ class MulOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); REGISTER_OP_CPU_KERNEL(mul_grad, ops::MulGradKernel); diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 6825dce332adc0dc11dda187d1bd367875b8603e..63de91254f4b75587cb2fb29aeb8ff7358ba8e76 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -74,7 +74,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker, - rowwise_add_grad, ops::RowwiseAddGradOp); + ops::RowwiseAddGradOp); REGISTER_OP_CPU_KERNEL( rowwise_add, ops::RowwiseAddKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index 8e96a74c94ab7ff4d8c3266695e5157aff67905b..4e039688d4d74f2a101fc91c747bd1e6ebec7ad2 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -97,7 +97,7 @@ class IdentityOp : public NetOp { namespace ops = paddle::operators; -REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, scale_grad, +REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradOp); REGISTER_OP_CPU_KERNEL(scale, ops::ScaleKernel); diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index f901edefa22dc9a252e87116df756d04767a7162..35c185ad80f93d1005c1616dcffd2e61bcd54222 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -77,8 +77,7 @@ Out[Index] = Ref[Index] + Updates } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad, - ops::ScatterGradOp); +REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, ops::ScatterGradOp); REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 761c6de8d4d2150b30b97b58da95da3d5f33db63..f35b7023845bac52887d81a8f5c496cb5e7193aa 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -53,8 +53,7 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, - ops::SigmoidOpGrad); +REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, ops::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 40c51a64c49bc064f55975ef6ced1d54070f1291..471bb288fb20f113aefb2a9e13eb805b161b0631 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -62,8 +62,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad, - ops::SoftmaxOpGrad); +REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 37e186a408ff5f560b5878e3e51ea81ca5810bc7..00030050700bfb2cee224124d090b0027d456ba0 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -2,21 +2,5 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python backward - sgd_op - gather_op - scatter_op - add_op - mul_op - rowwise_add_op - sigmoid_op - softmax_op - mean_op - cross_entropy_op - recurrent_op - uniform_random_op - gaussian_random_op - fill_zeros_like_op - lookup_table_op - scale_op - minus_op) + ${GLOB_OP_LIB}) endif(WITH_PYTHON) diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 36b7803f073950437c938b54e6b5677b0c359151..4ddf023780c704cb10c51ee9e5d7cb63420f9d73 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -133,6 +133,12 @@ message PoolConfig { // if not set, use padding optional uint32 padding_y = 13; + + optional uint32 size_z = 14 [ default = 1 ]; + optional uint32 stride_z = 15 [ default = 1 ]; + optional uint32 output_z = 16 [ default = 1 ]; + optional uint32 img_size_z = 17 [ default = 1 ]; + optional uint32 padding_z = 18 [ default = 1 ]; } message SppConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 049efe6cfc1dcfa77504f7cd6a5fbc6bf610c3f0..152a56190c1ffddbf9590ed8f71308ceb88403f4 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -938,6 +938,31 @@ class Pool(Cfg): self.add_keys(locals()) +@config_class +class Pool3d(Cfg): + def __init__( + self, + pool_type, + channels, + size_x, + size_y=None, + size_z=None, + start=None, + stride=None, # 1 by defalut in protobuf + stride_y=None, + stride_z=None, + padding=None, # 0 by defalut in protobuf + padding_y=None, + padding_z=None): + self.add_keys(locals()) + self.filter_size_y = size_y if size_y else size_x + self.filter_size_z = size_z if size_z else size_x + self.padding_y = padding_y if padding_y else padding + self.padding_z = padding_z if padding_z else padding + self.stride_y = stride_y if stride_y else stride + self.stride_z = stride_z if stride_z else stride + + @config_class class SpatialPyramidPool(Cfg): def __init__(self, pool_type, pyramid_height, channels): @@ -1253,6 +1278,45 @@ def parse_pool(pool, input_layer_name, pool_conf, ceil_mode): pool_conf.stride_y, not ceil_mode) +def parse_pool3d(pool, input_layer_name, pool_conf, ceil_mode): + pool_conf.pool_type = pool.pool_type + config_assert(pool.pool_type in ['max-projection', 'avg-projection'], + "pool-type %s is not in " + "['max-projection', 'avg-projection']" % pool.pool_type) + + pool_conf.channels = pool.channels + + pool_conf.size_x = pool.size_x + pool_conf.stride = pool.stride + pool_conf.padding = pool.padding + + pool_conf.size_y = default(pool.size_y, pool_conf.size_x) + pool_conf.size_z = default(pool.size_z, pool_conf.size_x) + pool_conf.stride_y = default(pool.stride_y, pool_conf.stride) + pool_conf.stride_z = default(pool.stride_z, pool_conf.stride) + pool_conf.padding_y = default(pool.padding_y, pool_conf.padding) + pool_conf.padding_z = default(pool.padding_z, pool_conf.padding) + + pool_conf.img_size, pool_conf.img_size_y, pool_conf.img_size_z = \ + get_img3d_size(input_layer_name, pool.channels) + + config_assert(not pool.start, "start is deprecated in pooling.") + + if pool.padding is not None: + pool_conf.padding = pool.padding + pool_conf.padding_y = default(pool.padding_y, pool_conf.padding) + pool_conf.padding_z = default(pool.padding_z, pool_conf.padding) + pool_conf.output_x = cnn_output_size(pool_conf.img_size, pool_conf.size_x, + pool_conf.padding, pool_conf.stride, + not ceil_mode) + pool_conf.output_y = cnn_output_size(pool_conf.img_size_y, pool_conf.size_y, + pool_conf.padding_y, + pool_conf.stride_y, not ceil_mode) + pool_conf.output_z = cnn_output_size(pool_conf.img_size_z, pool_conf.size_z, + pool_conf.padding_z, + pool_conf.stride_z, not ceil_mode) + + def parse_spp(spp, input_layer_name, spp_conf): parse_image(spp, input_layer_name, spp_conf.image_conf) spp_conf.pool_type = spp.pool_type @@ -1897,9 +1961,9 @@ class DataLayer(LayerBase): def __init__(self, name, size, + depth=None, height=None, width=None, - depth=None, device=None): super(DataLayer, self).__init__( name, 'data', size, inputs=[], device=device) @@ -2215,6 +2279,35 @@ class PoolLayer(LayerBase): pool_conf.channels) +@config_layer('pool3d') +class Pool3DLayer(LayerBase): + def __init__(self, name, inputs, ceil_mode=True, **xargs): + super(Pool3DLayer, self).__init__( + name, 'pool3d', 0, inputs=inputs, **xargs) + for input_index in xrange(len(self.inputs)): + input_layer = self.get_input_layer(input_index) + pool_conf = self.config.inputs[input_index].pool_conf + parse_pool3d(self.inputs[input_index].pool, input_layer.name, + pool_conf, ceil_mode) + self.set_cnn_layer(name, pool_conf.output_z, pool_conf.output_y, + pool_conf.output_x, pool_conf.channels) + + def set_cnn_layer(self, + input_layer_name, + depth, + height, + width, + channels, + is_print=True): + size = depth * height * width * channels + self.set_layer_size(size) + self.set_layer_height_width(height, width) + self.set_layer_depth(depth) + if is_print: + print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" % + (input_layer_name, channels, depth, height, width, size)) + + @config_layer('spp') class SpatialPyramidPoolLayer(LayerBase): def __init__(self, name, inputs, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 3de2f2f01da3778b3ae86f22e5c39f5193c7ccce..2bd274fad2ab7eed0902ffe944c6e0670f963233 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -137,7 +137,8 @@ __all__ = [ 'clip_layer', 'slice_projection', 'seq_slice_layer', - 'kmax_sequence_score_layer', + 'kmax_seq_score_layer', + 'img_pool3d_layer', 'scale_shift_layer', 'img_conv3d_layer', ] @@ -168,6 +169,7 @@ class LayerType(object): EXCONVTRANS_LAYER = 'exconvt' CUDNNCONV_LAYER = 'cudnn_conv' POOL_LAYER = 'pool' + POOL3D_LAYER = 'pool3d' BATCH_NORM_LAYER = 'batch_norm' NORM_LAYER = 'norm' SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm' @@ -900,7 +902,7 @@ def mixed_layer(size=0, @layer_support() -def data_layer(name, size, height=None, width=None, depth=None, +def data_layer(name, size, depth=None, height=None, width=None, layer_attr=None): """ Define DataLayer For NeuralNetwork. @@ -938,8 +940,8 @@ def data_layer(name, size, height=None, width=None, depth=None, num_filters = None if height is not None and width is not None: num_filters = size / (width * height * depth) - assert num_filters * width * height*depth == size, \ - "size=%s width=%s height=%s depth=%s" % (size, width, height, depth) + assert num_filters * width * height * depth == size, \ + "size=%s width=%s height=%s depth=%s" % (size, width, height, depth) return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters) @@ -2663,6 +2665,146 @@ def img_pool_layer(input, size=l.config.size) +@wrap_name_default("pool3d") +@layer_support() +def img_pool3d_layer(input, + pool_size, + name=None, + num_channels=None, + pool_type=None, + stride=1, + padding=0, + layer_attr=None, + pool_size_y=None, + stride_y=None, + padding_y=None, + pool_size_z=None, + stride_z=None, + padding_z=None, + ceil_mode=True): + """ + Image pooling Layer. + + The details of pooling layer, please refer ufldl's pooling_ . + + .. _pooling: http://ufldl.stanford.edu/tutorial/supervised/Pooling/ + + - ceil_mode=True: + + .. math:: + + w = 1 + int(ceil(input\_width + 2 * padding - pool\_size) / float(stride)) + h = 1 + int(ceil(input\_height + 2 * padding\_y - pool\_size\_y) / float(stride\_y)) + d = 1 + int(ceil(input\_depth + 2 * padding\_z - pool\_size\_z) / float(stride\_z)) + + - ceil_mode=False: + + .. math:: + + w = 1 + int(floor(input\_width + 2 * padding - pool\_size) / float(stride)) + h = 1 + int(floor(input\_height + 2 * padding\_y - pool\_size\_y) / float(stride\_y)) + d = 1 + int(floor(input\_depth + 2 * padding\_z - pool\_size\_z) / float(stride\_z)) + + The example usage is: + + .. code-block:: python + + maxpool = img_pool3d_layer(input=conv, + pool_size=3, + num_channels=8, + stride=1, + padding=1, + pool_type=MaxPooling()) + + :param padding: pooling padding width. + :type padding: int|tuple|list + :param name: name of pooling layer + :type name: basestring. + :param input: layer's input + :type input: LayerOutput + :param pool_size: pooling window width + :type pool_size: int|tuple|list + :param num_channels: number of input channel. + :type num_channels: int + :param pool_type: pooling type. MaxPooling or AvgPooling. Default is + MaxPooling. + :type pool_type: BasePoolingType + :param stride: stride width of pooling. + :type stride: int|tuple|list + :param layer_attr: Extra Layer attribute. + :type layer_attr: ExtraLayerAttribute + :param ceil_mode: Wether to use ceil mode to calculate output height and with. + Defalut is True. If set false, Otherwise use floor. + + :type ceil_mode: bool + :return: LayerOutput object. + :rtype: LayerOutput + """ + if num_channels is None: + assert input.num_filters is not None + num_channels = input.num_filters + + if pool_type is None: + pool_type = MaxPooling() + elif isinstance(pool_type, AvgPooling): + pool_type.name = 'avg' + + type_name = pool_type.name + '-projection' \ + if ( + isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)) \ + else pool_type.name + + if isinstance(pool_size, collections.Sequence): + assert len(pool_size) == 3 + pool_size, pool_size_y, pool_size_z = pool_size + else: + pool_size_y = pool_size + pool_size_z = pool_size + + if isinstance(stride, collections.Sequence): + assert len(stride) == 3 + stride, stride_y, stride_z = stride + else: + stride_y = stride + stride_z = stride + + if isinstance(padding, collections.Sequence): + assert len(padding) == 3 + padding, padding_y, padding_y = padding + else: + padding_y = padding + padding_z = padding + + l = Layer( + name=name, + type=LayerType.POOL3D_LAYER, + inputs=[ + Input( + input.name, + pool=Pool3d( + pool_type=type_name, + channels=num_channels, + size_x=pool_size, + start=None, + stride=stride, + padding=padding, + size_y=pool_size_y, + stride_y=stride_y, + padding_y=padding_y, + size_z=pool_size_z, + stride_z=stride_z, + padding_z=padding_z)) + ], + ceil_mode=ceil_mode, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name, + LayerType.POOL_LAYER, + parents=[input], + num_filters=num_channels, + size=l.config.size) + + @wrap_name_default("spp") @layer_support() def spp_layer(input, @@ -5852,7 +5994,7 @@ def cross_entropy_over_beam(input, name=None): Note that, if gold falls off the beam at search step t, then the cost is calculated over the beam at step t. - This cost layer always works together with kmax_sequence_score_layer, + This cost layer always works together with kmax_seq_score_layer, sub_nested_seq_layer, and sequence_slice_layer to trim the input to form a sub-search space. @@ -6455,14 +6597,14 @@ def seq_slice_layer(input, starts, ends, name=None): @wrap_name_default() @layer_support() -def kmax_sequence_score_layer(input, name=None, beam_size=1): +def kmax_seq_score_layer(input, name=None, beam_size=1): """ This layer accepts one input which are scores over a sequence or a nested sequence, and returns indices of beam_size sequences with highest scores. .. code-block:: python - kmax_indices = kmax_sequence_score_layer(input=input_layer, beam_size) + kmax_indices = kmax_seq_score_layer(input=input_layer, beam_size) :param name: The Layer Name. @@ -6475,10 +6617,10 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1): :return: LayerOutput object. :rtype: LayerOutput """ - assert isinstance(input, LayerOutput), ("kmax_sequence_score_layer " + assert isinstance(input, LayerOutput), ("kmax_seq_score_layer " "accepts only one input.") assert input.size == 1, ( - "input of kmax_sequence_score_layer is a score" + "input of kmax_seq_score_layer is a score " "over a sequence or a nested sequence, so its width must be 1.") Layer( diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 7982d34bc5c0d4d99b0a8468ab0db86134764eab..df872a90ff388f0d96cef44763dbd076bc768ab9 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -8,7 +8,8 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer -test_kmax_seq_socre_layer test_seq_select_layers test_scale_shift_layer -test_seq_slice_layer test_cross_entropy_over_beam test_conv3d_layer test_deconv3d_layer) +test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer +test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer +test_conv3d_layer test_deconv3d_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cross_entropy_over_beam.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cross_entropy_over_beam.protostr index c43fc48e222044b65d83b6162e7dc3954e119887..a602569697e91b11b8d421ac359c2e523a00fa98 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cross_entropy_over_beam.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cross_entropy_over_beam.protostr @@ -12,7 +12,7 @@ layers { active_type: "" } layers { - name: "__kmax_sequence_score_layer_0__" + name: "__kmax_seq_score_layer_0__" type: "kmax_seq_score" active_type: "" inputs { @@ -29,7 +29,7 @@ layers { input_layer_name: "sentence_states" } inputs { - input_layer_name: "__kmax_sequence_score_layer_0__" + input_layer_name: "__kmax_seq_score_layer_0__" } } layers { @@ -44,7 +44,7 @@ layers { bias_parameter_name: "___fc_layer_0__.wbias" } layers { - name: "__kmax_sequence_score_layer_1__" + name: "__kmax_seq_score_layer_1__" type: "kmax_seq_score" active_type: "" inputs { @@ -61,7 +61,7 @@ layers { input_layer_name: "__sub_nested_seq_layer_0__" } inputs { - input_layer_name: "__kmax_sequence_score_layer_1__" + input_layer_name: "__kmax_seq_score_layer_1__" } select_first: true } @@ -77,7 +77,7 @@ layers { bias_parameter_name: "___fc_layer_1__.wbias" } layers { - name: "__kmax_sequence_score_layer_2__" + name: "__kmax_seq_score_layer_2__" type: "kmax_seq_score" active_type: "" inputs { @@ -111,7 +111,7 @@ layers { input_layer_name: "sentence_scores" } inputs { - input_layer_name: "__kmax_sequence_score_layer_0__" + input_layer_name: "__kmax_seq_score_layer_0__" } inputs { input_layer_name: "sentences_ids" @@ -120,7 +120,7 @@ layers { input_layer_name: "__fc_layer_0__" } inputs { - input_layer_name: "__kmax_sequence_score_layer_1__" + input_layer_name: "__kmax_seq_score_layer_1__" } inputs { input_layer_name: "start_ids" @@ -129,7 +129,7 @@ layers { input_layer_name: "__fc_layer_1__" } inputs { - input_layer_name: "__kmax_sequence_score_layer_2__" + input_layer_name: "__kmax_seq_score_layer_2__" } inputs { input_layer_name: "end_ids" @@ -185,13 +185,13 @@ sub_models { name: "root" layer_names: "sentence_states" layer_names: "sentence_scores" - layer_names: "__kmax_sequence_score_layer_0__" + layer_names: "__kmax_seq_score_layer_0__" layer_names: "__sub_nested_seq_layer_0__" layer_names: "__fc_layer_0__" - layer_names: "__kmax_sequence_score_layer_1__" + layer_names: "__kmax_seq_score_layer_1__" layer_names: "__seq_slice_layer_0__" layer_names: "__fc_layer_1__" - layer_names: "__kmax_sequence_score_layer_2__" + layer_names: "__kmax_seq_score_layer_2__" layer_names: "sentences_ids" layer_names: "start_ids" layer_names: "end_ids" diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr index 3d32220bfbf5f4c67f88303cb9773ecfa484da4b..f93d368c8687573db80106b9cc4defa56a881e46 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr @@ -17,7 +17,7 @@ layers { bias_parameter_name: "___fc_layer_0__.wbias" } layers { - name: "__kmax_sequence_score_layer_0__" + name: "__kmax_seq_score_layer_0__" type: "kmax_seq_score" active_type: "" inputs { @@ -46,14 +46,14 @@ parameters { initial_smart: false } input_layer_names: "input_seq" -output_layer_names: "__kmax_sequence_score_layer_0__" +output_layer_names: "__kmax_seq_score_layer_0__" sub_models { name: "root" layer_names: "input_seq" layer_names: "__fc_layer_0__" - layer_names: "__kmax_sequence_score_layer_0__" + layer_names: "__kmax_seq_score_layer_0__" input_layer_names: "input_seq" - output_layer_names: "__kmax_sequence_score_layer_0__" + output_layer_names: "__kmax_seq_score_layer_0__" is_recurrent_layer_group: false } diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_pooling3D_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_pooling3D_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..8eb98593f6f692a445cf5088e101e9da3763b41d --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_pooling3D_layer.protostr @@ -0,0 +1,123 @@ +type: "nn" +layers { + name: "data_2d" + type: "data" + size: 6000 + active_type: "" + height: 20 + width: 10 +} +layers { + name: "pool___2d" + type: "pool" + size: 840 + active_type: "" + inputs { + input_layer_name: "data_2d" + pool_conf { + pool_type: "avg-projection" + channels: 30 + size_x: 5 + stride: 3 + output_x: 4 + img_size: 10 + padding: 1 + size_y: 5 + stride_y: 3 + output_y: 7 + img_size_y: 20 + padding_y: 1 + } + } + height: 7 + width: 4 +} +layers { + name: "data_3d_1" + type: "data" + size: 60000 + active_type: "" + height: 20 + width: 10 + depth: 10 +} +layers { + name: "pool_3d_1" + type: "pool3d" + size: 3360 + active_type: "" + inputs { + input_layer_name: "data_3d_1" + pool_conf { + pool_type: "avg-projection" + channels: 30 + size_x: 5 + stride: 3 + output_x: 4 + img_size: 10 + padding: 1 + size_y: 5 + stride_y: 3 + output_y: 7 + img_size_y: 20 + padding_y: 1 + size_z: 5 + stride_z: 3 + output_z: 4 + img_size_z: 10 + padding_z: 1 + } + } + height: 7 + width: 4 + depth: 4 +} +layers { + name: "pool_3d_2" + type: "pool3d" + size: 3360 + active_type: "" + inputs { + input_layer_name: "data_3d_1" + pool_conf { + pool_type: "max-projection" + channels: 30 + size_x: 5 + stride: 3 + output_x: 4 + img_size: 10 + padding: 1 + size_y: 5 + stride_y: 3 + output_y: 7 + img_size_y: 20 + padding_y: 1 + size_z: 5 + stride_z: 3 + output_z: 4 + img_size_z: 10 + padding_z: 1 + } + } + height: 7 + width: 4 + depth: 4 +} +input_layer_names: "data_2d" +output_layer_names: "pool___2d" +output_layer_names: "pool_3d_1" +output_layer_names: "pool_3d_2" +sub_models { + name: "root" + layer_names: "data_2d" + layer_names: "pool___2d" + layer_names: "data_3d_1" + layer_names: "pool_3d_1" + layer_names: "pool_3d_2" + input_layer_names: "data_2d" + output_layer_names: "pool___2d" + output_layer_names: "pool_3d_1" + output_layer_names: "pool_3d_2" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_cross_entropy_over_beam.py b/python/paddle/trainer_config_helpers/tests/configs/test_cross_entropy_over_beam.py index 240e703dc904e718c2c1ddaf2b6d7dccb4dabf41..4a5bdf1181dc4538418a8b89b41a1ff713e423c8 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/test_cross_entropy_over_beam.py +++ b/python/paddle/trainer_config_helpers/tests/configs/test_cross_entropy_over_beam.py @@ -7,14 +7,14 @@ beam_size = 5 # the first beam expansion. sentence_states = data_layer(name="sentence_states", size=32) sentence_scores = data_layer(name="sentence_scores", size=1) -topk_sentence_ids = kmax_sequence_score_layer( +topk_sentence_ids = kmax_seq_score_layer( input=sentence_scores, beam_size=beam_size) # the second beam expansion. topk_sen = sub_nested_seq_layer( input=sentence_states, selected_indices=topk_sentence_ids) start_pos_scores = fc_layer(input=topk_sen, size=1, act=LinearActivation()) -topk_start_pos_ids = kmax_sequence_score_layer( +topk_start_pos_ids = kmax_seq_score_layer( input=sentence_scores, beam_size=beam_size) # the final beam expansion. @@ -22,7 +22,7 @@ topk_start_spans = seq_slice_layer( input=topk_sen, starts=topk_start_pos_ids, ends=None) end_pos_scores = fc_layer( input=topk_start_spans, size=1, act=LinearActivation()) -topk_end_pos_ids = kmax_sequence_score_layer( +topk_end_pos_ids = kmax_seq_score_layer( input=end_pos_scores, beam_size=beam_size) # define the cost diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py index 48d0cd55da2481743de66ea95190c0856e7ddc39..171da10f75dae03eed7e110d0efd07d6a18e1ecf 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py +++ b/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py @@ -4,6 +4,6 @@ from paddle.trainer_config_helpers import * data = data_layer(name="input_seq", size=128) scores = fc_layer(input=data, size=1, act=ExpActivation()) -kmax_seq_id = kmax_sequence_score_layer(input=scores, beam_size=5) +kmax_seq_id = kmax_seq_score_layer(input=scores, beam_size=5) outputs(kmax_seq_id) diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_pooling3D_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_pooling3D_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbb921d41986e711d5b8b31caab1f8b6bdc47b8 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_pooling3D_layer.py @@ -0,0 +1,38 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=100, learning_rate=1e-5) + +data_2d = data_layer(name='data_2d', size=6000, height=20, width=10) + +pool_2d = img_pool_layer( + name="pool___2d", + input=data_2d, + num_channels=30, + pool_size=5, + stride=3, + padding=1, + pool_type=AvgPooling()) +outputs(pool_2d) + +data_3d = data_layer( + name='data_3d_1', size=60000, depth=10, height=20, width=10) + +pool_3d_1 = img_pool3d_layer( + name="pool_3d_1", + input=data_3d, + num_channels=30, + pool_size=5, + stride=3, + padding=1, + pool_type=AvgPooling()) +outputs(pool_3d_1) + +pool_3d_2 = img_pool3d_layer( + name="pool_3d_2", + input=data_3d, + num_channels=30, + pool_size=[5, 5, 5], + stride=[3, 3, 3], + padding=[1, 1, 1], + pool_type=MaxPooling()) +outputs(pool_3d_2) diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_seq_select_layers.py b/python/paddle/trainer_config_helpers/tests/configs/test_sub_nested_seq_select_layer.py similarity index 100% rename from python/paddle/trainer_config_helpers/tests/configs/test_seq_select_layers.py rename to python/paddle/trainer_config_helpers/tests/configs/test_sub_nested_seq_select_layer.py diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 5bea980611904b37a4a5d4e2cbbee13503a61ff0..1c8d8f4b2f626bea5d9a44d01de7c2c9c45dc2fb 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -78,6 +78,8 @@ def init(**kwargs): if 'use_gpu' in kwargs: cp.g_command_config_args['use_gpu'] = kwargs['use_gpu'] + if 'use_mkldnn' in kwargs: + cp.g_command_config_args['use_mkldnn'] = kwargs['use_mkldnn'] assert 'parallel_nn' not in kwargs, ("currently 'parallel_nn' is not " "supported in v2 APIs.")