test_box_coder_op.py 7.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
G
gaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

G
gaoyuan 已提交
17 18 19 20
import unittest
import numpy as np
import sys
import math
21
from op_test import OpTest
G
gaoyuan 已提交
22 23


24 25
def box_coder(target_box, prior_box, prior_box_var, output_box, code_type,
              box_normalized):
G
gaoyuan 已提交
26 27 28 29 30 31 32 33 34 35
    prior_box_x = (
        (prior_box[:, 2] + prior_box[:, 0]) / 2).reshape(1, prior_box.shape[0])
    prior_box_y = (
        (prior_box[:, 3] + prior_box[:, 1]) / 2).reshape(1, prior_box.shape[0])
    prior_box_width = (
        (prior_box[:, 2] - prior_box[:, 0])).reshape(1, prior_box.shape[0])
    prior_box_height = (
        (prior_box[:, 3] - prior_box[:, 1])).reshape(1, prior_box.shape[0])
    prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0],
                                          prior_box_var.shape[1])
36 37 38
    if not box_normalized:
        prior_box_height = prior_box_height + 1
        prior_box_width = prior_box_width + 1
G
gaoyuan 已提交
39 40

    if (code_type == "EncodeCenterSize"):
G
gaoyuan 已提交
41 42 43 44 45 46 47 48
        target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape(
            target_box.shape[0], 1)
        target_box_y = ((target_box[:, 3] + target_box[:, 1]) / 2).reshape(
            target_box.shape[0], 1)
        target_box_width = ((target_box[:, 2] - target_box[:, 0])).reshape(
            target_box.shape[0], 1)
        target_box_height = ((target_box[:, 3] - target_box[:, 1])).reshape(
            target_box.shape[0], 1)
49 50 51
        if not box_normalized:
            target_box_height = target_box_height + 1
            target_box_width = target_box_width + 1
G
gaoyuan 已提交
52 53 54 55 56 57 58 59 60

        output_box[:,:,0] = (target_box_x - prior_box_x) / prior_box_width / \
                prior_box_var[:,:,0]
        output_box[:,:,1] = (target_box_y - prior_box_y) / prior_box_height / \
                prior_box_var[:,:,1]
        output_box[:,:,2] = np.log(np.fabs(target_box_width / prior_box_width)) / \
                prior_box_var[:,:,2]
        output_box[:,:,3] = np.log(np.fabs(target_box_height / prior_box_height)) / \
                prior_box_var[:,:,3]
G
gaoyuan 已提交
61 62

    elif (code_type == "DecodeCenterSize"):
G
gaoyuan 已提交
63 64 65 66 67 68 69 70
        target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \
                       prior_box_width + prior_box_x
        target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \
                       prior_box_height + prior_box_y
        target_box_width = np.exp(prior_box_var[:,:,2] * target_box[:,:,2]) * \
                           prior_box_width
        target_box_height = np.exp(prior_box_var[:,:,3] * target_box[:,:,3]) * \
                            prior_box_height
Y
Yuan Gao 已提交
71

G
gaoyuan 已提交
72 73 74 75
        output_box[:, :, 0] = target_box_x - target_box_width / 2
        output_box[:, :, 1] = target_box_y - target_box_height / 2
        output_box[:, :, 2] = target_box_x + target_box_width / 2
        output_box[:, :, 3] = target_box_y + target_box_height / 2
76 77 78
        if not box_normalized:
            output_box[:, :, 2] = output_box[:, :, 2] - 1
            output_box[:, :, 3] = output_box[:, :, 3] - 1
G
gaoyuan 已提交
79 80


81 82
def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type,
                    box_normalized):
G
gaoyuan 已提交
83 84 85
    n = target_box.shape[0]
    m = prior_box.shape[0]
    output_box = np.zeros((n, m, 4), dtype=np.float32)
86 87
    cur_offset = 0
    for i in range(len(lod)):
Y
Yuan Gao 已提交
88
        if (code_type == "EncodeCenterSize"):
89 90 91
            box_coder(target_box[cur_offset:(cur_offset + lod[i]), :],
                      prior_box, prior_box_var,
                      output_box[cur_offset:(cur_offset + lod[i]), :, :],
92
                      code_type, box_normalized)
Y
Yuan Gao 已提交
93
        elif (code_type == "DecodeCenterSize"):
94 95 96
            box_coder(target_box[cur_offset:(cur_offset + lod[i]), :, :],
                      prior_box, prior_box_var,
                      output_box[cur_offset:(cur_offset + lod[i]), :, :],
97
                      code_type, box_normalized)
98
        cur_offset += lod[i]
G
gaoyuan 已提交
99 100 101 102 103 104 105 106 107
    return output_box


class TestBoxCoderOp(OpTest):
    def test_check_output(self):
        self.check_output()

    def setUp(self):
        self.op_type = "box_coder"
108
        lod = [[1, 1, 1, 1, 1]]
G
gaoyuan 已提交
109 110
        prior_box = np.random.random((10, 4)).astype('float32')
        prior_box_var = np.random.random((10, 4)).astype('float32')
Y
Yuan Gao 已提交
111
        target_box = np.random.random((5, 10, 4)).astype('float32')
G
gaoyuan 已提交
112
        code_type = "DecodeCenterSize"
113
        box_normalized = False
G
gaoyuan 已提交
114
        output_box = batch_box_coder(prior_box, prior_box_var, target_box,
115
                                     lod[0], code_type, box_normalized)
G
gaoyuan 已提交
116 117 118 119 120 121

        self.inputs = {
            'PriorBox': prior_box,
            'PriorBoxVar': prior_box_var,
            'TargetBox': target_box,
        }
122 123 124
        self.attrs = {
            'code_type': 'decode_center_size',
            'box_normalized': False
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
        }
        self.outputs = {'OutputBox': output_box}


class TestBoxCoderOpWithoutBoxVar(OpTest):
    def test_check_output(self):
        self.check_output()

    def setUp(self):
        self.op_type = "box_coder"
        lod = [[0, 1, 2, 3, 4, 5]]
        prior_box = np.random.random((10, 4)).astype('float32')
        prior_box_var = np.ones((10, 4)).astype('float32')
        target_box = np.random.random((5, 10, 4)).astype('float32')
        code_type = "DecodeCenterSize"
        box_normalized = False
        output_box = batch_box_coder(prior_box, prior_box_var, target_box,
                                     lod[0], code_type, box_normalized)

        self.inputs = {
            'PriorBox': prior_box,
            'TargetBox': target_box,
        }
        self.attrs = {
            'code_type': 'decode_center_size',
            'box_normalized': False
151
        }
G
gaoyuan 已提交
152 153 154 155 156 157 158 159 160
        self.outputs = {'OutputBox': output_box}


class TestBoxCoderOpWithLoD(OpTest):
    def test_check_output(self):
        self.check_output()

    def setUp(self):
        self.op_type = "box_coder"
161
        lod = [[4, 8, 8]]
G
gaoyuan 已提交
162 163 164 165
        prior_box = np.random.random((10, 4)).astype('float32')
        prior_box_var = np.random.random((10, 4)).astype('float32')
        target_box = np.random.random((20, 4)).astype('float32')
        code_type = "EncodeCenterSize"
166
        box_normalized = True
G
gaoyuan 已提交
167
        output_box = batch_box_coder(prior_box, prior_box_var, target_box,
168
                                     lod[0], code_type, box_normalized)
G
gaoyuan 已提交
169 170 171 172 173 174

        self.inputs = {
            'PriorBox': prior_box,
            'PriorBoxVar': prior_box_var,
            'TargetBox': (target_box, lod),
        }
175
        self.attrs = {'code_type': 'encode_center_size', 'box_normalized': True}
G
gaoyuan 已提交
176 177 178 179 180
        self.outputs = {'OutputBox': output_box}


if __name__ == '__main__':
    unittest.main()