test_conv3d_transpose_op.py 3.6 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
import unittest
import numpy as np
from op_test import OpTest


def conv3dtranspose_forward_naive(input_, filter_, conv3dtranspose_param):
    in_n, in_c, in_d, in_h, in_w = input_.shape
    f_c, out_c, f_d, f_h, f_w = filter_.shape
    assert in_c == f_c

    stride, pad = conv3dtranspose_param['stride'], conv3dtranspose_param['pad']
    out_d = (in_d - 1) * stride[0] + f_d
    out_h = (in_h - 1) * stride[1] + f_h
    out_w = (in_w - 1) * stride[2] + f_w
    out = np.zeros((in_n, out_c, out_d, out_h, out_w))

    for n in range(in_n):
        for d in range(in_d):
            for i in range(in_h):
                for j in range(in_w):
                    input_masked = input_[n, :, d, i, j]  # (c)
                    input_masked = np.reshape(input_masked, (in_c, 1, 1, 1))
                    input_masked = np.tile(input_masked, (1, f_d, f_h, f_w))

                    for k in range(out_c):
                        tmp_out = np.sum(input_masked * filter_[:, k, :, :, :],
                                         axis=0)
                        d1, d2 = d * stride[0], d * stride[0] + f_d
                        i1, i2 = i * stride[1], i * stride[1] + f_h
                        j1, j2 = j * stride[2], j * stride[2] + f_w
                        out[n, k, d1:d2, i1:i2, j1:j2] += tmp_out

C
chengduoZH 已提交
33 34
    out = out[:, :, pad[0]:out_d - pad[0], pad[1]:out_h - pad[1], pad[2]:out_w -
              pad[2]]
C
chengduoZH 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    return out


class TestConv3dTransposeOp(OpTest):
    def setUp(self):
        # init as conv transpose
        self.init_op_type()
        self.init_test_case()

        conv3dtranspose_param = {'stride': self.stride, 'pad': self.pad}
        input_ = np.random.random(self.input_size).astype("float32")
        filter_ = np.random.random(self.filter_size).astype("float32")
        output = conv3dtranspose_forward_naive(
            input_, filter_, conv3dtranspose_param).astype("float32")

        self.inputs = {'Input': input_, 'Filter': filter_}
        self.attrs = {
            'strides': self.stride,
            'paddings': self.pad,
            # 'dilations': self.dilations
        }
        self.outputs = {'Output': output}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(
C
chengduoZH 已提交
63
            set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
C
chengduoZH 已提交
64 65 66 67 68

    def test_check_grad_no_filter(self):
        self.check_grad(
            ['Input'],
            'Output',
C
chengduoZH 已提交
69
            max_relative_error=0.02,
C
chengduoZH 已提交
70 71 72 73 74 75
            no_grad_set=set(['Filter']))

    def test_check_grad_no_input(self):
        self.check_grad(
            ['Filter'],
            'Output',
C
chengduoZH 已提交
76
            max_relative_error=0.02,
C
chengduoZH 已提交
77 78 79 80 81 82
            no_grad_set=set(['Input']))

    def init_test_case(self):
        self.pad = [0, 0, 0]
        self.stride = [1, 1, 1]
        self.dilations = [1, 1, 1]
C
chengduoZH 已提交
83
        self.input_size = [2, 3, 5, 5, 5]  # NCDHW
C
chengduoZH 已提交
84 85 86 87
        f_c = self.input_size[1]
        self.filter_size = [f_c, 6, 3, 3, 3]

    def init_op_type(self):
C
chengduoZH 已提交
88
        self.op_type = "conv3d_transpose"
C
chengduoZH 已提交
89 90


C
chengduoZH 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
class TestWithPad(TestConv3dTransposeOp):
    def init_test_case(self):
        self.pad = [1, 1, 1]
        self.stride = [1, 1, 1]
        self.dilations = [1, 1, 1]
        self.input_size = [2, 3, 5, 5, 5]  # NCDHW
        f_c = self.input_size[1]
        self.filter_size = [f_c, 6, 3, 3, 3]


class TestWithStride(TestConv3dTransposeOp):
    def init_test_case(self):
        self.pad = [1, 1, 1]
        self.stride = [2, 2, 2]
        self.dilations = [1, 1, 1]
        self.input_size = [2, 3, 5, 5, 5]  # NCDHW
        f_c = self.input_size[1]
        self.filter_size = [f_c, 6, 3, 3, 3]


C
chengduoZH 已提交
111 112
if __name__ == '__main__':
    unittest.main()