transforms.cpp 8.8 KB
Newer Older
C
Channingss 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

S
syyxsxx 已提交
15 16 17 18
#include "include/paddlex/transforms.h"

#include <math.h>

C
Channingss 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31
#include <iostream>
#include <string>
#include <vector>

namespace PaddleX {

std::map<std::string, int> interpolations = {{"LINEAR", cv::INTER_LINEAR},
                                             {"NEAREST", cv::INTER_NEAREST},
                                             {"AREA", cv::INTER_AREA},
                                             {"CUBIC", cv::INTER_CUBIC},
                                             {"LANCZOS4", cv::INTER_LANCZOS4}};

bool Normalize::Run(cv::Mat* im, ImageBlob* data) {
F
FlyingQianMM 已提交
32
  std::vector<float> range_val;
33
  for (int c = 0; c < im->channels(); c++) {
F
FlyingQianMM 已提交
34 35 36 37 38 39 40 41 42 43
    range_val.push_back(max_val_[c] - min_val_[c]);
  }

  std::vector<cv::Mat> split_im;
  cv::split(*im, split_im);
  for (int c = 0; c < im->channels(); c++) {
    cv::subtract(split_im[c], cv::Scalar(min_val_[c]), split_im[c]);
    cv::divide(split_im[c], cv::Scalar(range_val[c]), split_im[c]);
    cv::subtract(split_im[c], cv::Scalar(mean_[c]), split_im[c]);
    cv::divide(split_im[c], cv::Scalar(std_[c]), split_im[c]);
C
Channingss 已提交
44
  }
F
FlyingQianMM 已提交
45
  cv::merge(split_im, *im);
C
Channingss 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
  return true;
}

float ResizeByShort::GenerateScale(const cv::Mat& im) {
  int origin_w = im.cols;
  int origin_h = im.rows;
  int im_size_max = std::max(origin_w, origin_h);
  int im_size_min = std::min(origin_w, origin_h);
  float scale =
      static_cast<float>(short_size_) / static_cast<float>(im_size_min);
  if (max_size_ > 0) {
    if (round(scale * im_size_max) > max_size_) {
      scale = static_cast<float>(max_size_) / static_cast<float>(im_size_max);
    }
  }
  return scale;
}

bool ResizeByShort::Run(cv::Mat* im, ImageBlob* data) {
C
Channingss 已提交
65
  data->im_size_before_resize_.push_back({im->rows, im->cols});
C
Channingss 已提交
66 67 68
  data->reshape_order_.push_back("resize");

  float scale = GenerateScale(*im);
69 70
  int width = static_cast<int>(round(scale * im->cols));
  int height = static_cast<int>(round(scale * im->rows));
C
Channingss 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  cv::resize(*im, *im, cv::Size(width, height), 0, 0, cv::INTER_LINEAR);

  data->new_im_size_[0] = im->rows;
  data->new_im_size_[1] = im->cols;
  data->scale = scale;
  return true;
}

bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
  int height = static_cast<int>(im->rows);
  int width = static_cast<int>(im->cols);
  if (height < height_ || width < width_) {
    std::cerr << "[CenterCrop] Image size less than crop size" << std::endl;
    return false;
  }
  int offset_x = static_cast<int>((width - width_) / 2);
  int offset_y = static_cast<int>((height - height_) / 2);
  cv::Rect crop_roi(offset_x, offset_y, width_, height_);
  *im = (*im)(crop_roi);
  data->new_im_size_[0] = im->rows;
  data->new_im_size_[1] = im->cols;
  return true;
}

bool Padding::Run(cv::Mat* im, ImageBlob* data) {
C
Channingss 已提交
96
  data->im_size_before_resize_.push_back({im->rows, im->cols});
C
Channingss 已提交
97 98 99 100
  data->reshape_order_.push_back("padding");

  int padding_w = 0;
  int padding_h = 0;
101
  if (width_ > 1 & height_ > 1) {
C
Channingss 已提交
102 103
    padding_w = width_ - im->cols;
    padding_h = height_ - im->rows;
J
jack 已提交
104
  } else if (coarsest_stride_ >= 1) {
J
jack 已提交
105 106
    int h = im->rows;
    int w = im->cols;
C
Channingss 已提交
107
    padding_h =
J
jack 已提交
108
        ceil(h * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
C
Channingss 已提交
109
    padding_w =
J
jack 已提交
110
        ceil(w * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
C
Channingss 已提交
111
  }
112

C
Channingss 已提交
113 114 115 116 117 118
  if (padding_h < 0 || padding_w < 0) {
    std::cerr << "[Padding] Computed padding_h=" << padding_h
              << ", padding_w=" << padding_w
              << ", but they should be greater than 0." << std::endl;
    return false;
  }
119 120
  std::vector<cv::Mat> padded_im_per_channel;
  for (size_t i = 0; i < im->channels(); i++) {
F
FlyingQianMM 已提交
121 122 123 124
    const cv::Mat per_channel = cv::Mat(im->rows + padding_h,
                                        im->cols + padding_w,
                                        CV_32FC1,
                                        cv::Scalar(im_value_[i]));
125 126 127 128 129 130 131
    padded_im_per_channel.push_back(per_channel);
  }
  cv::Mat padded_im;
  cv::merge(padded_im_per_channel, padded_im);
  cv::Rect im_roi = cv::Rect(0, 0, im->cols, im->rows);
  im->copyTo(padded_im(im_roi));
  *im = padded_im;
C
Channingss 已提交
132 133
  data->new_im_size_[0] = im->rows;
  data->new_im_size_[1] = im->cols;
F
FlyingQianMM 已提交
134

C
Channingss 已提交
135 136 137 138 139 140 141 142 143
  return true;
}

bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) {
  if (long_size_ <= 0) {
    std::cerr << "[ResizeByLong] long_size should be greater than 0"
              << std::endl;
    return false;
  }
C
Channingss 已提交
144
  data->im_size_before_resize_.push_back({im->rows, im->cols});
C
Channingss 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
  data->reshape_order_.push_back("resize");
  int origin_w = im->cols;
  int origin_h = im->rows;

  int im_size_max = std::max(origin_w, origin_h);
  float scale =
      static_cast<float>(long_size_) / static_cast<float>(im_size_max);
  cv::resize(*im, *im, cv::Size(), scale, scale, cv::INTER_NEAREST);
  data->new_im_size_[0] = im->rows;
  data->new_im_size_[1] = im->cols;
  data->scale = scale;
  return true;
}

bool Resize::Run(cv::Mat* im, ImageBlob* data) {
  if (width_ <= 0 || height_ <= 0) {
    std::cerr << "[Resize] width and height should be greater than 0"
              << std::endl;
    return false;
  }
  if (interpolations.count(interp_) <= 0) {
    std::cerr << "[Resize] Invalid interpolation method: '" << interp_ << "'"
              << std::endl;
    return false;
  }
C
Channingss 已提交
170
  data->im_size_before_resize_.push_back({im->rows, im->cols});
C
Channingss 已提交
171 172 173 174 175 176 177 178 179
  data->reshape_order_.push_back("resize");

  cv::resize(
      *im, *im, cv::Size(width_, height_), 0, 0, interpolations[interp_]);
  data->new_im_size_[0] = im->rows;
  data->new_im_size_[1] = im->cols;
  return true;
}

180
bool Clip::Run(cv::Mat* im, ImageBlob* data) {
F
FlyingQianMM 已提交
181 182 183 184 185 186 187 188 189
  std::vector<cv::Mat> split_im;
  cv::split(*im, split_im);
  for (int c = 0; c < im->channels(); c++) {
    cv::threshold(split_im[c], split_im[c], max_val_[c], max_val_[c],
                  cv::THRESH_TRUNC);
    cv::subtract(cv::Scalar(0), split_im[c], split_im[c]);
    cv::threshold(split_im[c], split_im[c], min_val_[c], min_val_[c],
                  cv::THRESH_TRUNC);
    cv::divide(split_im[c], cv::Scalar(-1), split_im[c]);
190
  }
F
FlyingQianMM 已提交
191
  cv::merge(split_im, *im);
192 193 194
  return true;
}

C
Channingss 已提交
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
void Transforms::Init(const YAML::Node& transforms_node, bool to_rgb) {
  transforms_.clear();
  to_rgb_ = to_rgb;
  for (const auto& item : transforms_node) {
    std::string name = item.begin()->first.as<std::string>();
    std::shared_ptr<Transform> transform = CreateTransform(name);
    transform->Init(item.begin()->second);
    transforms_.push_back(transform);
  }
}

std::shared_ptr<Transform> Transforms::CreateTransform(
    const std::string& transform_name) {
  if (transform_name == "Normalize") {
    return std::make_shared<Normalize>();
  } else if (transform_name == "ResizeByShort") {
    return std::make_shared<ResizeByShort>();
  } else if (transform_name == "CenterCrop") {
    return std::make_shared<CenterCrop>();
  } else if (transform_name == "Resize") {
    return std::make_shared<Resize>();
  } else if (transform_name == "Padding") {
    return std::make_shared<Padding>();
  } else if (transform_name == "ResizeByLong") {
    return std::make_shared<ResizeByLong>();
F
FlyingQianMM 已提交
220 221
  } else if (transform_name == "Clip") {
    return std::make_shared<Clip>();
C
Channingss 已提交
222 223 224 225 226 227 228 229
  } else {
    std::cerr << "There's unexpected transform(name='" << transform_name
              << "')." << std::endl;
    exit(-1);
  }
}

bool Transforms::Run(cv::Mat* im, ImageBlob* data) {
S
syyxsxx 已提交
230
  // do all preprocess ops by order
C
Channingss 已提交
231 232 233
  if (to_rgb_) {
    cv::cvtColor(*im, *im, cv::COLOR_BGR2RGB);
  }
F
FlyingQianMM 已提交
234
  (*im).convertTo(*im, CV_32FC(im->channels()));
C
Channingss 已提交
235 236 237 238 239 240 241 242 243 244 245
  data->ori_im_size_[0] = im->rows;
  data->ori_im_size_[1] = im->cols;
  data->new_im_size_[0] = im->rows;
  data->new_im_size_[1] = im->cols;
  for (int i = 0; i < transforms_.size(); ++i) {
    if (!transforms_[i]->Run(im, data)) {
      std::cerr << "Apply transforms to image failed!" << std::endl;
      return false;
    }
  }

S
syyxsxx 已提交
246 247
  // data format NHWC to NCHW
  // img data save to ImageBlob
C
Channingss 已提交
248 249 250 251 252 253 254 255 256 257
  int h = im->rows;
  int w = im->cols;
  int c = im->channels();
  (data->im_data_).resize(c * h * w);
  float* ptr = (data->im_data_).data();
  for (int i = 0; i < c; ++i) {
    cv::extractChannel(*im, cv::Mat(h, w, CV_32FC1, ptr + i * h * w), i);
  }
  return true;
}
J
jack 已提交
258

C
Channingss 已提交
259
}  // namespace PaddleX