box_coder_op.h 7.8 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
G
gaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
S
Siddharth Goyal 已提交
13
#include <string>
Y
Yi Wang 已提交
14 15
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
G
gaoyuan 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace operators {

enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 };

inline BoxCodeType GetBoxCodeType(const std::string& type) {
  if (type == "encode_center_size") {
    return BoxCodeType::kEncodeCenterSize;
  } else if (type == "decode_center_size") {
    return BoxCodeType::kDecodeCenterSize;
  }
  PADDLE_THROW("Not support type %s.", type);
}

31
template <typename DeviceContext, typename T>
G
gaoyuan 已提交
32 33
class BoxCoderKernel : public framework::OpKernel<T> {
 public:
34 35 36
  void EncodeCenterSize(const framework::Tensor* target_box,
                        const framework::Tensor* prior_box,
                        const framework::Tensor* prior_box_var,
37
                        const bool normalized, T* output) const {
38 39 40 41 42 43 44
    int64_t row = target_box->dims()[0];
    int64_t col = prior_box->dims()[0];
    int64_t len = prior_box->dims()[1];
    auto* target_box_data = target_box->data<T>();
    auto* prior_box_data = prior_box->data<T>();
    const T* prior_box_var_data = nullptr;
    if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
G
gaoyuan 已提交
45 46 47

    for (int64_t i = 0; i < row; ++i) {
      for (int64_t j = 0; j < col; ++j) {
48 49 50 51 52
        T prior_box_width = prior_box_data[j * len + 2] -
                            prior_box_data[j * len] + (normalized == false);
        T prior_box_height = prior_box_data[j * len + 3] -
                             prior_box_data[j * len + 1] +
                             (normalized == false);
G
gaoyuan 已提交
53
        T prior_box_center_x =
G
gaoyuan 已提交
54
            (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
G
gaoyuan 已提交
55
        T prior_box_center_y =
G
gaoyuan 已提交
56
            (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
G
gaoyuan 已提交
57 58

        T target_box_center_x =
G
gaoyuan 已提交
59
            (target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
G
gaoyuan 已提交
60
        T target_box_center_y =
G
gaoyuan 已提交
61
            (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2;
62 63 64 65 66
        T target_box_width = target_box_data[i * len + 2] -
                             target_box_data[i * len] + (normalized == false);
        T target_box_height = target_box_data[i * len + 3] -
                              target_box_data[i * len + 1] +
                              (normalized == false);
G
gaoyuan 已提交
67

G
gaoyuan 已提交
68
        size_t offset = i * col * len + j * len;
69 70 71 72
        output[offset] =
            (target_box_center_x - prior_box_center_x) / prior_box_width;
        output[offset + 1] =
            (target_box_center_y - prior_box_center_y) / prior_box_height;
G
gaoyuan 已提交
73
        output[offset + 2] =
74
            std::log(std::fabs(target_box_width / prior_box_width));
G
gaoyuan 已提交
75
        output[offset + 3] =
76 77 78 79 80 81 82
            std::log(std::fabs(target_box_height / prior_box_height));
        if (prior_box_var) {
          output[offset] /= prior_box_var_data[j * len];
          output[offset + 1] /= prior_box_var_data[j * len + 1];
          output[offset + 2] /= prior_box_var_data[j * len + 2];
          output[offset + 3] /= prior_box_var_data[j * len + 3];
        }
G
gaoyuan 已提交
83 84 85
      }
    }
  }
86 87 88
  void DecodeCenterSize(const framework::Tensor* target_box,
                        const framework::Tensor* prior_box,
                        const framework::Tensor* prior_box_var,
89
                        const bool normalized, T* output) const {
90 91 92
    int64_t row = target_box->dims()[0];
    int64_t col = prior_box->dims()[0];
    int64_t len = prior_box->dims()[1];
G
gaoyuan 已提交
93

94 95 96 97
    auto* target_box_data = target_box->data<T>();
    auto* prior_box_data = prior_box->data<T>();
    const T* prior_box_var_data = nullptr;
    if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
G
gaoyuan 已提交
98 99 100

    for (int64_t i = 0; i < row; ++i) {
      for (int64_t j = 0; j < col; ++j) {
Y
Yuan Gao 已提交
101
        size_t offset = i * col * len + j * len;
102 103 104 105 106
        T prior_box_width = prior_box_data[j * len + 2] -
                            prior_box_data[j * len] + (normalized == false);
        T prior_box_height = prior_box_data[j * len + 3] -
                             prior_box_data[j * len + 1] +
                             (normalized == false);
G
gaoyuan 已提交
107
        T prior_box_center_x =
G
gaoyuan 已提交
108
            (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
G
gaoyuan 已提交
109
        T prior_box_center_y =
G
gaoyuan 已提交
110
            (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
G
gaoyuan 已提交
111

112 113 114 115
        T target_box_center_x = 0, target_box_center_y = 0;
        T target_box_width = 0, target_box_height = 0;
        if (prior_box_var) {
          target_box_center_x = prior_box_var_data[j * len] *
Y
Yuan Gao 已提交
116
                                    target_box_data[offset] * prior_box_width +
G
gaoyuan 已提交
117
                                prior_box_center_x;
118
          target_box_center_y = prior_box_var_data[j * len + 1] *
Y
Yuan Gao 已提交
119
                                    target_box_data[offset + 1] *
G
gaoyuan 已提交
120 121
                                    prior_box_height +
                                prior_box_center_y;
122
          target_box_width = std::exp(prior_box_var_data[j * len + 2] *
Y
Yuan Gao 已提交
123
                                      target_box_data[offset + 2]) *
G
gaoyuan 已提交
124
                             prior_box_width;
125
          target_box_height = std::exp(prior_box_var_data[j * len + 3] *
Y
Yuan Gao 已提交
126
                                       target_box_data[offset + 3]) *
G
gaoyuan 已提交
127
                              prior_box_height;
128 129 130 131 132 133 134 135 136 137
        } else {
          target_box_center_x =
              target_box_data[offset] * prior_box_width + prior_box_center_x;
          target_box_center_y = target_box_data[offset + 1] * prior_box_height +
                                prior_box_center_y;
          target_box_width =
              std::exp(target_box_data[offset + 2]) * prior_box_width;
          target_box_height =
              std::exp(target_box_data[offset + 3]) * prior_box_height;
        }
G
gaoyuan 已提交
138 139 140

        output[offset] = target_box_center_x - target_box_width / 2;
        output[offset + 1] = target_box_center_y - target_box_height / 2;
141 142 143 144
        output[offset + 2] =
            target_box_center_x + target_box_width / 2 - (normalized == false);
        output[offset + 3] =
            target_box_center_y + target_box_height / 2 - (normalized == false);
G
gaoyuan 已提交
145 146 147 148 149 150 151 152
      }
    }
  }

  void Compute(const framework::ExecutionContext& context) const override {
    auto* prior_box = context.Input<framework::Tensor>("PriorBox");
    auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
    auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
G
gaoyuan 已提交
153
    auto* output_box = context.Output<framework::Tensor>("OutputBox");
G
gaoyuan 已提交
154 155 156 157 158 159 160

    if (target_box->lod().size()) {
      PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
                        "Only support 1 level of LoD.");
    }
    auto row = target_box->dims()[0];
    auto col = prior_box->dims()[0];
G
gaoyuan 已提交
161
    auto len = prior_box->dims()[1];
G
gaoyuan 已提交
162

G
gaoyuan 已提交
163
    output_box->mutable_data<T>({row, col, len}, context.GetPlace());
G
gaoyuan 已提交
164 165

    auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
166
    bool normalized = context.Attr<bool>("box_normalized");
G
gaoyuan 已提交
167 168
    T* output = output_box->data<T>();
    if (code_type == BoxCodeType::kEncodeCenterSize) {
169
      EncodeCenterSize(target_box, prior_box, prior_box_var, normalized,
170
                       output);
G
gaoyuan 已提交
171
    } else if (code_type == BoxCodeType::kDecodeCenterSize) {
172
      DecodeCenterSize(target_box, prior_box, prior_box_var, normalized,
173
                       output);
G
gaoyuan 已提交
174 175 176 177 178 179
    }
  }
};

}  // namespace operators
}  // namespace paddle