box_coder_op.h 10.3 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
G
gaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
S
Siddharth Goyal 已提交
13
#include <string>
14
#include <vector>
Y
Yi Wang 已提交
15 16
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
G
gaoyuan 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31

namespace paddle {
namespace operators {

enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 };

inline BoxCodeType GetBoxCodeType(const std::string& type) {
  if (type == "encode_center_size") {
    return BoxCodeType::kEncodeCenterSize;
  } else if (type == "decode_center_size") {
    return BoxCodeType::kDecodeCenterSize;
  }
  PADDLE_THROW("Not support type %s.", type);
}

32
template <typename DeviceContext, typename T>
G
gaoyuan 已提交
33 34
class BoxCoderKernel : public framework::OpKernel<T> {
 public:
35 36 37
  void EncodeCenterSize(const framework::Tensor* target_box,
                        const framework::Tensor* prior_box,
                        const framework::Tensor* prior_box_var,
38 39
                        const bool normalized,
                        const std::vector<float> variance, T* output) const {
40 41 42 43 44 45 46
    int64_t row = target_box->dims()[0];
    int64_t col = prior_box->dims()[0];
    int64_t len = prior_box->dims()[1];
    auto* target_box_data = target_box->data<T>();
    auto* prior_box_data = prior_box->data<T>();
    const T* prior_box_var_data = nullptr;
    if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
G
gaoyuan 已提交
47

L
luotao1 已提交
48 49 50
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
G
gaoyuan 已提交
51 52
    for (int64_t i = 0; i < row; ++i) {
      for (int64_t j = 0; j < col; ++j) {
53 54 55 56 57
        T prior_box_width = prior_box_data[j * len + 2] -
                            prior_box_data[j * len] + (normalized == false);
        T prior_box_height = prior_box_data[j * len + 3] -
                             prior_box_data[j * len + 1] +
                             (normalized == false);
J
jerrywgz 已提交
58
        T prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2;
G
gaoyuan 已提交
59
        T prior_box_center_y =
J
jerrywgz 已提交
60
            prior_box_data[j * len + 1] + prior_box_height / 2;
G
gaoyuan 已提交
61 62

        T target_box_center_x =
G
gaoyuan 已提交
63
            (target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
G
gaoyuan 已提交
64
        T target_box_center_y =
G
gaoyuan 已提交
65
            (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2;
66 67 68 69 70
        T target_box_width = target_box_data[i * len + 2] -
                             target_box_data[i * len] + (normalized == false);
        T target_box_height = target_box_data[i * len + 3] -
                              target_box_data[i * len + 1] +
                              (normalized == false);
G
gaoyuan 已提交
71

G
gaoyuan 已提交
72
        size_t offset = i * col * len + j * len;
73 74 75 76
        output[offset] =
            (target_box_center_x - prior_box_center_x) / prior_box_width;
        output[offset + 1] =
            (target_box_center_y - prior_box_center_y) / prior_box_height;
G
gaoyuan 已提交
77
        output[offset + 2] =
78
            std::log(std::fabs(target_box_width / prior_box_width));
G
gaoyuan 已提交
79
        output[offset + 3] =
80 81
            std::log(std::fabs(target_box_height / prior_box_height));
        if (prior_box_var) {
J
jerrywgz 已提交
82 83 84 85 86 87 88 89
          int prior_var_offset = 0;
          if (prior_box_var->dims().size() == 2) {
            prior_var_offset = j * len;
          }
          output[offset] /= prior_box_var_data[prior_var_offset];
          output[offset + 1] /= prior_box_var_data[prior_var_offset + 1];
          output[offset + 2] /= prior_box_var_data[prior_var_offset + 2];
          output[offset + 3] /= prior_box_var_data[prior_var_offset + 3];
90 91 92 93
        } else if (!(variance.empty())) {
          for (int k = 0; k < 4; ++k) {
            output[offset + k] /= static_cast<T>(variance[k]);
          }
94
        }
G
gaoyuan 已提交
95 96 97
      }
    }
  }
98 99 100
  void DecodeCenterSize(const framework::Tensor* target_box,
                        const framework::Tensor* prior_box,
                        const framework::Tensor* prior_box_var,
J
jerrywgz 已提交
101
                        const bool normalized, const int axis,
102
                        const std::vector<float> variance, T* output) const {
103
    int64_t row = target_box->dims()[0];
J
jerrywgz 已提交
104 105
    int64_t col = target_box->dims()[1];
    int64_t len = target_box->dims()[2];
G
gaoyuan 已提交
106

107 108 109 110
    auto* target_box_data = target_box->data<T>();
    auto* prior_box_data = prior_box->data<T>();
    const T* prior_box_var_data = nullptr;
    if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
J
jerrywgz 已提交
111
    int prior_box_offset = 0;
L
luotao1 已提交
112 113 114
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
G
gaoyuan 已提交
115 116
    for (int64_t i = 0; i < row; ++i) {
      for (int64_t j = 0; j < col; ++j) {
Y
Yuan Gao 已提交
117
        size_t offset = i * col * len + j * len;
J
jerrywgz 已提交
118 119 120 121 122 123 124 125 126 127
        if (axis == 0) {
          prior_box_offset = j * len;
        } else if (axis == 1) {
          prior_box_offset = i * len;
        }
        T prior_box_width = prior_box_data[prior_box_offset + 2] -
                            prior_box_data[prior_box_offset] +
                            (normalized == false);
        T prior_box_height = prior_box_data[prior_box_offset + 3] -
                             prior_box_data[prior_box_offset + 1] +
128
                             (normalized == false);
G
gaoyuan 已提交
129
        T prior_box_center_x =
J
jerrywgz 已提交
130
            prior_box_data[prior_box_offset] + prior_box_width / 2;
G
gaoyuan 已提交
131
        T prior_box_center_y =
J
jerrywgz 已提交
132
            prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
G
gaoyuan 已提交
133

134 135 136
        T target_box_center_x = 0, target_box_center_y = 0;
        T target_box_width = 0, target_box_height = 0;
        if (prior_box_var) {
J
jerrywgz 已提交
137 138 139 140 141 142 143 144
          int prior_var_offset = 0;
          if (prior_box_var->dims().size() == 2) {
            if (axis == 0)
              prior_var_offset = j * len;
            else if (axis == 1)
              prior_var_offset = i * len;
          }
          target_box_center_x = prior_box_var_data[prior_var_offset] *
Y
Yuan Gao 已提交
145
                                    target_box_data[offset] * prior_box_width +
G
gaoyuan 已提交
146
                                prior_box_center_x;
J
jerrywgz 已提交
147
          target_box_center_y = prior_box_var_data[prior_var_offset + 1] *
Y
Yuan Gao 已提交
148
                                    target_box_data[offset + 1] *
G
gaoyuan 已提交
149 150
                                    prior_box_height +
                                prior_box_center_y;
J
jerrywgz 已提交
151
          target_box_width = std::exp(prior_box_var_data[prior_var_offset + 2] *
Y
Yuan Gao 已提交
152
                                      target_box_data[offset + 2]) *
G
gaoyuan 已提交
153
                             prior_box_width;
J
jerrywgz 已提交
154 155 156 157
          target_box_height =
              std::exp(prior_box_var_data[prior_var_offset + 3] *
                       target_box_data[offset + 3]) *
              prior_box_height;
158 159 160 161 162 163 164 165 166 167 168 169 170 171
        } else if (!(variance.empty())) {
          target_box_center_x = static_cast<T>(variance[0]) *
                                    target_box_data[offset] * prior_box_width +
                                prior_box_center_x;
          target_box_center_y = static_cast<T>(variance[1]) *
                                    target_box_data[offset + 1] *
                                    prior_box_height +
                                prior_box_center_y;
          target_box_width = std::exp(static_cast<T>(variance[2]) *
                                      target_box_data[offset + 2]) *
                             prior_box_width;
          target_box_height = std::exp(static_cast<T>(variance[3]) *
                                       target_box_data[offset + 3]) *
                              prior_box_height;
172 173 174 175 176 177 178 179 180 181
        } else {
          target_box_center_x =
              target_box_data[offset] * prior_box_width + prior_box_center_x;
          target_box_center_y = target_box_data[offset + 1] * prior_box_height +
                                prior_box_center_y;
          target_box_width =
              std::exp(target_box_data[offset + 2]) * prior_box_width;
          target_box_height =
              std::exp(target_box_data[offset + 3]) * prior_box_height;
        }
G
gaoyuan 已提交
182 183 184

        output[offset] = target_box_center_x - target_box_width / 2;
        output[offset + 1] = target_box_center_y - target_box_height / 2;
185 186 187 188
        output[offset + 2] =
            target_box_center_x + target_box_width / 2 - (normalized == false);
        output[offset + 3] =
            target_box_center_y + target_box_height / 2 - (normalized == false);
G
gaoyuan 已提交
189 190 191 192 193 194 195 196
      }
    }
  }

  void Compute(const framework::ExecutionContext& context) const override {
    auto* prior_box = context.Input<framework::Tensor>("PriorBox");
    auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
    auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
G
gaoyuan 已提交
197
    auto* output_box = context.Output<framework::Tensor>("OutputBox");
198
    std::vector<float> variance = context.Attr<std::vector<float>>("variance");
J
jerrywgz 已提交
199
    const int axis = context.Attr<int>("axis");
G
gaoyuan 已提交
200 201 202 203
    if (target_box->lod().size()) {
      PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
                        "Only support 1 level of LoD.");
    }
204 205 206 207 208 209 210 211 212
    if (prior_box_var) {
      PADDLE_ENFORCE(variance.empty(),
                     "Input 'PriorBoxVar' and attribute 'variance' should not"
                     "be used at the same time.");
    }
    if (!(variance.empty())) {
      PADDLE_ENFORCE(static_cast<int>(variance.size()) == 4,
                     "Size of attribute 'variance' should be 4");
    }
J
jerrywgz 已提交
213 214 215
    auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
    bool normalized = context.Attr<bool>("box_normalized");

G
gaoyuan 已提交
216 217
    auto row = target_box->dims()[0];
    auto col = prior_box->dims()[0];
J
jerrywgz 已提交
218 219 220
    if (code_type == BoxCodeType::kDecodeCenterSize) {
      col = target_box->dims()[1];
    }
G
gaoyuan 已提交
221
    auto len = prior_box->dims()[1];
G
gaoyuan 已提交
222

G
gaoyuan 已提交
223
    output_box->mutable_data<T>({row, col, len}, context.GetPlace());
G
gaoyuan 已提交
224 225 226

    T* output = output_box->data<T>();
    if (code_type == BoxCodeType::kEncodeCenterSize) {
227
      EncodeCenterSize(target_box, prior_box, prior_box_var, normalized,
228
                       variance, output);
G
gaoyuan 已提交
229
    } else if (code_type == BoxCodeType::kDecodeCenterSize) {
J
jerrywgz 已提交
230
      DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis,
231
                       variance, output);
G
gaoyuan 已提交
232 233 234 235 236 237
    }
  }
};

}  // namespace operators
}  // namespace paddle