box_coder_op.h 8.8 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
G
gaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
S
Siddharth Goyal 已提交
13
#include <string>
Y
Yi Wang 已提交
14 15
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
G
gaoyuan 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace operators {

enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 };

inline BoxCodeType GetBoxCodeType(const std::string& type) {
  if (type == "encode_center_size") {
    return BoxCodeType::kEncodeCenterSize;
  } else if (type == "decode_center_size") {
    return BoxCodeType::kDecodeCenterSize;
  }
  PADDLE_THROW("Not support type %s.", type);
}

31
template <typename DeviceContext, typename T>
G
gaoyuan 已提交
32 33
class BoxCoderKernel : public framework::OpKernel<T> {
 public:
34 35 36
  void EncodeCenterSize(const framework::Tensor* target_box,
                        const framework::Tensor* prior_box,
                        const framework::Tensor* prior_box_var,
37
                        const bool normalized, T* output) const {
38 39 40 41 42 43 44
    int64_t row = target_box->dims()[0];
    int64_t col = prior_box->dims()[0];
    int64_t len = prior_box->dims()[1];
    auto* target_box_data = target_box->data<T>();
    auto* prior_box_data = prior_box->data<T>();
    const T* prior_box_var_data = nullptr;
    if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
G
gaoyuan 已提交
45

L
luotao1 已提交
46 47 48
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
G
gaoyuan 已提交
49 50
    for (int64_t i = 0; i < row; ++i) {
      for (int64_t j = 0; j < col; ++j) {
51 52 53 54 55
        T prior_box_width = prior_box_data[j * len + 2] -
                            prior_box_data[j * len] + (normalized == false);
        T prior_box_height = prior_box_data[j * len + 3] -
                             prior_box_data[j * len + 1] +
                             (normalized == false);
J
jerrywgz 已提交
56
        T prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2;
G
gaoyuan 已提交
57
        T prior_box_center_y =
J
jerrywgz 已提交
58
            prior_box_data[j * len + 1] + prior_box_height / 2;
G
gaoyuan 已提交
59 60

        T target_box_center_x =
G
gaoyuan 已提交
61
            (target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
G
gaoyuan 已提交
62
        T target_box_center_y =
G
gaoyuan 已提交
63
            (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2;
64 65 66 67 68
        T target_box_width = target_box_data[i * len + 2] -
                             target_box_data[i * len] + (normalized == false);
        T target_box_height = target_box_data[i * len + 3] -
                              target_box_data[i * len + 1] +
                              (normalized == false);
G
gaoyuan 已提交
69

G
gaoyuan 已提交
70
        size_t offset = i * col * len + j * len;
71 72 73 74
        output[offset] =
            (target_box_center_x - prior_box_center_x) / prior_box_width;
        output[offset + 1] =
            (target_box_center_y - prior_box_center_y) / prior_box_height;
G
gaoyuan 已提交
75
        output[offset + 2] =
76
            std::log(std::fabs(target_box_width / prior_box_width));
G
gaoyuan 已提交
77
        output[offset + 3] =
78 79
            std::log(std::fabs(target_box_height / prior_box_height));
        if (prior_box_var) {
J
jerrywgz 已提交
80 81 82 83 84 85 86 87
          int prior_var_offset = 0;
          if (prior_box_var->dims().size() == 2) {
            prior_var_offset = j * len;
          }
          output[offset] /= prior_box_var_data[prior_var_offset];
          output[offset + 1] /= prior_box_var_data[prior_var_offset + 1];
          output[offset + 2] /= prior_box_var_data[prior_var_offset + 2];
          output[offset + 3] /= prior_box_var_data[prior_var_offset + 3];
88
        }
G
gaoyuan 已提交
89 90 91
      }
    }
  }
92 93 94
  void DecodeCenterSize(const framework::Tensor* target_box,
                        const framework::Tensor* prior_box,
                        const framework::Tensor* prior_box_var,
J
jerrywgz 已提交
95 96
                        const bool normalized, const int axis,
                        T* output) const {
97
    int64_t row = target_box->dims()[0];
J
jerrywgz 已提交
98 99
    int64_t col = target_box->dims()[1];
    int64_t len = target_box->dims()[2];
G
gaoyuan 已提交
100

101 102 103 104
    auto* target_box_data = target_box->data<T>();
    auto* prior_box_data = prior_box->data<T>();
    const T* prior_box_var_data = nullptr;
    if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
J
jerrywgz 已提交
105
    int prior_box_offset = 0;
L
luotao1 已提交
106 107 108
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
G
gaoyuan 已提交
109 110
    for (int64_t i = 0; i < row; ++i) {
      for (int64_t j = 0; j < col; ++j) {
Y
Yuan Gao 已提交
111
        size_t offset = i * col * len + j * len;
J
jerrywgz 已提交
112 113 114 115 116 117 118 119 120 121
        if (axis == 0) {
          prior_box_offset = j * len;
        } else if (axis == 1) {
          prior_box_offset = i * len;
        }
        T prior_box_width = prior_box_data[prior_box_offset + 2] -
                            prior_box_data[prior_box_offset] +
                            (normalized == false);
        T prior_box_height = prior_box_data[prior_box_offset + 3] -
                             prior_box_data[prior_box_offset + 1] +
122
                             (normalized == false);
G
gaoyuan 已提交
123
        T prior_box_center_x =
J
jerrywgz 已提交
124
            prior_box_data[prior_box_offset] + prior_box_width / 2;
G
gaoyuan 已提交
125
        T prior_box_center_y =
J
jerrywgz 已提交
126
            prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
G
gaoyuan 已提交
127

128 129 130
        T target_box_center_x = 0, target_box_center_y = 0;
        T target_box_width = 0, target_box_height = 0;
        if (prior_box_var) {
J
jerrywgz 已提交
131 132 133 134 135 136 137 138
          int prior_var_offset = 0;
          if (prior_box_var->dims().size() == 2) {
            if (axis == 0)
              prior_var_offset = j * len;
            else if (axis == 1)
              prior_var_offset = i * len;
          }
          target_box_center_x = prior_box_var_data[prior_var_offset] *
Y
Yuan Gao 已提交
139
                                    target_box_data[offset] * prior_box_width +
G
gaoyuan 已提交
140
                                prior_box_center_x;
J
jerrywgz 已提交
141
          target_box_center_y = prior_box_var_data[prior_var_offset + 1] *
Y
Yuan Gao 已提交
142
                                    target_box_data[offset + 1] *
G
gaoyuan 已提交
143 144
                                    prior_box_height +
                                prior_box_center_y;
J
jerrywgz 已提交
145
          target_box_width = std::exp(prior_box_var_data[prior_var_offset + 2] *
Y
Yuan Gao 已提交
146
                                      target_box_data[offset + 2]) *
G
gaoyuan 已提交
147
                             prior_box_width;
J
jerrywgz 已提交
148 149 150 151
          target_box_height =
              std::exp(prior_box_var_data[prior_var_offset + 3] *
                       target_box_data[offset + 3]) *
              prior_box_height;
152 153 154 155 156 157 158 159 160 161
        } else {
          target_box_center_x =
              target_box_data[offset] * prior_box_width + prior_box_center_x;
          target_box_center_y = target_box_data[offset + 1] * prior_box_height +
                                prior_box_center_y;
          target_box_width =
              std::exp(target_box_data[offset + 2]) * prior_box_width;
          target_box_height =
              std::exp(target_box_data[offset + 3]) * prior_box_height;
        }
G
gaoyuan 已提交
162 163 164

        output[offset] = target_box_center_x - target_box_width / 2;
        output[offset + 1] = target_box_center_y - target_box_height / 2;
165 166 167 168
        output[offset + 2] =
            target_box_center_x + target_box_width / 2 - (normalized == false);
        output[offset + 3] =
            target_box_center_y + target_box_height / 2 - (normalized == false);
G
gaoyuan 已提交
169 170 171 172 173 174 175 176
      }
    }
  }

  void Compute(const framework::ExecutionContext& context) const override {
    auto* prior_box = context.Input<framework::Tensor>("PriorBox");
    auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
    auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
G
gaoyuan 已提交
177
    auto* output_box = context.Output<framework::Tensor>("OutputBox");
J
jerrywgz 已提交
178
    const int axis = context.Attr<int>("axis");
G
gaoyuan 已提交
179 180 181 182
    if (target_box->lod().size()) {
      PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
                        "Only support 1 level of LoD.");
    }
J
jerrywgz 已提交
183 184 185
    auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
    bool normalized = context.Attr<bool>("box_normalized");

G
gaoyuan 已提交
186 187
    auto row = target_box->dims()[0];
    auto col = prior_box->dims()[0];
J
jerrywgz 已提交
188 189 190
    if (code_type == BoxCodeType::kDecodeCenterSize) {
      col = target_box->dims()[1];
    }
G
gaoyuan 已提交
191
    auto len = prior_box->dims()[1];
G
gaoyuan 已提交
192

G
gaoyuan 已提交
193
    output_box->mutable_data<T>({row, col, len}, context.GetPlace());
G
gaoyuan 已提交
194 195 196

    T* output = output_box->data<T>();
    if (code_type == BoxCodeType::kEncodeCenterSize) {
197
      EncodeCenterSize(target_box, prior_box, prior_box_var, normalized,
198
                       output);
G
gaoyuan 已提交
199
    } else if (code_type == BoxCodeType::kDecodeCenterSize) {
J
jerrywgz 已提交
200
      DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis,
201
                       output);
G
gaoyuan 已提交
202 203 204 205 206 207
    }
  }
};

}  // namespace operators
}  // namespace paddle