提交 a83d5215 编写于 作者: H hedaoyuan

Add unit test for Col2ImFunctor.

上级 c7610106
...@@ -20,7 +20,8 @@ limitations under the License. */ ...@@ -20,7 +20,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
TEST(Im2ColFunctor, real) { template <DeviceType Device, class T>
void TestIm2ColFunctor() {
for (size_t channels : {1, 5, 32}) { for (size_t channels : {1, 5, 32}) {
for (size_t inputHeight : {5, 33, 100}) { for (size_t inputHeight : {5, 33, 100}) {
for (size_t inputWidth : {5, 32, 96}) { for (size_t inputWidth : {5, 32, 96}) {
...@@ -50,16 +51,18 @@ TEST(Im2ColFunctor, real) { ...@@ -50,16 +51,18 @@ TEST(Im2ColFunctor, real) {
filterHeight, filterHeight,
filterWidth}); filterWidth});
VectorPtr input = Vector::create(imShape.getElements(), false);
size_t height = channels * filterHeight * filterWidth; size_t height = channels * filterHeight * filterWidth;
size_t width = outputHeight * outputWidth; size_t width = outputHeight * outputWidth;
VectorPtr input1 = Vector::create(imShape.getElements(), false);
VectorPtr input2 = Vector::create(imShape.getElements(), false);
MatrixPtr output1 = Matrix::create(height, width, false, false); MatrixPtr output1 = Matrix::create(height, width, false, false);
MatrixPtr output2 = Matrix::create(width, height, false, false); MatrixPtr output2 = Matrix::create(width, height, false, false);
Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, real> im2col1; input1->uniform(0.001, 1);
Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, real> im2col2; input2->copyFrom(*input1);
input->uniform(0.001, 1); Im2ColFunctor<kCFO, Device, T> im2Col1;
im2col1(input->getData(), Im2ColFunctor<kOCF, Device, T> im2Col2;
im2Col1(input1->getData(),
imShape, imShape,
output1->getData(), output1->getData(),
colShape1, colShape1,
...@@ -67,7 +70,7 @@ TEST(Im2ColFunctor, real) { ...@@ -67,7 +70,7 @@ TEST(Im2ColFunctor, real) {
stride, stride,
padding, padding,
padding); padding);
im2col2(input->getData(), im2Col2(input2->getData(),
imShape, imShape,
output2->getData(), output2->getData(),
colShape2, colShape2,
...@@ -76,27 +79,32 @@ TEST(Im2ColFunctor, real) { ...@@ -76,27 +79,32 @@ TEST(Im2ColFunctor, real) {
padding, padding,
padding); padding);
// The transposition of the result of ColFormat == kCFO
// is equal to the result of ColFormat == kOCF.
MatrixPtr test; MatrixPtr test;
output2->transpose(test, true); output2->transpose(test, true);
autotest::TensorCheckErr(*output1, *test); autotest::TensorCheckErr(*output1, *test);
}
}
}
}
}
}
}
}
#if 0 Col2ImFunctor<kCFO, Device, T> col2Im1;
TEST(Col2ImFunctor, real) { Col2ImFunctor<kOCF, Device, T> col2Im2;
for (size_t channels : {1, 5, 32}) { col2Im1(input1->getData(),
for (size_t inputHeight : {5, 33, 100}) { imShape,
for (size_t inputWidth : {5, 32, 96}) { output1->getData(),
for (size_t filterHeight : {1, 5}) { colShape1,
for (size_t filterWidth : {3, 7}) { stride,
for (size_t stride : {1, 2}) { stride,
for (size_t padding : {0, 1}) { padding,
padding);
col2Im2(input2->getData(),
imShape,
output2->getData(),
colShape2,
stride,
stride,
padding,
padding);
autotest::TensorCheckErr(*input1, *input2);
} }
} }
} }
...@@ -105,6 +113,13 @@ TEST(Col2ImFunctor, real) { ...@@ -105,6 +113,13 @@ TEST(Col2ImFunctor, real) {
} }
} }
} }
TEST(Im2ColFunctor, CPU) { TestIm2ColFunctor<DEVICE_TYPE_CPU, float>(); }
#ifndef PADDLE_ONLY_CPU
TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); }
#endif #endif
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册