提交 f0c3c498 编写于 作者: X xzl

test exconv layerGrad and im2col

上级 328169a9
...@@ -29,14 +29,16 @@ void TestIm2ColFunctor() { ...@@ -29,14 +29,16 @@ void TestIm2ColFunctor() {
for (size_t filterWidth : {3, 7}) { for (size_t filterWidth : {3, 7}) {
for (size_t stride : {1, 2}) { for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) { for (size_t padding : {0, 1}) {
if (inputHeight <= filterHeight || inputWidth <= filterWidth) for (size_t dilation : {1, 3}) {
size_t filterSizeH = (filterHeight - 1) * dilation + 1;
size_t filterSizeW = (filterWidth - 1) * dilation + 1;
if (inputHeight <= filterSizeH || inputWidth <= filterSizeW)
break; break;
if (padding >= filterHeight || padding >= filterWidth) break; if (padding >= filterSizeH || padding >= filterSizeW) break;
size_t outputHeight = size_t outputHeight =
(inputHeight - filterHeight + 2 * padding + stride) / (inputHeight - filterSizeH + 2 * padding) / stride + 1;
stride;
size_t outputWidth = size_t outputWidth =
(inputWidth - filterWidth + 2 * padding + stride) / stride; (inputWidth - filterSizeW + 2 * padding) / stride + 1;
TensorShape imShape = TensorShape imShape =
TensorShape({channels, inputHeight, inputWidth}); TensorShape({channels, inputHeight, inputWidth});
...@@ -53,10 +55,14 @@ void TestIm2ColFunctor() { ...@@ -53,10 +55,14 @@ void TestIm2ColFunctor() {
size_t height = channels * filterHeight * filterWidth; size_t height = channels * filterHeight * filterWidth;
size_t width = outputHeight * outputWidth; size_t width = outputHeight * outputWidth;
VectorPtr input1 = Vector::create(imShape.getElements(), false); VectorPtr input1 =
VectorPtr input2 = Vector::create(imShape.getElements(), false); Vector::create(imShape.getElements(), false);
MatrixPtr output1 = Matrix::create(height, width, false, false); VectorPtr input2 =
MatrixPtr output2 = Matrix::create(width, height, false, false); Vector::create(imShape.getElements(), false);
MatrixPtr output1 =
Matrix::create(height, width, false, false);
MatrixPtr output2 =
Matrix::create(width, height, false, false);
input1->uniform(0.001, 1); input1->uniform(0.001, 1);
input2->copyFrom(*input1); input2->copyFrom(*input1);
...@@ -69,7 +75,9 @@ void TestIm2ColFunctor() { ...@@ -69,7 +75,9 @@ void TestIm2ColFunctor() {
stride, stride,
stride, stride,
padding, padding,
padding); padding,
dilation,
dilation);
im2Col2(input2->getData(), im2Col2(input2->getData(),
imShape, imShape,
output2->getData(), output2->getData(),
...@@ -77,7 +85,9 @@ void TestIm2ColFunctor() { ...@@ -77,7 +85,9 @@ void TestIm2ColFunctor() {
stride, stride,
stride, stride,
padding, padding,
padding); padding,
dilation,
dilation);
// The transposition of the result of ColFormat == kCFO // The transposition of the result of ColFormat == kCFO
// is equal to the result of ColFormat == kOCF. // is equal to the result of ColFormat == kOCF.
...@@ -87,6 +97,7 @@ void TestIm2ColFunctor() { ...@@ -87,6 +97,7 @@ void TestIm2ColFunctor() {
Col2ImFunctor<kCFO, Device, T> col2Im1; Col2ImFunctor<kCFO, Device, T> col2Im1;
Col2ImFunctor<kOCF, Device, T> col2Im2; Col2ImFunctor<kOCF, Device, T> col2Im2;
col2Im1(input1->getData(), col2Im1(input1->getData(),
imShape, imShape,
output1->getData(), output1->getData(),
...@@ -94,7 +105,9 @@ void TestIm2ColFunctor() { ...@@ -94,7 +105,9 @@ void TestIm2ColFunctor() {
stride, stride,
stride, stride,
padding, padding,
padding); padding,
dilation,
dilation);
col2Im2(input2->getData(), col2Im2(input2->getData(),
imShape, imShape,
output2->getData(), output2->getData(),
...@@ -102,8 +115,9 @@ void TestIm2ColFunctor() { ...@@ -102,8 +115,9 @@ void TestIm2ColFunctor() {
stride, stride,
stride, stride,
padding, padding,
padding); padding,
dilation,
dilation);
autotest::TensorCheckErr(*input1, *input2); autotest::TensorCheckErr(*input1, *input2);
} }
} }
...@@ -112,6 +126,7 @@ void TestIm2ColFunctor() { ...@@ -112,6 +126,7 @@ void TestIm2ColFunctor() {
} }
} }
} }
}
} }
TEST(Im2ColFunctor, CPU) { TestIm2ColFunctor<DEVICE_TYPE_CPU, float>(); } TEST(Im2ColFunctor, CPU) { TestIm2ColFunctor<DEVICE_TYPE_CPU, float>(); }
......
...@@ -434,7 +434,7 @@ void testConvLayer(const string& type, bool trans, bool useGpu) { ...@@ -434,7 +434,7 @@ void testConvLayer(const string& type, bool trans, bool useGpu) {
config.layerConfig.set_partial_sum(1); config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true); config.layerConfig.set_shared_biases(true);
int dilation = 1; int dilation = 2;
if (type == "cudnn_conv") { if (type == "cudnn_conv") {
#if CUDNN_VERSION >= 6000 #if CUDNN_VERSION >= 6000
dilation = 2; dilation = 2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册