提交 2f33f74a 编写于 作者: D Dong Zhihong

Merge remote-tracking branch 'origin/develop' into feature/evaluator

......@@ -21,7 +21,7 @@ third_party/
cmake-build-*
# generated while compiling
python/paddle/v2/framework/core.so
python/paddle/v2/fluid/core.so
paddle/pybind/pybind.h
CMakeFiles
cmake_install.cmake
......
......@@ -121,6 +121,7 @@ paddle_error paddle_matrix_get_shape(paddle_matrix mat,
paddle_matrix paddle_matrix_create_sparse(
uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu) {
#ifndef PADDLE_MOBILE_INFERENCE
auto ptr = new paddle::capi::CMatrix();
ptr->mat = paddle::Matrix::createSparseMatrix(
height,
......@@ -131,6 +132,9 @@ paddle_matrix paddle_matrix_create_sparse(
false,
useGpu);
return ptr;
#else
return nullptr;
#endif
}
paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
......@@ -140,6 +144,7 @@ paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
uint64_t colSize,
float* valueArray,
uint64_t valueSize) {
#ifndef PADDLE_MOBILE_INFERENCE
if (mat == nullptr) return kPD_NULLPTR;
auto ptr = cast(mat);
if (rowArray == nullptr || colArray == nullptr ||
......@@ -160,4 +165,7 @@ paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
} else {
return kPD_NOT_SUPPORTED;
}
#else
return kPD_NOT_SUPPORTED;
#endif
}
......@@ -48,6 +48,7 @@ PD_API paddle_matrix paddle_matrix_create(uint64_t height,
* @param isBinary is binary (either 1 or 0 in matrix) or not.
* @param useGpu is using GPU or not.
* @return paddle_matrix.
* @note Mobile inference does not support this interface.
*/
PD_API paddle_matrix paddle_matrix_create_sparse(
uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu);
......@@ -129,6 +130,7 @@ PD_API paddle_error paddle_matrix_get_shape(paddle_matrix mat,
* NULL if the matrix is binary.
* @param [in] valueSize length of value array. Zero if the matrix is binary.
* @return paddle_error
* @note Mobile inference does not support this interface.
*/
PD_API paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
int* rowArray,
......
......@@ -27,7 +27,9 @@ if(WITH_GPU)
set_source_files_properties(${CUDA_CXX_SOURCES}
PROPERTIES COMPILE_FLAGS "-D__NVCC__")
else()
if (NOT MOBILE_INFERENCE)
set(CUDA_CXX_SOURCES src/hl_warpctc_wrap.cc)
endif()
endif()
set(CUDA_CU_SOURCES
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "hl_base.h"
/**
* @brief Maximum pool forward.
* @brief Maximum pool forward with Mask output.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inputData input data.
......@@ -35,7 +35,7 @@ limitations under the License. */
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
*
* @param[out] maskData the location indices of select max data.
*/
extern void hl_maxpool_forward(const int frameCnt,
const real* inputData,
......@@ -51,7 +51,8 @@ extern void hl_maxpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride);
const int tgtStride,
real* maskData = NULL);
/**
* @brief Maximum pool backward.
......
......@@ -31,7 +31,8 @@ inline void hl_maxpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride) {}
const int tgtStride,
real* MaskData) {}
inline void hl_maxpool_backward(const int frameCnt,
const real* inputData,
......
......@@ -31,7 +31,8 @@ __global__ void KeMaxPoolForward(const int nthreads,
const int offsetH,
const int offsetW,
real* tgtData,
const int tgtStride) {
const int tgtStride,
real* maskData) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int pw = index % pooledW;
......@@ -45,16 +46,22 @@ __global__ void KeMaxPoolForward(const int nthreads,
hstart = max(hstart, 0);
wstart = max(wstart, 0);
real maxval = -FLT_MAX;
int max_index = -1;
inputData += (frameNum * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (maxval < inputData[h * width + w])
maxval = inputData[h * width + w];
if (maxval < inputData[h * width + w]) {
max_index = h * width + w;
maxval = inputData[max_index];
}
}
}
int tgtIndex =
index % (pooledW * pooledH * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = maxval;
if (maskData != NULL) {
maskData[tgtIndex] = max_index;
}
}
}
......@@ -72,7 +79,8 @@ void hl_maxpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride) {
const int tgtStride,
real* maskData) {
int num_kernels = pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;
dim3 threads(1024, 1);
......@@ -92,7 +100,8 @@ void hl_maxpool_forward(const int frameCnt,
paddingH,
paddingW,
tgtData,
tgtStride);
tgtStride,
maskData);
CHECK_SYNC("hl_maxpool_forward failed");
}
......
......@@ -61,6 +61,7 @@ public:
// function arguments
strides_ = config.get<std::vector<size_t>>("strides");
paddings_ = config.get<std::vector<size_t>>("paddings");
dilations_ = config.get<std::vector<size_t>>("dilations");
groups_ = config.get<size_t>("groups");
// number of inputs and outputs
......@@ -118,6 +119,7 @@ protected:
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
std::vector<size_t> dilations_;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
......@@ -133,6 +135,10 @@ protected:
inline int paddingW() const { return paddings_[1]; }
inline int dilationH() const { return dilations_[0]; }
inline int dilationW() const { return dilations_[1]; }
// A temporary memory in convolution calculation.
MemoryHandlePtr memory_;
......
......@@ -79,45 +79,59 @@ void Convolution(const std::string& conv1,
if (outputChannels < inputChannels) continue;
for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) {
if (padding >= filterSize) break;
for (size_t dilation : {1, 3}) {
if (padding >= filterSize) break;
size_t filterS = (filterSize - 1) * dilation + 1;
// NNPACK only supports stride = 1 if batchSize > 1
if ((conv1 == "NNPACKConv-CPU" || conv2 == "NNPACKConv-CPU") &&
batchSize > 1 && stride > 1)
break;
if (inputSize + 2 * padding < filterS) break;
size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / stride;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize << " stride=" << stride
<< " padding=" << padding;
if ((conv1 == "NaiveConv-CPU" || conv2 == "NaiveConv-CPU" ||
conv1 == "NNPACKConv-CPU" ||
conv2 == "NNPACKConv-CPU") &&
dilation > 1)
break;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", (std::string) "auto"));
// NNPACK only supports stride = 1 if batchSize > 1
if ((conv1 == "NNPACKConv-CPU" ||
conv2 == "NNPACKConv-CPU") &&
batchSize > 1 && stride > 1)
break;
TensorShape input{
batchSize, inputChannels, inputSize, inputSize};
TensorShape filter{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
size_t outputSize =
(inputSize - filterS + 2 * padding + stride) / stride;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
function(test, input, filter, output);
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
std::vector<size_t> dilations = {dilation, dilation};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)1)
.set("algo", (std::string) "auto"));
TensorShape input{
batchSize, inputChannels, inputSize, inputSize};
TensorShape filter{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
function(test, input, filter, output);
}
}
}
}
......@@ -144,6 +158,7 @@ void Convolution2(const std::string& conv1,
for (size_t outputChannels : {7}) {
size_t stride = 1;
size_t padding = 0;
size_t dilation = 1;
size_t outputHeight =
(inputHeight - filterHeight + 2 * padding + stride) /
stride;
......@@ -162,6 +177,7 @@ void Convolution2(const std::string& conv1,
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
std::vector<size_t> dilations = {dilation, dilation};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
......@@ -169,6 +185,7 @@ void Convolution2(const std::string& conv1,
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("dilations", dilations)
.set("algo", (std::string) "auto"));
TensorShape input{
......@@ -223,6 +240,7 @@ void DepthwiseConvolution(const std::string& conv1,
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
std::vector<size_t> dilations = {1, 1};
size_t groups = inputChannels;
Compare2Function<DType1, DType2> test(
conv1,
......@@ -231,6 +249,7 @@ void DepthwiseConvolution(const std::string& conv1,
.set("paddings", paddings)
.set("strides", strides)
.set("groups", groups)
.set("dilations", dilations)
.set("algo", (std::string) "auto"));
TensorShape input{
......
......@@ -100,7 +100,9 @@ public:
strideH(),
strideW(),
paddingH(),
paddingW());
paddingW(),
dilationH(),
dilationW());
} else {
colData = inputData + g * inputOffset;
}
......@@ -223,7 +225,9 @@ public:
strideH(),
strideW(),
paddingH(),
paddingW());
paddingW(),
dilationH(),
dilationW());
}
}
inputGrad += inputChannels * inputHeight * inputWidth;
......@@ -310,7 +314,9 @@ public:
strideH(),
strideW(),
paddingH(),
paddingW());
paddingW(),
dilationH(),
dilationW());
} else {
colData = inputData + g * inputOffset;
}
......
......@@ -78,7 +78,9 @@ public:
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth);
int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1);
};
template <ColFormat Format, DeviceType Device, class T>
......@@ -91,7 +93,9 @@ public:
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth);
int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1);
};
} // namespace paddle
......@@ -31,7 +31,9 @@ public:
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
......@@ -47,8 +49,8 @@ public:
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
......@@ -81,7 +83,9 @@ public:
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
......@@ -97,8 +101,8 @@ public:
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) >= 0 &&
(imRowIdx - paddingHeight) < inputHeight &&
(imColIdx - paddingWidth) >= 0 &&
......@@ -134,7 +138,9 @@ public:
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
......@@ -147,9 +153,10 @@ public:
for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset =
outputH * strideHeight + filterH - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth;
int imRowOffset = outputH * strideHeight +
filterH * dilationHeight - paddingHeight;
int imColOffset = outputW * strideWidth +
filterW * dilationWidth - paddingWidth;
int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels +
channel) *
......@@ -189,7 +196,9 @@ public:
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
......@@ -202,9 +211,10 @@ public:
for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset =
outputH * strideHeight + filterH - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth;
int imRowOffset = outputH * strideHeight +
filterH * dilationHeight - paddingHeight;
int imColOffset = outputW * strideWidth +
filterW * dilationWidth - paddingWidth;
int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels +
channel) *
......
......@@ -28,6 +28,8 @@ __global__ void im2col(const T* data_im,
int strideW,
int paddingH,
int paddingW,
int dilationH,
int dilationW,
int height_col,
int width_col,
T* data_col) {
......@@ -44,8 +46,8 @@ __global__ void im2col(const T* data_im,
data_col += (channel_out * height_col + h_out) * width_col + w_out;
for (int i = 0; i < blockH; ++i) {
for (int j = 0; j < blockW; ++j) {
int rIdx = int(h_in + i);
int cIdx = int(w_in + j);
int rIdx = int(h_in + i * dilationH);
int cIdx = int(w_in + j * dilationW);
if ((rIdx - (int)paddingH) >= (int)height ||
(rIdx - (int)paddingH) < 0 ||
(cIdx - (int)paddingW) >= (int)width ||
......@@ -77,7 +79,9 @@ public:
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
......@@ -102,6 +106,8 @@ public:
strideWidth,
paddingHeight,
paddingWidth,
dilationHeight,
dilationWidth,
outputHeight,
outputWidth,
colData);
......@@ -121,6 +127,8 @@ __global__ void col2im(size_t n,
size_t strideW,
size_t paddingH,
size_t paddingW,
size_t dilationH,
size_t dilationW,
size_t height_col,
size_t width_col,
T* data_im) {
......@@ -131,23 +139,34 @@ __global__ void col2im(size_t n,
int w = int(index % width);
int h = int((index / width) % height);
int c = int(index / (width * height));
int filterH = (blockH - 1) * dilationH + 1;
int filterW = (blockW - 1) * dilationW + 1;
if ((w - (int)paddingW) >= 0 &&
(w - (int)paddingW) < (width - 2 * paddingW) &&
(h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) {
// compute the start and end of the output
int w_col_start =
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
(w < (int)filterW) ? 0 : (w - int(filterW)) / (int)strideW + 1;
int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col));
int h_col_start =
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
(h < (int)filterH) ? 0 : (h - (int)filterH) / (int)strideH + 1;
int h_col_end = min(int(h / strideH + 1), int(height_col));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = int(c * blockH * blockW) +
(h - h_col * (int)strideH) * (int)blockW +
(w - w_col * (int)strideW);
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
int h_k = (h - h_col * strideH);
int w_k = (w - w_col * strideW);
if (h_k % dilationH == 0 && w_k % dilationW == 0) {
h_k /= dilationH;
w_k /= dilationW;
int c_col =
(((c * blockH + h_k) * blockW + w_k) * height_col + h_col) *
width_col +
w_col;
val += data_col[c_col];
}
}
}
h -= paddingH;
......@@ -173,7 +192,9 @@ public:
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
......@@ -205,6 +226,8 @@ public:
strideWidth,
paddingHeight,
paddingWidth,
dilationHeight,
dilationWidth,
outputHeight,
outputWidth,
imData);
......@@ -229,6 +252,8 @@ __global__ void im2colOCF(const T* imData,
int strideWidth,
int paddingHeight,
int paddingWidth,
int dilationHeight,
int dilationWidth,
int outputHeight,
int outputWidth) {
int swId = blockIdx.x;
......@@ -237,8 +262,10 @@ __global__ void im2colOCF(const T* imData,
channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth;
int heightOffset = idy + shId * strideHeight - paddingHeight;
int widthOffset =
idx * dilationHeight + swId * strideWidth - paddingWidth;
int heightOffset =
idy * dilationWidth + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth +
channelId * inputHeight * inputWidth;
......@@ -273,7 +300,9 @@ public:
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
......@@ -312,6 +341,8 @@ public:
strideWidth,
paddingHeight,
paddingWidth,
dilationHeight,
dilationWidth,
outputHeight,
outputWidth);
CHECK_SYNC("Im2ColFunctor GPU failed");
......@@ -330,6 +361,8 @@ __global__ void col2imOCF(T* imData,
int strideWidth,
int paddingHeight,
int paddingWidth,
int dilationHeight,
int dilationWidth,
int outputHeight,
int outputWidth) {
int swId = blockIdx.x;
......@@ -338,8 +371,10 @@ __global__ void col2imOCF(T* imData,
channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth;
int heightOffset = idy + shId * strideHeight - paddingHeight;
int widthOffset =
idx * dilationWidth + swId * strideWidth - paddingWidth;
int heightOffset =
idy * dilationHeight + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth +
channelId * inputHeight * inputWidth;
......@@ -372,7 +407,9 @@ public:
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
......@@ -411,6 +448,8 @@ public:
strideWidth,
paddingHeight,
paddingWidth,
dilationHeight,
dilationWidth,
outputHeight,
outputWidth);
CHECK_SYNC("Col2ImFunctor GPU failed");
......
......@@ -29,82 +29,98 @@ void TestIm2ColFunctor() {
for (size_t filterWidth : {3, 7}) {
for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) {
if (inputHeight <= filterHeight || inputWidth <= filterWidth)
break;
if (padding >= filterHeight || padding >= filterWidth) break;
size_t outputHeight =
(inputHeight - filterHeight + 2 * padding + stride) /
stride;
size_t outputWidth =
(inputWidth - filterWidth + 2 * padding + stride) / stride;
TensorShape imShape =
TensorShape({channels, inputHeight, inputWidth});
TensorShape colShape1 = TensorShape({channels,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
TensorShape colShape2 = TensorShape({outputHeight,
outputWidth,
channels,
filterHeight,
filterWidth});
size_t height = channels * filterHeight * filterWidth;
size_t width = outputHeight * outputWidth;
VectorPtr input1 = Vector::create(imShape.getElements(), false);
VectorPtr input2 = Vector::create(imShape.getElements(), false);
MatrixPtr output1 = Matrix::create(height, width, false, false);
MatrixPtr output2 = Matrix::create(width, height, false, false);
input1->uniform(0.001, 1);
input2->copyFrom(*input1);
Im2ColFunctor<kCFO, Device, T> im2Col1;
Im2ColFunctor<kOCF, Device, T> im2Col2;
im2Col1(input1->getData(),
imShape,
output1->getData(),
colShape1,
stride,
stride,
padding,
padding);
im2Col2(input2->getData(),
imShape,
output2->getData(),
colShape2,
stride,
stride,
padding,
padding);
// The transposition of the result of ColFormat == kCFO
// is equal to the result of ColFormat == kOCF.
MatrixPtr test;
output2->transpose(test, true);
autotest::TensorCheckErr(*output1, *test);
Col2ImFunctor<kCFO, Device, T> col2Im1;
Col2ImFunctor<kOCF, Device, T> col2Im2;
col2Im1(input1->getData(),
imShape,
output1->getData(),
colShape1,
stride,
stride,
padding,
padding);
col2Im2(input2->getData(),
imShape,
output2->getData(),
colShape2,
stride,
stride,
padding,
padding);
autotest::TensorCheckErr(*input1, *input2);
for (size_t dilation : {1, 3}) {
size_t filterSizeH = (filterHeight - 1) * dilation + 1;
size_t filterSizeW = (filterWidth - 1) * dilation + 1;
if (inputHeight + 2 * padding < filterSizeH ||
inputWidth + 2 * padding < filterSizeW)
break;
if (padding >= filterSizeH || padding >= filterSizeW) break;
size_t outputHeight =
(inputHeight - filterSizeH + 2 * padding) / stride + 1;
size_t outputWidth =
(inputWidth - filterSizeW + 2 * padding) / stride + 1;
TensorShape imShape =
TensorShape({channels, inputHeight, inputWidth});
TensorShape colShape1 = TensorShape({channels,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
TensorShape colShape2 = TensorShape({outputHeight,
outputWidth,
channels,
filterHeight,
filterWidth});
size_t height = channels * filterHeight * filterWidth;
size_t width = outputHeight * outputWidth;
VectorPtr input1 =
Vector::create(imShape.getElements(), false);
VectorPtr input2 =
Vector::create(imShape.getElements(), false);
MatrixPtr output1 =
Matrix::create(height, width, false, false);
MatrixPtr output2 =
Matrix::create(width, height, false, false);
input1->uniform(0.001, 1);
input2->copyFrom(*input1);
Im2ColFunctor<kCFO, Device, T> im2Col1;
Im2ColFunctor<kOCF, Device, T> im2Col2;
im2Col1(input1->getData(),
imShape,
output1->getData(),
colShape1,
stride,
stride,
padding,
padding,
dilation,
dilation);
im2Col2(input2->getData(),
imShape,
output2->getData(),
colShape2,
stride,
stride,
padding,
padding,
dilation,
dilation);
// The transposition of the result of ColFormat == kCFO
// is equal to the result of ColFormat == kOCF.
MatrixPtr test;
output2->transpose(test, true);
autotest::TensorCheckErr(*output1, *test);
Col2ImFunctor<kCFO, Device, T> col2Im1;
Col2ImFunctor<kOCF, Device, T> col2Im2;
col2Im1(input1->getData(),
imShape,
output1->getData(),
colShape1,
stride,
stride,
padding,
padding,
dilation,
dilation);
col2Im2(input2->getData(),
imShape,
output2->getData(),
colShape2,
stride,
stride,
padding,
padding,
dilation,
dilation);
autotest::TensorCheckErr(*input1, *input2);
}
}
}
}
......
......@@ -85,9 +85,49 @@ if(MOBILE_INFERENCE)
gradientmachines/GradientMachineMode.cpp
gradientmachines/MultiGradientMachine.cpp)
# Remove useless layers
# Remove layers that used in training
list(REMOVE_ITEM GSERVER_SOURCES
layers/RecurrentLayerGroup.cpp)
layers/RecurrentLayerGroup.cpp
layers/CostLayer.cpp
layers/MultiBoxLossLayer.cpp
layers/WarpCTCLayer.cpp
layers/CTCLayer.cpp
layers/LinearChainCTC.cpp
layers/PrintLayer.cpp)
list(REMOVE_ITEM GSERVER_SOURCES
layers/OuterProdLayer.cpp
layers/SumToOneNormLayer.cpp
layers/ConvShiftLayer.cpp
layers/InterpolationLayer.cpp
layers/AgentLayer.cpp
layers/DotMulOperator.cpp
layers/GruStepLayer.cpp
layers/LstmStepLayer.cpp
layers/ConvexCombinationLayer.cpp
layers/Conv3DLayer.cpp
layers/DeConv3DLayer.cpp
layers/CropLayer.cpp
layers/CrossEntropyOverBeam.cpp
layers/DataNormLayer.cpp
layers/FeatureMapExpandLayer.cpp
layers/HierarchicalSigmoidLayer.cpp
layers/MultinomialSampler.cpp
layers/NCELayer.cpp
layers/KmaxSeqScoreLayer.cpp
layers/MDLstmLayer.cpp
layers/MultiplexLayer.cpp
layers/PadLayer.cpp
layers/Pool3DLayer.cpp
layers/ResizeLayer.cpp
layers/RotateLayer.cpp
layers/RowConvLayer.cpp
layers/RowL2NormLayer.cpp
layers/SamplingIdLayer.cpp
layers/ScaleShiftLayer.cpp
layers/SelectiveFullyConnectedLayer.cpp
layers/SpatialPyramidPoolLayer.cpp
layers/BilinearInterpLayer.cpp
layers/ClipLayer.cpp)
endif()
if(WITH_GPU)
......
......@@ -16,7 +16,6 @@ limitations under the License. */
#include "NeuralNetwork.h"
#include "hl_gpu.h"
#include "paddle/gserver/layers/AgentLayer.h"
#include "paddle/utils/CustomStackTrace.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
......@@ -28,6 +27,7 @@ limitations under the License. */
#ifndef PADDLE_MOBILE_INFERENCE
#include "MultiNetwork.h"
#include "RecurrentGradientMachine.h"
#include "paddle/gserver/layers/AgentLayer.h"
#endif
namespace paddle {
......@@ -192,9 +192,11 @@ void NeuralNetwork::init(const ModelConfig& config,
void NeuralNetwork::connect(LayerPtr agentLayer,
LayerPtr realLayer,
int height) {
#ifndef PADDLE_MOBILE_INFERENCE
AgentLayer* agent = dynamic_cast<AgentLayer*>(agentLayer.get());
CHECK_NOTNULL(agent);
agent->setRealLayer(realLayer, height);
#endif
}
void NeuralNetwork::connect(std::string agentLayerName,
......
......@@ -79,6 +79,10 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
for (int i = 0; i < config_.inputs_size(); i++) {
std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]};
std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]};
std::vector<size_t> dilations = {(size_t)dilationY_[i],
(size_t)dilation_[i]};
bool useDilation = ((size_t)dilationY_[i] > 1 || (size_t)dilation_[i] > 1);
// Convolution Layer uses the GemmConv function by default.
convType = "GemmConv";
......@@ -97,13 +101,14 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
if ((filterSize_[i] == filterSizeY_[i]) &&
(filterSize_[i] == 3 || filterSize_[i] == 4) &&
(stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2)) {
(stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2) &&
!useDilation) {
convType = "NeonDepthwiseConv";
}
#endif
}
if (FLAGS_use_nnpack && !isDeconv_) {
if (FLAGS_use_nnpack && !isDeconv_ && !useDilation) {
createFunction(forward_,
"NNPACKConv",
FuncConfig()
......@@ -117,6 +122,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)groups_[i]));
createFunction(backward_,
......@@ -124,6 +130,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)groups_[i]));
createFunction(backward_,
......@@ -131,6 +138,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)groups_[i]));
}
}
......
......@@ -98,6 +98,7 @@ ClassRegistrar<Layer, LayerConfig> Layer::registrar_;
LayerPtr Layer::create(const LayerConfig& config) {
std::string type = config.type();
#ifndef PADDLE_MOBILE_INFERENCE
// NOTE: As following types have illegal character '-',
// they can not use REGISTER_LAYER to registrar.
// Besides, to fit with old training models,
......@@ -106,7 +107,6 @@ LayerPtr Layer::create(const LayerConfig& config) {
return LayerPtr(new MultiClassCrossEntropy(config));
else if (type == "rank-cost")
return LayerPtr(new RankingCost(config));
#ifndef PADDLE_MOBILE_INFERENCE
else if (type == "auc-validation")
return LayerPtr(new AucValidation(config));
else if (type == "pnpair-validation")
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MaxPoolWithMaskLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
bool MaxPoolWithMaskLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
PoolLayer::init(layerMap, parameterMap);
setOutput("mask", &mask_);
return true;
}
size_t MaxPoolWithMaskLayer::getSize() {
CHECK_EQ(inputLayers_.size(), 1UL);
size_t layerSize = 0;
outputY_ = outputSize(imgSizeY_,
sizeY_,
confPaddingY_,
strideY_,
/* caffeMode */ false);
outputX_ = outputSize(imgSize_,
sizeX_,
confPadding_,
stride_,
/* caffeMode */ false);
layerSize = outputX_ * outputY_ * channels_;
getOutput().setFrameHeight(outputY_);
getOutput().setFrameWidth(outputX_);
return layerSize;
}
void MaxPoolWithMaskLayer::forward(PassType passType) {
size_t size = getSize();
MatrixPtr inputV = inputLayers_[0]->getOutputValue();
int batchSize = inputV->getHeight();
resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
CHECK_EQ(size, outV->getWidth());
resetSpecifyOutput(mask_,
batchSize,
size,
/* isValueClean */ false,
/* isGradClean */ true);
MatrixPtr maskV = mask_.value;
outV->maxPoolForward(*inputV,
imgSizeY_,
imgSize_,
channels_,
sizeX_,
sizeY_,
strideY_,
stride_,
outputY_,
outputX_,
confPaddingY_,
confPadding_,
maskV);
}
void MaxPoolWithMaskLayer::backward(const UpdateCallback& callback) {
(void)callback;
if (NULL == getInputGrad(0)) {
return;
}
MatrixPtr outGrad = getOutputGrad();
MatrixPtr inputV = inputLayers_[0]->getOutputValue();
MatrixPtr outV = getOutputValue();
MatrixPtr inputGrad = inputLayers_[0]->getOutputGrad();
inputGrad->maxPoolBackward(*inputV,
imgSizeY_,
imgSize_,
*outGrad,
*outV,
sizeX_,
sizeY_,
strideY_,
stride_,
outputY_,
outputX_,
1,
1,
confPaddingY_,
confPadding_);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "PoolLayer.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief Basic parent layer of different kinds of pooling
*/
class MaxPoolWithMaskLayer : public PoolLayer {
protected:
Argument mask_;
public:
explicit MaxPoolWithMaskLayer(const LayerConfig& config)
: PoolLayer(config) {}
size_t getSize();
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
};
} // namespace paddle
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "PoolLayer.h"
#include "MaxPoolWithMaskLayer.h"
#include "PoolProjectionLayer.h"
#include "paddle/utils/Logging.h"
#ifdef PADDLE_WITH_CUDA
......@@ -44,7 +45,6 @@ bool PoolLayer::init(const LayerMap& layerMap,
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
return true;
}
......@@ -57,6 +57,8 @@ Layer* PoolLayer::create(const LayerConfig& config) {
} else if (CudnnPoolLayer::typeCheck(pool)) {
return new CudnnPoolLayer(config);
#endif
} else if (pool == "max-pool-with-mask") {
return new MaxPoolWithMaskLayer(config);
} else {
LOG(FATAL) << "Unknown pool type: " << pool;
return nullptr;
......
# gserver pacakge unittests
add_simple_unittest(test_LinearChainCRF)
add_simple_unittest(test_MultinomialSampler)
add_simple_unittest(test_RecurrentLayer)
if(NOT MOBILE_INFERENCE)
add_simple_unittest(test_MultinomialSampler)
endif()
function(gserver_test TARGET)
add_unittest_without_exec(${TARGET}
${TARGET}.cpp
......@@ -24,6 +27,7 @@ gserver_test(test_ConvUnify)
gserver_test(test_BatchNorm)
gserver_test(test_KmaxSeqScore)
gserver_test(test_Expand)
gserver_test(test_MaxPoolingWithMaskOutput)
########## test_Mkldnn layers and activations ##########
if(WITH_MKLDNN)
......@@ -48,7 +52,7 @@ if(WITH_PYTHON)
endif()
############### test_WarpCTCLayer #######################
if(NOT WITH_DOUBLE)
if(NOT WITH_DOUBLE AND NOT MOBILE_INFERENCE)
add_unittest_without_exec(test_WarpCTCLayer
test_WarpCTCLayer.cpp)
......
......@@ -434,7 +434,7 @@ void testConvLayer(const string& type, bool trans, bool useGpu) {
config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true);
int dilation = 1;
int dilation = 2;
if (type == "cudnn_conv") {
#if CUDNN_VERSION >= 6000
dilation = 2;
......@@ -1234,6 +1234,7 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) {
TEST(Layer, PoolLayer) {
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false);
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false);
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ false);
#ifdef PADDLE_WITH_CUDA
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true);
......@@ -1242,6 +1243,7 @@ TEST(Layer, PoolLayer) {
testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true);
#endif
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "LayerGradUtil.h"
#include "paddle/math/MathUtils.h"
#include "paddle/testing/TestUtil.h"
using namespace paddle;
void setPoolConfig(TestConfig* config,
PoolConfig* pool,
const string& poolType) {
(*config).biasSize = 0;
(*config).layerConfig.set_type("pool");
(*config).layerConfig.set_num_filters(1);
int kw = 3, kh = 3;
int pw = 0, ph = 0;
int sw = 2, sh = 2;
pool->set_pool_type(poolType);
pool->set_channels(1);
pool->set_size_x(kw);
pool->set_size_y(kh);
pool->set_start(0);
pool->set_padding(pw);
pool->set_padding_y(ph);
pool->set_stride(sw);
pool->set_stride_y(sh);
int ow = outputSize(pool->img_size(), kw, pw, sw, /* caffeMode */ false);
int oh = outputSize(pool->img_size_y(), kh, ph, sh, /* caffeMode */ false);
pool->set_output_x(ow);
pool->set_output_y(oh);
}
void doOneMaxPoolingWithMaskOutputTest(MatrixPtr& inputMat,
const string& poolType,
bool use_gpu,
MatrixPtr& maskMat) {
TestConfig config;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 25, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
PoolConfig* pool = input->mutable_pool_conf();
pool->set_img_size(5);
pool->set_img_size_y(5);
setPoolConfig(&config, pool, poolType);
config.layerConfig.set_size(pool->output_x() * pool->output_y() *
pool->channels());
config.layerConfig.set_name("MaxPoolWithMask");
std::vector<DataLayerPtr> dataLayers;
LayerMap layerMap;
vector<Argument> datas;
initDataLayer(config,
&dataLayers,
&datas,
&layerMap,
"MaxPoolWithMask",
1,
false,
use_gpu);
dataLayers[0]->getOutputValue()->copyFrom(*inputMat);
FLAGS_use_gpu = use_gpu;
std::vector<ParameterPtr> parameters;
LayerPtr maxPoolingWithMaskOutputLayer;
initTestLayer(config, &layerMap, &parameters, &maxPoolingWithMaskOutputLayer);
maxPoolingWithMaskOutputLayer->forward(PASS_GC);
checkMatrixEqual(maxPoolingWithMaskOutputLayer->getOutput("mask").value,
maskMat);
}
TEST(Layer, maxPoolingWithMaskOutputLayerFwd) {
bool useGpu = false;
MatrixPtr inputMat;
MatrixPtr maskMat;
real inputData[] = {0.1, 0.1, 0.5, 0.5, 1.1, 0.2, 0.2, 0.6, 0.1,
0.1, 0.3, 0.3, 0.7, 0.1, 0.1, 0.4, 0.4, 0.8,
0.8, 0.1, 1.0, 2.0, 3.0, 0.0, 9.0};
real maskData[] = {12, 4, 22, 24};
inputMat = Matrix::create(1, 25, false, useGpu);
maskMat = Matrix::create(1, 4, false, useGpu);
inputMat->setData(inputData);
maskMat->setData(maskData);
doOneMaxPoolingWithMaskOutputTest(
inputMat, "max-pool-with-mask", useGpu, maskMat);
#ifdef PADDLE_WITH_CUDA
useGpu = true;
inputMat = Matrix::create(1, 25, false, useGpu);
maskMat = Matrix::create(1, 4, false, useGpu);
inputMat->copyFrom(inputData, 25);
maskMat->copyFrom(maskData, 4);
doOneMaxPoolingWithMaskOutputTest(
inputMat, "max-pool-with-mask", useGpu, maskMat);
#endif
}
......@@ -1902,5 +1902,52 @@ void BaseMatrixT<real>::sumOfProducts(BaseMatrixT& b,
}
template class BaseMatrixT<real>;
#ifndef PADDLE_MOBILE_INFERENCE
template class BaseMatrixT<int>;
#else
template <>
void BaseMatrixT<int>::zero() {
applyUnary(unary::Zero<int>());
}
template <>
void BaseMatrixT<int>::assign(int p) {
applyUnary(unary::Assign<int>(p));
}
template <>
void BaseMatrixT<int>::isEqualTo(BaseMatrixT& b, int value) {
applyBinary(binary::IsEqual<int>(value), b);
}
template <>
void BaseMatrixT<int>::neg() {
applyUnary(unary::Neg<int>());
}
template <>
void BaseMatrixT<int>::abs2() {
applyUnary(unary::Abs<int>());
}
template <>
void BaseMatrixT<int>::add(int p) {
applyUnary(unary::Add<int>(p));
}
template <>
void BaseMatrixT<int>::add(int p1, int p2) {
applyUnary(unary::Add2<int>(p1, p2));
}
template <>
void BaseMatrixT<int>::applyL1(int learningRate, int decayRate) {
applyUnary(unary::ApplyL1<int>(learningRate * decayRate));
}
#endif
} // namespace paddle
......@@ -25,6 +25,19 @@ else()
message(STATUS "Compile with MKLDNNMatrix")
endif()
if(MOBILE_INFERENCE)
list(REMOVE_ITEM MATH_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/SIMDFunctions.cpp)
# Remove sparse
list(REMOVE_ITEM MATH_HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/CpuSparseMatrix.h
${CMAKE_CURRENT_SOURCE_DIR}/SparseMatrix.h
${CMAKE_CURRENT_SOURCE_DIR}/SparseRowMatrix.h)
list(REMOVE_ITEM MATH_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/CpuSparseMatrix.cpp
${CMAKE_CURRENT_SOURCE_DIR}/SparseMatrix.cpp
${CMAKE_CURRENT_SOURCE_DIR}/SparseRowMatrix.cpp)
endif()
set(MATH_SOURCES
"${PADDLE_SOURCE_DIR}/paddle/math/BaseMatrix.cu"
"${PADDLE_SOURCE_DIR}/paddle/math/TrainingAlgorithmOp.cu"
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef PADDLE_MOBILE_INFERENCE
#include <cstddef>
#include "Matrix.h"
......@@ -309,3 +312,57 @@ private:
using Matrix::subMatrix;
};
} // namespace paddle
#else
#include "Matrix.h"
namespace paddle {
class CpuSparseMatrix : public Matrix {
public:
CpuSparseMatrix(size_t height,
size_t width,
size_t nnz, /* used to allocate space */
SparseValueType valueType = FLOAT_VALUE,
SparseFormat format = SPARSE_CSR,
bool trans = false)
: Matrix(NULL, height, width, trans, false) {}
CpuSparseMatrix(real* data,
int* rows,
int* cols,
size_t height,
size_t width,
size_t nnz,
SparseValueType valueType,
SparseFormat format,
bool trans)
: Matrix(NULL, height, width, trans, false) {}
real* getValue() const { return nullptr; }
size_t getColStartIdx(size_t i) const { return 0; }
size_t getRowStartIdx(size_t i) const { return 0; }
size_t getColNum(size_t i) const { return 0; }
int* getRowCols(size_t i) const { return nullptr; }
CpuSparseMatrixPtr getTmpSparseMatrix(size_t height, size_t width) {
return nullptr;
}
void resize(size_t newHeight,
size_t newWidth,
size_t newNnz, /* used to allocate space */
SparseValueType valueType,
SparseFormat format) {}
void resize(size_t newHeight, size_t newWidth) {}
MatrixPtr getTranspose() { return nullptr; }
void setRow(size_t row,
size_t colNum,
const unsigned int* cols,
const real* values) {}
};
} // namespace paddle
#endif
......@@ -451,6 +451,7 @@ void GpuMatrix::addSharedBias(Matrix& b, real scale) {
}
void GpuMatrix::collectBias(Matrix& a, real scale) {
#ifdef PADDLE_WITH_CUDA
CHECK_EQ(getHeight(), (size_t)1);
CHECK_EQ(width_, a.getWidth());
GpuSparseMatrix* sMatPtr = dynamic_cast<GpuSparseMatrix*>(&a);
......@@ -461,6 +462,7 @@ void GpuMatrix::collectBias(Matrix& a, real scale) {
hl_sparse_matrix_s A_d = sMatPtr->sMatrix_.get();
hl_sparse_matrix_column_sum(data, A_d, sMatPtr->getHeight(), width_, scale);
}
#endif
}
void GpuMatrix::collectSharedBias(Matrix& a, real scale) {
......@@ -552,6 +554,7 @@ void GpuMatrix::mul(const GpuSparseMatrix& a,
const GpuMatrix& b,
real scaleAB,
real scaleT) {
#ifdef PADDLE_WITH_CUDA
CHECK(isContiguous());
CHECK(b.isContiguous());
CHECK(b.useGpu_ == true) << "Matrix type are not equal";
......@@ -578,12 +581,14 @@ void GpuMatrix::mul(const GpuSparseMatrix& a,
b.height_,
scaleAB,
scaleT);
#endif
}
void GpuMatrix::mul(const GpuMatrix& a,
const GpuSparseMatrix& b,
real scaleAB,
real scaleT) {
#ifdef PADDLE_WITH_CUDA
CHECK(isContiguous());
CHECK(a.isContiguous());
CHECK(a.useGpu_ == true) << "Matrix type are not equal";
......@@ -622,6 +627,7 @@ void GpuMatrix::mul(const GpuMatrix& a,
scaleAB,
scaleT);
}
#endif
}
/* this = a*b */
......@@ -1028,15 +1034,23 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW) {
size_t paddingW,
MatrixPtr maskMatP) {
CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal";
real* inputData = inputMat.getData();
real* maskData = NULL;
size_t frameNum = inputMat.getHeight();
CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth());
CHECK(height_ == inputMat.getHeight());
CHECK(width_ == outputH * outputW * channels);
if (maskMatP != NULL) {
CHECK(maskMatP->useGpu_ == true) << "Matrix type are not equal";
CHECK(outputH * outputW * channels == maskMatP->getWidth());
maskData = maskMatP->getData();
}
hl_maxpool_forward(frameNum,
inputData,
channels,
......@@ -1051,7 +1065,8 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat,
paddingH,
paddingW,
data_,
getStride());
getStride(),
maskData);
}
void GpuMatrix::maxPoolBackward(Matrix& inputMat,
......@@ -1548,6 +1563,7 @@ void GpuMatrix::bilinearBackward(const Matrix& out,
}
void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) {
#ifdef PADDLE_WITH_CUDA
GpuMatrix* outputPtr = dynamic_cast<GpuMatrix*>(&output);
auto labelPtr = dynamic_cast<GpuSparseMatrix*>(&label);
......@@ -1563,9 +1579,11 @@ void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) {
hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get();
hl_matrix_multi_binary_cross_entropy(
output_d, entropy_d, mat_d, height_, outputPtr->width_);
#endif
}
void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) {
#ifdef PADDLE_WITH_CUDA
GpuMatrix* outputPtr = dynamic_cast<GpuMatrix*>(&output);
auto labelPtr = dynamic_cast<GpuSparseMatrix*>(&label);
......@@ -1581,6 +1599,7 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) {
hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get();
hl_matrix_multi_binary_cross_entropy_bp(
output_d, grad_d, mat_d, height_, width_);
#endif
}
void GpuMatrix::vol2Col(real* dataSrc,
......@@ -1973,9 +1992,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW) {
size_t paddingW,
MatrixPtr maskMatP) {
real* inputData = inputMat.getData();
real* outData = data_;
real* maskData = NULL;
size_t num = inputMat.getHeight();
size_t inLength = imgSizeH * imgSizeW;
size_t outLength = outputH * outputW;
......@@ -1984,6 +2005,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
CHECK_EQ(channels * outLength, this->getWidth());
size_t outStride = getStride();
if (maskMatP != NULL) {
maskData = maskMatP->getData();
CHECK_EQ(channels * outLength, maskMatP->getWidth());
}
/* initialize the data_ */
for (size_t i = 0; i < height_; i++) {
for (size_t j = 0; j < width_; j++) {
......@@ -2005,10 +2031,21 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
int wstart = pw * strideW - paddingW;
int wend = std::min(wstart + sizeX, imgSizeW);
wstart = std::max(wstart, 0);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
outData[ph * outputW + pw] = std::max(
outData[ph * outputW + pw], inputData[h * imgSizeW + w]);
if (maskData == NULL) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
outData[ph * outputW + pw] = std::max(
outData[ph * outputW + pw], inputData[h * imgSizeW + w]);
}
}
} else {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) {
outData[ph * outputW + pw] = inputData[h * imgSizeW + w];
maskData[ph * outputW + pw] = h * imgSizeW + w;
}
}
}
}
}
......@@ -2016,6 +2053,8 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
// compute offset
inputData += inLength;
outData += outLength;
if (maskData != NULL) maskData += outLength;
}
}
}
......@@ -3226,6 +3265,7 @@ template void CpuMatrix::mul<CpuMatrix, CacheRowCpuMatrix>(CpuSparseMatrix* a,
real scaleAB,
real scaleT);
#ifndef PADDLE_MOBILE_INFERENCE
void SharedCpuMatrix::mul(CpuSparseMatrix* a,
CpuMatrix* b,
real scaleAB,
......@@ -3354,6 +3394,7 @@ void SharedCpuMatrix::initBlock(int blockNum) {
}
}
#endif
/* Add a (column) vector b to matrix a, column by column */
void CpuMatrix::addColumnVector(const Matrix& b) {
BaseMatrix::addColVector(const_cast<Matrix&>(b));
......
......@@ -861,7 +861,8 @@ public:
/**
* Pooling forward operation, pick out the largest element
* in the sizeX of value
* in the sizeX of value, if the maskMatP is not NULL, it will
* also caculate the location indices.
*/
virtual void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
......@@ -874,7 +875,8 @@ public:
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW) {
size_t paddingW,
MatrixPtr maskMatP = NULL) {
LOG(FATAL) << "Not implemeted";
}
......@@ -1426,7 +1428,8 @@ public:
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW);
size_t paddingW,
MatrixPtr maskMatP);
void maxPoolBackward(Matrix& image,
size_t imgSizeH,
......@@ -1697,7 +1700,8 @@ public:
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW);
size_t paddingW,
MatrixPtr maskMatP);
void maxPoolBackward(Matrix& image,
size_t imgSizeH,
......@@ -2066,6 +2070,7 @@ public:
class SharedCpuMatrix : public CpuMatrix {
public:
#ifndef PADDLE_MOBILE_INFERENCE
/* blockNum is number of partitions of the matrix */
SharedCpuMatrix(int blockNum, size_t height, size_t width, bool trans = false)
: CpuMatrix(height, width, trans) {
......@@ -2111,6 +2116,7 @@ private:
ThreadLocal<CpuMatrixPtr> localBuf_;
ThreadLocal<std::vector<int>> localBufRows_;
ThreadLocal<std::vector<int>> blockSeq_;
#endif
};
typedef struct { unsigned int col; } sparse_non_value_t;
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef PADDLE_MOBILE_INFERENCE
#include <cstddef>
#include "CpuSparseMatrix.h"
#include "Matrix.h"
......@@ -237,3 +240,47 @@ private:
};
} // namespace paddle
#else
#include "CpuSparseMatrix.h"
namespace paddle {
class GpuSparseMatrix : public Matrix {
public:
GpuSparseMatrix(size_t height,
size_t width,
size_t nnz, /* used to allocate space */
SparseValueType valueType = FLOAT_VALUE,
SparseFormat format_ = SPARSE_CSR,
bool trans = false)
: Matrix(NULL, height, width, trans, false) {}
GpuSparseMatrix(real* value,
int* rows,
int* cols,
size_t height,
size_t width,
size_t nnz,
SparseValueType valueType,
SparseFormat format,
bool trans)
: Matrix(NULL, height, width, trans, true) {}
void resize(size_t newHeight,
size_t newWidth,
size_t newNnz, /* used to allocate space */
SparseValueType valueType,
SparseFormat format) {}
void resize(size_t newHeight, size_t newWidth) {}
MatrixPtr getTranspose() { return nullptr; }
void setRow(size_t row,
size_t colNum,
const unsigned int* cols,
const real* values) {}
};
} // namespace paddle
#endif
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#ifndef PADDLE_MOBILE_INFERENCE
#include <gflags/gflags.h>
#include <string.h>
#include <algorithm>
......@@ -313,3 +315,27 @@ private:
};
} // namespace paddle
#else
namespace paddle {
class SparseRowCpuMatrix : public CpuMatrix {
public:
void reserveStore() {}
void clearIndices() {}
};
class SparsePrefetchRowCpuMatrix : public SparseRowCpuMatrix {
public:
void setupIndices() {}
void addRows(MatrixPtr input) {}
void addRows(IVectorPtr ids) {}
};
class SparseAutoGrowRowCpuMatrix : public SparseRowCpuMatrix {};
class CacheRowCpuMatrix : public SparseAutoGrowRowCpuMatrix {};
class SparseRowIdsCpuMatrix : public CpuMatrix {};
} // namespace paddle
#endif
......@@ -3,8 +3,10 @@
add_simple_unittest(test_ExecViaCpu)
add_simple_unittest(test_SIMDFunctions)
add_simple_unittest(test_TrainingAlgorithm)
add_simple_unittest(test_SparseMatrix)
add_simple_unittest(test_RowBuffer)
if(NOT MOBILE_INFERENCE)
add_simple_unittest(test_SparseMatrix)
endif()
# TODO(yuyang18): Refactor TestUtil.cpp. Remove this cross module reference.
add_unittest(test_matrixCompare
......
......@@ -27,6 +27,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
framework::ExecutionContext ctx(*this, scope, dev_ctx);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
const size_t step_num = ids->size();
......
......@@ -94,5 +94,13 @@ class CompareOp : public framework::OperatorWithKernel {
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(greater_than, "Out = X > Y");
REGISTER_LOGICAL_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y");
REGISTER_LOGICAL_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
......@@ -15,4 +15,9 @@
#include "paddle/operators/compare_op.h"
REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, GPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(greater_than, GPU,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_KERNEL(greater_equal, GPU,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor);
......@@ -27,6 +27,24 @@ struct LessThanFunctor {
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; }
};
template <typename T>
struct LessEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
};
template <typename T>
struct GreaterThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
};
template <typename T>
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
};
template <typename T>
struct EqualFunctor {
using ELEM_TYPE = T;
......
......@@ -27,15 +27,15 @@ template <typename PoolProcess, typename T>
class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process) {
const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
......@@ -47,7 +47,7 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
......@@ -87,11 +87,12 @@ template <typename PoolProcess, class T>
class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_grad_process) {
PoolProcess pool_grad_process,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -110,7 +111,7 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
......@@ -154,10 +155,11 @@ template <class T>
class MaxPool2dGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -176,7 +178,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> {
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
......@@ -240,17 +242,17 @@ template <typename PoolProcess, class T>
class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process) {
const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int output_channels = output->dims()[1];
const int output_depth = output->dims()[2];
const int output_height = output->dims()[3];
const int output_width = output->dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
......@@ -265,7 +267,7 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
......@@ -315,11 +317,12 @@ template <typename PoolProcess, class T>
class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_grad_process) {
PoolProcess pool_grad_process,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
......@@ -343,7 +346,7 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
......@@ -398,10 +401,11 @@ template <class T>
class MaxPool3dGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
......@@ -425,7 +429,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> {
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
......@@ -498,15 +502,15 @@ template <typename T>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* output, framework::Tensor* mask) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
......@@ -517,8 +521,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
T* output_data = output->mutable_data<T>(context.GetPlace());
T* mask_data = mask->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
......@@ -563,13 +567,13 @@ template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input_grad.dims()[0];
const int input_height = input_grad.dims()[2];
const int input_width = input_grad.dims()[3];
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad) {
const int batch_size = input_grad->dims()[0];
const int input_height = input_grad->dims()[2];
const int input_width = input_grad->dims()[3];
const int output_channels = output_grad.dims()[1];
const int output_height = output_grad.dims()[2];
const int output_width = output_grad.dims()[3];
......@@ -578,7 +582,7 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) {
......@@ -612,17 +616,17 @@ template <typename T>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* output, framework::Tensor* mask) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int output_channels = output->dims()[1];
const int output_depth = output->dims()[2];
const int output_height = output->dims()[3];
const int output_width = output->dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
......@@ -636,8 +640,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
T* output_data = output->mutable_data<T>(context.GetPlace());
T* mask_data = mask->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
......@@ -691,14 +695,14 @@ template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input_grad.dims()[0];
const int input_depth = input_grad.dims()[2];
const int input_height = input_grad.dims()[3];
const int input_width = input_grad.dims()[4];
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad) {
const int batch_size = input_grad->dims()[0];
const int input_depth = input_grad->dims()[2];
const int input_height = input_grad->dims()[3];
const int input_width = input_grad->dims()[4];
const int output_channels = output_grad.dims()[1];
const int output_depth = output_grad.dims()[2];
const int output_height = output_grad.dims()[3];
......@@ -708,7 +712,7 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) {
......
此差异已折叠。
......@@ -88,60 +88,62 @@ template <typename Place, typename PoolProcess, typename T>
class Pool2dFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_compute);
const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute, framework::Tensor* output);
};
template <typename Place, typename PoolProcess, typename T>
class Pool2dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute);
PoolProcess pool_compute, framework::Tensor* input_grad);
};
template <typename Place, class T>
class MaxPool2dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad);
};
template <typename Place, typename PoolProcess, typename T>
class Pool3dFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_compute);
const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute, framework::Tensor* output);
};
template <typename Place, typename PoolProcess, typename T>
class Pool3dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute);
PoolProcess pool_compute, framework::Tensor* input_grad);
};
template <typename Place, class T>
class MaxPool3dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad);
};
/*
......@@ -155,38 +157,38 @@ template <typename Place, typename T>
class MaxPool2dWithIndexFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* output, framework::Tensor* mask);
};
template <typename Place, typename T>
class MaxPool2dWithIndexGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad);
};
template <typename Place, typename T>
class MaxPool3dWithIndexFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* output, framework::Tensor* mask);
};
template <typename Place, typename T>
class MaxPool3dWithIndexGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad);
};
} // namespace math
......
......@@ -75,16 +75,16 @@ class PoolKernel : public framework::OpKernel<T> {
Place, paddle::operators::math::MaxPool<T>, T>
pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
pool2d_forward(context.device_context(), *in_x, ksize, strides,
paddings, pool_process, out);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
Place, paddle::operators::math::AvgPool<T>, T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
pool2d_forward(context.device_context(), *in_x, ksize, strides,
paddings, pool_process, out);
}
} break;
case 3: {
......@@ -93,15 +93,15 @@ class PoolKernel : public framework::OpKernel<T> {
Place, paddle::operators::math::MaxPool<T>, T>
pool3d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
pool3d_forward(context.device_context(), *in_x, ksize, strides,
paddings, pool_process, out);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dFunctor<
Place, paddle::operators::math::AvgPool<T>, T>
pool3d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
pool3d_forward(context.device_context(), *in_x, ksize, strides,
paddings, pool_process, out);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
......@@ -142,30 +142,30 @@ class PoolGradKernel : public framework::OpKernel<T> {
if (pooling_type == "max") {
paddle::operators::math::MaxPool2dGradFunctor<Place, T>
pool2d_backward;
pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings);
pool2d_backward(context.device_context(), *in_x, *out, *out_grad,
ksize, strides, paddings, in_x_grad);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dGradFunctor<
Place, paddle::operators::math::AvgPoolGrad<T>, T>
pool2d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings, pool_process);
pool2d_backward(context.device_context(), *in_x, *out, *out_grad,
ksize, strides, paddings, pool_process, in_x_grad);
}
} break;
case 3: {
if (pooling_type == "max") {
paddle::operators::math::MaxPool3dGradFunctor<Place, T>
pool3d_backward;
pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings);
pool3d_backward(context.device_context(), *in_x, *out, *out_grad,
ksize, strides, paddings, in_x_grad);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dGradFunctor<
Place, paddle::operators::math::AvgPoolGrad<T>, T>
pool3d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings, pool_process);
pool3d_backward(context.device_context(), *in_x, *out, *out_grad,
ksize, strides, paddings, pool_process, in_x_grad);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
......
......@@ -46,14 +46,14 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
case 2: {
paddle::operators::math::MaxPool2dWithIndexFunctor<Place, T>
pool2d_forward;
pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize,
strides, paddings);
pool2d_forward(context.device_context(), *in_x, ksize, strides,
paddings, out, mask);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexFunctor<Place, T>
pool3d_forward;
pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize,
strides, paddings);
pool3d_forward(context.device_context(), *in_x, ksize, strides,
paddings, out, mask);
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
}
......@@ -89,14 +89,14 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
case 2: {
paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T>
pool2d_backward;
pool2d_backward(context.device_context(), *in_x_grad, *out_grad,
*mask, ksize, strides, paddings);
pool2d_backward(context.device_context(), *out_grad, *mask, ksize,
strides, paddings, in_x_grad);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T>
pool3d_backward;
pool3d_backward(context.device_context(), *in_x_grad, *out_grad,
*mask, ksize, strides, paddings);
pool3d_backward(context.device_context(), *out_grad, *mask, ksize,
strides, paddings, in_x_grad);
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
}
......
......@@ -200,7 +200,10 @@ void Parameter::setMat(ParameterType pType, int matType) {
false,
useGpu_);
}
} else if (matType == MAT_NORMAL_SHARED) {
}
#ifndef PADDLE_MOBILE_INFERENCE
// NOLINTNEXTLINE
else if (matType == MAT_NORMAL_SHARED) {
CHECK_EQ(height * width, bufs_[pType]->getSize());
size_t blockNum = 0;
CHECK(isGradShared(&blockNum));
......@@ -259,7 +262,10 @@ void Parameter::setMat(ParameterType pType, int matType) {
} else if (matType == MAT_SPARSE_ROW_AUTO_GROW) {
CHECK(isGradSparseUpdate());
mats_[pType] = std::make_shared<SparseAutoGrowRowCpuMatrix>(height, width);
} else {
}
#endif
// NOLINTNEXTLINE
else {
LOG(FATAL) << "Unsupported mat type" << matType;
}
}
......
......@@ -33,6 +33,7 @@ MatrixPtr makeRandomSparseMatrix(size_t height,
bool withValue,
bool useGpu,
bool equalNnzPerSample) {
#ifndef PADDLE_MOBILE_INFERENCE
std::vector<int64_t> ids(height);
std::vector<int64_t> indices(height + 1);
indices[0] = 0;
......@@ -84,6 +85,8 @@ MatrixPtr makeRandomSparseMatrix(size_t height,
}
return mat;
}
#endif
return nullptr;
}
void generateSequenceStartPositions(size_t batchSize,
......
......@@ -37,10 +37,10 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in
${CMAKE_CURRENT_BINARY_DIR}/setup.py)
add_custom_command(OUTPUT ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/core.so
COMMAND cmake -E copy $<TARGET_FILE:paddle_pybind> ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/core.so
add_custom_command(OUTPUT ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/core.so
COMMAND cmake -E copy $<TARGET_FILE:paddle_pybind> ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/core.so
DEPENDS paddle_pybind)
add_custom_target(copy_paddle_pybind ALL DEPENDS ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/core.so)
add_custom_target(copy_paddle_pybind ALL DEPENDS ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/core.so)
add_custom_command(OUTPUT ${PADDLE_PYTHON_BUILD_DIR}/.timestamp
......@@ -66,7 +66,7 @@ if (WITH_TESTING)
add_subdirectory(paddle/v2/tests)
add_subdirectory(paddle/v2/reader/tests)
add_subdirectory(paddle/v2/plot/tests)
add_subdirectory(paddle/v2/framework/tests)
add_subdirectory(paddle/v2/fluid/tests)
endif()
endif()
install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR}
......
......@@ -1200,8 +1200,14 @@ def TestData(data_config, async_load_data=None):
#caffe_mode: compute the output size using floor instead of ceil,
# which is consistent of caffe and CuDNN's convention.
def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode):
output = (2 * padding + img_size - filter_size) / float(stride)
def cnn_output_size(img_size,
filter_size,
padding,
stride,
caffe_mode,
dilation=1):
filter_s = (filter_size - 1) * dilation + 1
output = (2 * padding + img_size - filter_s) / float(stride)
if caffe_mode:
return 1 + int(math.floor(output))
else:
......@@ -1210,8 +1216,14 @@ def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode):
#calcualte image_size based on output_size for de-convolution (ConvTransLayer).
#It is the reverse function of cnn_output_size
def cnn_image_size(output_size, filter_size, padding, stride, caffe_mode):
img_size = (output_size - 1) * stride + filter_size - 2 * padding
def cnn_image_size(output_size,
filter_size,
padding,
stride,
caffe_mode,
dilation=1):
filter_s = (filter_size - 1) * dilation + 1
img_size = (output_size - 1) * stride + filter_s - 2 * padding
if not caffe_mode:
img_size = img_size + 1
return img_size
......@@ -1253,9 +1265,9 @@ def parse_bilinear(bilinear, input_layer_name, bilinear_conf):
def parse_pool(pool, input_layer_name, pool_conf, ceil_mode):
pool_conf.pool_type = pool.pool_type
config_assert(pool.pool_type in [
'max-projection', 'avg-projection', 'cudnn-max-pool', 'cudnn-avg-pool'
], "pool-type %s is not in "
"['max-projection', 'avg-projection', "
'max-projection', 'avg-projection', 'max-pool-with-mask', 'cudnn-max-pool', 'cudnn-avg-pool'
], "pool-type %s is not in " \
"['max-projection', 'avg-projection', 'max-pool-with-mask'," \
"'cudnn-max-pool', 'cudnn-avg-pool']" % pool.pool_type)
pool_conf.channels = pool.channels
......@@ -1376,6 +1388,12 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
conv_conf.stride_y = conv.stride_y
conv_conf.groups = conv.groups
conv_conf.caffe_mode = conv.caffe_mode
if not conv.dilation:
conv.dilation = 1
conv.dilation_y = 1
else:
conv_conf.dilation = conv.dilation
conv_conf.dilation_y = conv.dilation_y
if not trans:
conv_conf.filter_channels = conv.channels / conv.groups
......@@ -1383,20 +1401,20 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
get_img_size(input_layer_name, conv.channels)
conv_conf.output_x = cnn_output_size(
conv_conf.img_size, conv_conf.filter_size, conv_conf.padding,
conv_conf.stride, conv_conf.caffe_mode)
conv_conf.stride, conv_conf.caffe_mode, conv.dilation)
conv_conf.output_y = cnn_output_size(
conv_conf.img_size_y, conv_conf.filter_size_y, conv_conf.padding_y,
conv_conf.stride_y, conv_conf.caffe_mode)
conv_conf.stride_y, conv_conf.caffe_mode, conv.dilation_y)
else:
conv_conf.filter_channels = num_filters / conv.groups
conv_conf.output_x, conv_conf.output_y = \
get_img_size(input_layer_name, conv.channels)
conv_conf.img_size = cnn_image_size(
conv_conf.output_x, conv_conf.filter_size, conv_conf.padding,
conv_conf.stride, conv_conf.caffe_mode)
conv_conf.stride, conv_conf.caffe_mode, conv.dilation)
conv_conf.img_size_y = cnn_image_size(
conv_conf.output_y, conv_conf.filter_size_y, conv_conf.padding_y,
conv_conf.stride_y, conv_conf.caffe_mode)
conv_conf.stride_y, conv_conf.caffe_mode, conv.dilation_y)
#caffe_mode: compute the output size using floor instead of ceil,
......
......@@ -20,7 +20,7 @@ from paddle.trainer.config_parser import *
from .activations import LinearActivation, SigmoidActivation, TanhActivation, \
ReluActivation, IdentityActivation, SoftmaxActivation, BaseActivation
from .evaluators import *
from .poolings import MaxPooling, AvgPooling, BasePoolingType, \
from .poolings import MaxPooling, AvgPooling, MaxWithMaskPooling, BasePoolingType, \
CudnnAvgPooling, CudnnMaxPooling
from .attrs import *
from .default_decorators import *
......@@ -2571,7 +2571,9 @@ def img_conv_layer(input,
if layer_type:
if dilation > 1 or dilation_y > 1:
assert layer_type in ["cudnn_conv", "cudnn_convt"]
assert layer_type in [
"cudnn_conv", "cudnn_convt", "exconv", "exconvt"
]
if trans:
assert layer_type in ["exconvt", "cudnn_convt"]
else:
......@@ -2699,9 +2701,9 @@ def img_pool_layer(input,
elif isinstance(pool_type, AvgPooling):
pool_type.name = 'avg'
assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling,
assert type(pool_type) in [AvgPooling, MaxPooling, MaxWithMaskPooling, CudnnAvgPooling,
CudnnMaxPooling], \
"only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported"
"only (Cudnn)AvgPooling, (Cudnn)MaxPooling, MaxWithMaskPooling are supported"
type_name = pool_type.name + '-projection' \
if (
......@@ -3592,10 +3594,9 @@ def lstm_step_layer(input,
:type gate_act: BaseActivation
:param state_act: State Activation Type. TanhActivation is the default.
:type state_act: BaseActivation
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param bias_attr: The parameter attribute for bias. If this parameter is
set to True or None, the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | True
:param layer_attr: layer's extra attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
......@@ -3650,9 +3651,10 @@ def gru_step_layer(input,
:param name: The name of this layer. It is optional.
:param gate_act: Activation type of this layer's two gates. Default is Sigmoid.
:type gate_act: BaseActivation
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute, no bias
is defined. If this parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param param_attr: the parameter_attribute for transforming the output_mem
from previous step.
......@@ -3712,9 +3714,10 @@ def gru_step_naive_layer(input,
:type act: BaseActivation
:param gate_act: Activation type of this layer's two gates. Default is Sigmoid.
:type gate_act: BaseActivation
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute, no bias
is defined. If this parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param param_attr:
:param layer_attr:
......@@ -3844,9 +3847,10 @@ def recurrent_layer(input,
:type input: LayerOutput
:param act: Activation type. TanhActivation is the default.
:type act: BaseActivation
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute,
no bias is defined. If the parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param param_attr: parameter attribute.
:type param_attr: ParameterAttribute
......@@ -4836,9 +4840,10 @@ def tensor_layer(a,
:type act: BaseActivation
:param param_attr: The Parameter Attribute.
:type param_attr: ParameterAttribute
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute,
no bias is defined. If this parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param layer_attr: Extra Layer config.
:type layer_attr: ExtraLayerAttribute | None
......@@ -4900,9 +4905,10 @@ def selective_fc_layer(input,
:type act: BaseActivation
:param param_attr: The Parameter Attribute.
:type param_attr: ParameterAttribute
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute,
no bias is defined. If this parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param layer_attr: Extra Layer config.
:type layer_attr: ExtraLayerAttribute | None
......@@ -5585,10 +5591,10 @@ def nce_layer(input,
to the num_classes. Each member of the list defines
the probability of a class given input x.
:type neg_distribution: list | tuple | collections.Sequence | None
:param bias_attr: The attribute for bias. If this parameter is set False or
any object whose type is not ParameterAttribute, no bias
is added. If this parameter is set True, the bias is
initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute,
no bias is defined. If this parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute
......@@ -6498,9 +6504,9 @@ def gated_unit_layer(input,
:param gate_param_attr: The parameter attribute of the gate. See ParameterAttribute
for details.
:type gate_param_attr: ParameterAttribute
:param gate_bias_attr: The bias attribute of the gate. If the parameter is set to False or
:param gate_bias_attr: The bias attribute of the gate. If this parameter is set to False or
an object whose type is not ParameterAttribute, no bias is defined.
If the parameter is set to True, the bias is initialized to zero.
If this parameter is set to True, the bias is initialized to zero.
:type gate_bias_attr: ParameterAttribute | bool | None | Any
:param inproj_attr: Extra layer attributes of the projection. See ExtraLayerAttribute for
details.
......@@ -6508,9 +6514,9 @@ def gated_unit_layer(input,
:param inproj_param_attr: The parameter attribute of the projection. See ParameterAttribute
for details.
:type inproj_param_attr: ParameterAttribute
:param inproj_bias_attr: The bias attribute of the projection. If the parameter is set to False
:param inproj_bias_attr: The bias attribute of the projection. If this parameter is set to False
or an object whose type is not ParameterAttribute, no bias is defined.
If the parameter is set to True, the bias is initialized to zero.
If this parameter is set to True, the bias is initialized to zero.
:type inproj_bias_attr: ParameterAttribute | bool | None | Any
:param layer_attr: Extra layer attribute of the product. See ExtraLayerAttribute for
details.
......
......@@ -681,34 +681,42 @@ def lstmemory_unit(input,
state_act=TanhActivation())
:param input: input layer.
:param input: Input layer.
:type input: LayerOutput
:param out_memory: output of previous time step
:param out_memory: The output of previous time step.
:type out_memory: LayerOutput | None
:param name: lstmemory unit name.
:param name: The lstmemory unit name.
:type name: basestring
:param size: lstmemory unit size.
:param size: The lstmemory unit size.
:type size: int
:param param_attr: parameter attribute, None means default attribute.
:param param_attr: The parameter attribute for the weights in
input to hidden projection.
None means default attribute.
:type param_attr: ParameterAttribute
:param act: last activiation type of lstm.
:param act: The last activiation type of lstm.
:type act: BaseActivation
:param gate_act: gate activiation type of lstm.
:param gate_act: The gate activiation type of lstm.
:type gate_act: BaseActivation
:param state_act: state activiation type of lstm.
:param state_act: The state activiation type of lstm.
:type state_act: BaseActivation
:param input_proj_bias_attr: bias attribute for input to hidden projection.
False means no bias, None means default bias.
:type input_proj_bias_attr: ParameterAttribute|False|None
:param input_proj_layer_attr: extra layer attribute for input to hidden
projection of the LSTM unit, such as dropout, error clipping.
:param input_proj_bias_attr: The parameter attribute for the bias in
input to hidden projection.
False or None means no bias.
If this parameter is set to True,
the bias is initialized to zero.
:type input_proj_bias_attr: ParameterAttribute|bool|None
:param input_proj_layer_attr: The extra layer attribute for
input to hidden projection of the LSTM unit,
such as dropout, error clipping.
:type input_proj_layer_attr: ExtraLayerAttribute
:param lstm_bias_attr: bias parameter attribute of lstm layer.
False means no bias, None means default bias.
:type lstm_bias_attr: ParameterAttribute|False|None
:param lstm_layer_attr: extra attribute of lstm layer.
:param lstm_bias_attr: The parameter attribute for the bias in lstm layer.
False or None means no bias.
If this parameter is set to True,
the bias is initialized to zero.
:type lstm_bias_attr: ParameterAttribute|True|None
:param lstm_layer_attr: The extra attribute of lstm layer.
:type lstm_layer_attr: ExtraLayerAttribute
:return: lstmemory unit name.
:return: The lstmemory unit name.
:rtype: LayerOutput
"""
if size is None:
......@@ -786,34 +794,42 @@ def lstmemory_group(input,
gate_act=SigmoidActivation(),
state_act=TanhActivation())
:param input: input layer.
:param input: Input layer.
:type input: LayerOutput
:param size: lstmemory group size.
:param size: The lstmemory group size.
:type size: int
:param name: name of lstmemory group.
:param name: The name of lstmemory group.
:type name: basestring
:param out_memory: output of previous time step.
:param out_memory: The output of previous time step.
:type out_memory: LayerOutput | None
:param reverse: process the input in a reverse order or not.
:param reverse: Process the input in a reverse order or not.
:type reverse: bool
:param param_attr: parameter attribute, None means default attribute.
:param param_attr: The parameter attribute for the weights in
input to hidden projection.
None means default attribute.
:type param_attr: ParameterAttribute
:param act: last activiation type of lstm.
:param act: The last activiation type of lstm.
:type act: BaseActivation
:param gate_act: gate activiation type of lstm.
:param gate_act: The gate activiation type of lstm.
:type gate_act: BaseActivation
:param state_act: state activiation type of lstm.
:param state_act: The state activiation type of lstm.
:type state_act: BaseActivation
:param lstm_bias_attr: bias parameter attribute of lstm layer.
False means no bias, None means default bias.
:type lstm_bias_attr: ParameterAttribute|False|None
:param input_proj_bias_attr: bias attribute for input to hidden projection.
False means no bias, None means default bias.
:type input_proj_bias_attr: ParameterAttribute|False|None
:param input_proj_layer_attr: extra layer attribute for input to hidden
projection of the LSTM unit, such as dropout, error clipping.
:param input_proj_bias_attr: The parameter attribute for the bias in
input to hidden projection.
False or None means no bias.
If this parameter is set to True,
the bias is initialized to zero.
:type input_proj_bias_attr: ParameterAttribute|bool|None
:param input_proj_layer_attr: The extra layer attribute for
input to hidden projection of the LSTM unit,
such as dropout, error clipping.
:type input_proj_layer_attr: ExtraLayerAttribute
:param lstm_layer_attr: lstm layer's extra attribute.
:param lstm_bias_attr: The parameter attribute for the bias in lstm layer.
False or None means no bias.
If this parameter is set to True,
the bias is initialized to zero.
:type lstm_bias_attr: ParameterAttribute|True|None
:param lstm_layer_attr: The extra attribute of lstm layer.
:type lstm_layer_attr: ExtraLayerAttribute
:return: the lstmemory group.
:rtype: LayerOutput
......
......@@ -15,8 +15,8 @@
"""
__all__ = [
"BasePoolingType", "MaxPooling", "AvgPooling", "CudnnMaxPooling",
"CudnnAvgPooling", "SumPooling", "SquareRootNPooling"
"BasePoolingType", "MaxPooling", "AvgPooling", "MaxWithMaskPooling",
"CudnnMaxPooling", "CudnnAvgPooling", "SumPooling", "SquareRootNPooling"
]
......@@ -55,6 +55,19 @@ class MaxPooling(BasePoolingType):
self.output_max_index = output_max_index
class MaxWithMaskPooling(BasePoolingType):
"""
MaxWithMask pooling.
Not only return the very large values for each dimension in sequence or time steps,
but also the location indices of found maxinum values.
"""
def __init__(self):
BasePoolingType.__init__(self, "max-pool-with-mask")
class CudnnMaxPooling(BasePoolingType):
"""
Cudnn max pooling only support GPU. Return the maxinum value in the
......
......@@ -28,6 +28,8 @@ layers {
stride_y: 1
output_y: 227
img_size_y: 256
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
......
......@@ -28,6 +28,8 @@ layers {
stride_y: 1
output_y: 227
img_size_y: 256
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
......
......@@ -28,6 +28,8 @@ layers {
stride_y: 1
output_y: 48
img_size_y: 48
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
......
......@@ -30,6 +30,8 @@ layers {
stride_y: 1
output_y: 48
img_size_y: 48
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
......@@ -105,6 +107,8 @@ layers {
stride_y: 1
output_y: 24
img_size_y: 24
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_1__.wbias"
......
......@@ -30,6 +30,8 @@ layers {
stride_y: 1
output_y: 48
img_size_y: 48
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
......
......@@ -36,6 +36,8 @@ layers {
stride_y: 1
output_y: 14
img_size_y: 14
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
......
from paddle.v2.framework import framework as framework
from paddle.v2.fluid import framework as framework
__all__ = ['append_backward_ops']
......
......@@ -13,7 +13,7 @@ A `scoped_function` will take a `function` as input. That function will be
invoked in a new local scope.
"""
import paddle.v2.framework.core
import paddle.v2.fluid.core
import threading
__tl_scope__ = threading.local()
......@@ -27,13 +27,13 @@ __all__ = [
def get_cur_scope():
"""
Get current scope.
:rtype: paddle.v2.framework.core.Scope
:rtype: paddle.v2.fluid.core.Scope
"""
cur_scope_stack = getattr(__tl_scope__, 'cur_scope', None)
if cur_scope_stack is None:
__tl_scope__.cur_scope = list()
if len(__tl_scope__.cur_scope) == 0:
__tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope())
__tl_scope__.cur_scope.append(paddle.v2.fluid.core.Scope())
return __tl_scope__.cur_scope[-1]
......
import numpy as np
from paddle.v2.framework.framework import Program, g_main_program, unique_name, Variable
import paddle.v2.framework.core as core
from paddle.v2.fluid.framework import Program, g_main_program, unique_name, Variable
import paddle.v2.fluid.core as core
def _clone_var_in_block_(block, var):
......@@ -181,7 +181,6 @@ class Accuracy(Evaluator):
return np.array(out[0])
# FIXME(dzh): add a decorator to call _update_ops automatically
def accuracy(*args, **kwargs):
cls = Accuracy(*args, **kwargs)
out = cls._update_ops(*args, **kwargs)
......
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Block, Program, g_main_program
import paddle.v2.fluid.core as core
from paddle.v2.fluid.framework import Block, Program, g_main_program
g_scope = core.Scope()
......
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
import paddle.v2.fluid.core as core
import paddle.v2.fluid.proto.framework_pb2 as framework_pb2
import collections
import numpy as np
import copy
......
import paddle.v2.framework.framework as framework
import paddle.v2.fluid.framework as framework
import numpy as np
__all__ = [
......
import os
import cPickle as pickle
from paddle.v2.framework.framework import Program, Parameter, g_main_program, \
from paddle.v2.fluid.framework import Program, Parameter, g_main_program, \
Variable
__all__ = [
......
import copy
import itertools
from paddle.v2.framework.framework import Variable, g_main_program, \
from paddle.v2.fluid.framework import Variable, g_main_program, \
g_startup_program, unique_name, Program
from paddle.v2.framework.initializer import ConstantInitializer, \
from paddle.v2.fluid.initializer import ConstantInitializer, \
UniformInitializer, XavierInitializer
......
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
import paddle.v2.fluid.core as core
import paddle.v2.fluid.proto.framework_pb2 as framework_pb2
from paddle.v2.fluid.framework import OpProtoHolder, Variable, Program, \
Operator
from paddle.v2.framework.initializer import ConstantInitializer, \
from paddle.v2.fluid.initializer import ConstantInitializer, \
NormalInitializer
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
import re
import cStringIO
......@@ -845,6 +845,23 @@ def batch_norm(input,
return helper.append_activation(batch_norm_out)
def beam_search_decode(ids, scores, main_program=None, startup_program=None):
helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.data_type)
sentence_scores = helper.create_tmp_variable(dtype=ids.data_type)
helper.append_op(
type="beam_search_decode",
inputs={"Ids": ids,
"Scores": scores},
outputs={
"SentenceIds": sentence_ids,
"SentenceScores": sentence_scores
})
return sentence_ids, sentence_scores
class BlockGuard(object):
"""
BlockGuard class.
......
......@@ -3,8 +3,8 @@ import json
import logging
from collections import defaultdict
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
import paddle.v2.fluid.core as core
import paddle.v2.fluid.proto.framework_pb2 as framework_pb2
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
......
import paddle.v2.framework.layers as layers
import paddle.v2.fluid.layers as layers
__all__ = ["simple_img_conv_pool", "sequence_conv_pool"]
......
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
import paddle.v2.fluid.core as core
import paddle.v2.fluid.proto.framework_pb2 as framework_pb2
def get_all_op_protos():
......
from collections import defaultdict
import paddle.v2.framework.framework as framework
from paddle.v2.framework.framework import unique_name, Program
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.initializer import ConstantInitializer
from paddle.v2.framework.regularizer import append_regularization_ops
from paddle.v2.framework.layer_helper import LayerHelper
import paddle.v2.fluid.framework as framework
from paddle.v2.fluid.framework import unique_name, Program
from paddle.v2.fluid.backward import append_backward_ops
from paddle.v2.fluid.initializer import ConstantInitializer
from paddle.v2.fluid.regularizer import append_regularization_ops
from paddle.v2.fluid.layer_helper import LayerHelper
__all__ = [
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
......
import paddle.v2.framework.framework as framework
import paddle.v2.fluid.framework as framework
__all__ = [
'append_regularization_ops', 'L2DecayRegularizer', 'L1DecayRegularizer'
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program
from paddle.v2.framework.io import save_persistables, load_persistables
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.io import save_persistables, load_persistables
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import numpy as np
import paddle.v2 as paddle
import paddle.v2.framework.core as core
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import g_startup_program, g_main_program
from paddle.v2.framework.initializer import XavierInitializer
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.framework import g_startup_program, g_main_program
from paddle.v2.fluid.initializer import XavierInitializer
def resnet_cifar10(input, depth=32, main_program=None, startup_program=None):
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.framework.evaluator as evaluator
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.regularizer import L2DecayRegularizer
from paddle.v2.framework.initializer import UniformInitializer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.regularizer import L2DecayRegularizer
from paddle.v2.fluid.initializer import UniformInitializer
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_main_program, g_startup_program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program, g_main_program, g_startup_program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_main_program, g_startup_program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program, g_main_program, g_startup_program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import g_main_program, g_startup_program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import g_main_program, g_startup_program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
......@@ -2,12 +2,12 @@ import unittest
import numpy as np
import random
import itertools
import paddle.v2.framework.core as core
import paddle.v2.fluid.core as core
import collections
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.op import Operator
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import Program, OpProtoHolder
from paddle.v2.fluid.backward import append_backward_ops
from paddle.v2.fluid.op import Operator
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.framework import Program, OpProtoHolder
def randomize_probability(batch_size, class_num, dtype='float32'):
......
import unittest
import paddle.v2.framework.core as core
import paddle.v2.framework.layers as layers
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.framework import g_main_program
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.backward import append_backward_ops
from paddle.v2.fluid.framework import g_main_program
import numpy
......
import unittest
import numpy as np
from op_test import OpTest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
import paddle.v2.fluid.core as core
from paddle.v2.fluid.op import Operator
def grad_var_name(var_name):
......
import op_test
import unittest
import numpy as np
import paddle.v2.framework.core as core
import paddle.v2.fluid.core as core
class TestCastOp(op_test.OpTest):
......
......@@ -23,6 +23,9 @@ def create_test_class(op_type, typename, callback):
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
if __name__ == '__main__':
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册