“f6067e1b31ab5eebedb3b89bb339aa198365f228”上不存在“git@gitcode.net:qq_37101384/tdengine.git”
box_coder_op.h 9.7 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
G
gaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
S
Siddharth Goyal 已提交
13
#include <string>
14
#include <vector>
Y
Yi Wang 已提交
15 16
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
G
gaoyuan 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31

namespace paddle {
namespace operators {

enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 };

inline BoxCodeType GetBoxCodeType(const std::string& type) {
  if (type == "encode_center_size") {
    return BoxCodeType::kEncodeCenterSize;
  } else if (type == "decode_center_size") {
    return BoxCodeType::kDecodeCenterSize;
  }
  PADDLE_THROW("Not support type %s.", type);
}

32
template <typename DeviceContext, typename T>
G
gaoyuan 已提交
33 34
class BoxCoderKernel : public framework::OpKernel<T> {
 public:
35 36 37
  void EncodeCenterSize(const framework::Tensor* target_box,
                        const framework::Tensor* prior_box,
                        const framework::Tensor* prior_box_var,
38 39
                        const bool normalized,
                        const std::vector<float> variance, T* output) const {
40 41 42 43 44 45 46
    int64_t row = target_box->dims()[0];
    int64_t col = prior_box->dims()[0];
    int64_t len = prior_box->dims()[1];
    auto* target_box_data = target_box->data<T>();
    auto* prior_box_data = prior_box->data<T>();
    const T* prior_box_var_data = nullptr;
    if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
G
gaoyuan 已提交
47

L
luotao1 已提交
48 49 50
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
G
gaoyuan 已提交
51 52
    for (int64_t i = 0; i < row; ++i) {
      for (int64_t j = 0; j < col; ++j) {
53 54 55 56 57
        T prior_box_width = prior_box_data[j * len + 2] -
                            prior_box_data[j * len] + (normalized == false);
        T prior_box_height = prior_box_data[j * len + 3] -
                             prior_box_data[j * len + 1] +
                             (normalized == false);
J
jerrywgz 已提交
58
        T prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2;
G
gaoyuan 已提交
59
        T prior_box_center_y =
J
jerrywgz 已提交
60
            prior_box_data[j * len + 1] + prior_box_height / 2;
G
gaoyuan 已提交
61 62

        T target_box_center_x =
G
gaoyuan 已提交
63
            (target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
G
gaoyuan 已提交
64
        T target_box_center_y =
G
gaoyuan 已提交
65
            (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2;
66 67 68 69 70
        T target_box_width = target_box_data[i * len + 2] -
                             target_box_data[i * len] + (normalized == false);
        T target_box_height = target_box_data[i * len + 3] -
                              target_box_data[i * len + 1] +
                              (normalized == false);
G
gaoyuan 已提交
71

G
gaoyuan 已提交
72
        size_t offset = i * col * len + j * len;
73 74 75 76
        output[offset] =
            (target_box_center_x - prior_box_center_x) / prior_box_width;
        output[offset + 1] =
            (target_box_center_y - prior_box_center_y) / prior_box_height;
G
gaoyuan 已提交
77
        output[offset + 2] =
78
            std::log(std::fabs(target_box_width / prior_box_width));
G
gaoyuan 已提交
79
        output[offset + 3] =
80 81
            std::log(std::fabs(target_box_height / prior_box_height));
        if (prior_box_var) {
82
          int prior_var_offset = j * len;
J
jerrywgz 已提交
83 84 85 86
          output[offset] /= prior_box_var_data[prior_var_offset];
          output[offset + 1] /= prior_box_var_data[prior_var_offset + 1];
          output[offset + 2] /= prior_box_var_data[prior_var_offset + 2];
          output[offset + 3] /= prior_box_var_data[prior_var_offset + 3];
87 88 89 90
        } else if (!(variance.empty())) {
          for (int k = 0; k < 4; ++k) {
            output[offset + k] /= static_cast<T>(variance[k]);
          }
91
        }
G
gaoyuan 已提交
92 93 94
      }
    }
  }
95
  template <int axis, int var_size>
96 97 98
  void DecodeCenterSize(const framework::Tensor* target_box,
                        const framework::Tensor* prior_box,
                        const framework::Tensor* prior_box_var,
99 100
                        const bool normalized, std::vector<float> variance,
                        T* output) const {
101
    int64_t row = target_box->dims()[0];
J
jerrywgz 已提交
102 103
    int64_t col = target_box->dims()[1];
    int64_t len = target_box->dims()[2];
G
gaoyuan 已提交
104

105 106 107
    auto* target_box_data = target_box->data<T>();
    auto* prior_box_data = prior_box->data<T>();
    const T* prior_box_var_data = nullptr;
108
    if (var_size == 2) prior_box_var_data = prior_box_var->data<T>();
J
jerrywgz 已提交
109
    int prior_box_offset = 0;
110 111
    T var_data[4] = {1., 1., 1., 1.};
    T* var_ptr = var_data;
L
luotao1 已提交
112 113 114
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
G
gaoyuan 已提交
115 116
    for (int64_t i = 0; i < row; ++i) {
      for (int64_t j = 0; j < col; ++j) {
Y
Yuan Gao 已提交
117
        size_t offset = i * col * len + j * len;
118
        prior_box_offset = axis == 0 ? j * len : i * len;
J
jerrywgz 已提交
119 120 121 122 123
        T prior_box_width = prior_box_data[prior_box_offset + 2] -
                            prior_box_data[prior_box_offset] +
                            (normalized == false);
        T prior_box_height = prior_box_data[prior_box_offset + 3] -
                             prior_box_data[prior_box_offset + 1] +
124
                             (normalized == false);
G
gaoyuan 已提交
125
        T prior_box_center_x =
J
jerrywgz 已提交
126
            prior_box_data[prior_box_offset] + prior_box_width / 2;
G
gaoyuan 已提交
127
        T prior_box_center_y =
J
jerrywgz 已提交
128
            prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
G
gaoyuan 已提交
129

130 131
        T target_box_center_x = 0, target_box_center_y = 0;
        T target_box_width = 0, target_box_height = 0;
132 133 134 135 136 137
        int prior_var_offset = axis == 0 ? j * len : i * len;
        if (var_size == 2) {
          std::memcpy(var_ptr, prior_box_var_data + prior_var_offset,
                      4 * sizeof(T));
        } else if (var_size == 1) {
          var_ptr = reinterpret_cast<T*>(variance.data());
138
        }
139 140 141 142 143
        T box_var_x = *var_ptr;
        T box_var_y = *(var_ptr + 1);
        T box_var_w = *(var_ptr + 2);
        T box_var_h = *(var_ptr + 3);

J
jerrywgz 已提交
144 145 146 147 148 149 150 151 152 153
        target_box_center_x =
            box_var_x * target_box_data[offset] * prior_box_width +
            prior_box_center_x;
        target_box_center_y =
            box_var_y * target_box_data[offset + 1] * prior_box_height +
            prior_box_center_y;
        target_box_width =
            std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width;
        target_box_height = std::exp(box_var_h * target_box_data[offset + 3]) *
                            prior_box_height;
G
gaoyuan 已提交
154 155 156

        output[offset] = target_box_center_x - target_box_width / 2;
        output[offset + 1] = target_box_center_y - target_box_height / 2;
157 158 159 160
        output[offset + 2] =
            target_box_center_x + target_box_width / 2 - (normalized == false);
        output[offset + 3] =
            target_box_center_y + target_box_height / 2 - (normalized == false);
G
gaoyuan 已提交
161 162 163 164 165 166 167 168
      }
    }
  }

  void Compute(const framework::ExecutionContext& context) const override {
    auto* prior_box = context.Input<framework::Tensor>("PriorBox");
    auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
    auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
G
gaoyuan 已提交
169
    auto* output_box = context.Output<framework::Tensor>("OutputBox");
170
    std::vector<float> variance = context.Attr<std::vector<float>>("variance");
J
jerrywgz 已提交
171
    const int axis = context.Attr<int>("axis");
G
gaoyuan 已提交
172 173 174 175
    if (target_box->lod().size()) {
      PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
                        "Only support 1 level of LoD.");
    }
176 177 178 179 180 181 182 183 184
    if (prior_box_var) {
      PADDLE_ENFORCE(variance.empty(),
                     "Input 'PriorBoxVar' and attribute 'variance' should not"
                     "be used at the same time.");
    }
    if (!(variance.empty())) {
      PADDLE_ENFORCE(static_cast<int>(variance.size()) == 4,
                     "Size of attribute 'variance' should be 4");
    }
J
jerrywgz 已提交
185 186 187
    auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
    bool normalized = context.Attr<bool>("box_normalized");

G
gaoyuan 已提交
188 189
    auto row = target_box->dims()[0];
    auto col = prior_box->dims()[0];
J
jerrywgz 已提交
190 191 192
    if (code_type == BoxCodeType::kDecodeCenterSize) {
      col = target_box->dims()[1];
    }
G
gaoyuan 已提交
193
    auto len = prior_box->dims()[1];
G
gaoyuan 已提交
194

G
gaoyuan 已提交
195
    output_box->mutable_data<T>({row, col, len}, context.GetPlace());
G
gaoyuan 已提交
196 197 198

    T* output = output_box->data<T>();
    if (code_type == BoxCodeType::kEncodeCenterSize) {
199
      EncodeCenterSize(target_box, prior_box, prior_box_var, normalized,
200
                       variance, output);
G
gaoyuan 已提交
201
    } else if (code_type == BoxCodeType::kDecodeCenterSize) {
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
      if (prior_box_var) {
        if (axis == 0) {
          DecodeCenterSize<0, 2>(target_box, prior_box, prior_box_var,
                                 normalized, variance, output);
        } else {
          DecodeCenterSize<1, 2>(target_box, prior_box, prior_box_var,
                                 normalized, variance, output);
        }
      } else if (!(variance.empty())) {
        if (axis == 0) {
          DecodeCenterSize<0, 1>(target_box, prior_box, prior_box_var,
                                 normalized, variance, output);
        } else {
          DecodeCenterSize<1, 1>(target_box, prior_box, prior_box_var,
                                 normalized, variance, output);
        }
      } else {
        if (axis == 0) {
          DecodeCenterSize<0, 0>(target_box, prior_box, prior_box_var,
                                 normalized, variance, output);
        } else {
          DecodeCenterSize<1, 0>(target_box, prior_box, prior_box_var,
                                 normalized, variance, output);
        }
      }
G
gaoyuan 已提交
227 228 229 230 231 232
    }
  }
};

}  // namespace operators
}  // namespace paddle