convolution_transposed_3x3.cc 15.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h"

#include <string>
#include <utility>
#include <vector>

#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"

namespace tflite {
namespace gpu {
namespace cl {
30
ConvolutionTransposed3x3::ConvolutionTransposed3x3(
31
    const OperationDef& definition, const DeviceInfo& device_info, int2 padding)
32
    : GPUOperation(definition), padding_(padding) {
33
  work_group_size_ = int3(8, 4, 1);
34
  work_group_launch_order_ = int3(2, 0, 1);
35
  if (device_info.IsPowerVR()) {
36
    weights_upload_type_ = WeightsUploadType::LOCAL_MEM_ASYNC;
37
  } else if (device_info.IsNvidia() || device_info.IsIntel()) {
38
    weights_upload_type_ = WeightsUploadType::LOCAL_MEM_BY_THREADS;
39
  } else if (device_info.IsAMD()) {
40 41 42 43
    weights_upload_type_ = WeightsUploadType::CONSTANT_MEM;
  } else {
    weights_upload_type_ = WeightsUploadType::GLOBAL_MEM;
  }
44 45 46
  code_ = GenerateConvolutionTransposedCode(definition_, weights_upload_type_,
                                            padding_, work_group_launch_order_);
  if (definition_.precision == CalculationsPrecision::F16 &&
47
      device_info.IsPowerVR()) {
48 49
    compiler_options_.push_back(CompilerOptions::POWERVR_FP16);
  }
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
}

ConvolutionTransposed3x3::ConvolutionTransposed3x3(
    ConvolutionTransposed3x3&& operation)
    : GPUOperation(std::move(operation)),
      padding_(operation.padding_),
      weights_upload_type_(operation.weights_upload_type_) {}

ConvolutionTransposed3x3& ConvolutionTransposed3x3::operator=(
    ConvolutionTransposed3x3&& operation) {
  if (this != &operation) {
    std::swap(padding_, operation.padding_);
    std::swap(weights_upload_type_, operation.weights_upload_type_);
    GPUOperation::operator=(std::move(operation));
  }
  return *this;
}
67

68
std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
69
    const OperationDef& op_def,
70
    ConvolutionTransposed3x3::WeightsUploadType weights_upload_type,
71 72 73
    int2 padding, int3 work_group_launch_order) {
  auto src_desc = op_def.src_tensors[0];
  src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
74
  if (op_def.IsBatchSupported()) {
75
    src_desc.SetStateVar("BatchedWidth", "true");
76
  }
77 78 79
  AddSrcTensor("src_tensor", src_desc);

  auto dst_desc = op_def.dst_tensors[0];
80
  if (op_def.IsBatchSupported()) {
81
    dst_desc.SetStateVar("BatchedWidth", "true");
82
  }
83 84 85 86 87
  AddDstTensor("dst_tensor", dst_desc);

  args_.AddInt("filter_offset");
  args_.AddInt("padding_x");
  args_.AddInt("padding_y");
88 89 90 91 92 93 94

  const bool need_local_mem =
      weights_upload_type ==
          ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_BY_THREADS ||
      weights_upload_type ==
          ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_ASYNC;

95
  std::string c = GetCommonDefines(op_def.precision);
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
  switch (op_def.precision) {
    case CalculationsPrecision::F32:
    case CalculationsPrecision::F16:
      c += "#define CONV(R, SRC, F) \\\n";
      c += "  R += SRC.x * weights_cache[F]; \\\n";
      c += "  R += SRC.y * weights_cache[F + 1]; \\\n";
      c += "  R += SRC.z * weights_cache[F + 2]; \\\n";
      c += "  R += SRC.w * weights_cache[F + 3];   \n";
      break;
    case CalculationsPrecision::F32_F16:
      c += "#define CONV(R, SRC, F) \\\n";
      c += "  R += convert_float4(SRC.x * weights_cache[F] + SRC.y * "
           "weights_cache[F + 1] + SRC.z * weights_cache[F + 2] + SRC.w * "
           "weights_cache[F + 3]);\n";
      break;
  }

  const std::string weights_space =
      weights_upload_type ==
              ConvolutionTransposed3x3::WeightsUploadType::CONSTANT_MEM
          ? "__constant"
          : "__global";

  const std::string pixel_stride =
120
      op_def.IsBatchSupported() ? "args.dst_tensor.Batch()" : "1";
121 122
  c += "__attribute__((reqd_work_group_size(8, 4, 1)))\n";
  c += "__kernel void main_function(\n";
123
  c += "$0) {\n";
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
  int3 launch_remap;
  launch_remap[work_group_launch_order.x] = 0;
  launch_remap[work_group_launch_order.y] = 1;
  launch_remap[work_group_launch_order.z] = 2;
  auto GetGlobalID = [&](int id) {
    std::string result;
    const std::string sid = std::to_string(id);
    if (work_group_launch_order[id] == id) {
      return "get_global_id(" + sid + ")";
    } else {
      return "get_group_id(" + std::to_string(launch_remap[id]) +
             ") * get_local_size(" + sid + ") + get_local_id(" + sid + ")";
    }
  };
  if (op_def.IsBatchSupported()) {
    c += "  int linear_id = " + GetGlobalID(0) + ";\n";
140 141 142 143
    c += "  int X0 = linear_id / args.dst_tensor.Batch();\n";
    c += "  int B = linear_id % args.dst_tensor.Batch();\n";
    c += "  int DST_X = X0 * 2 * args.dst_tensor.Batch() + B;\n";
    c += "  int SRC_X = linear_id + args.padding_x;\n";
144 145 146
  } else {
    c += "  int X = " + GetGlobalID(0) + ";\n";
    c += "  int DST_X = X * 2;\n";
147
    c += "  int SRC_X = X + args.padding_x;\n";
148 149 150
  }
  c += "  int Y = " + GetGlobalID(1) + ";\n";
  c += "  int DST_Y = Y * 2;\n";
151
  c += "  int SRC_Y = Y + args.padding_y;\n";
152 153
  c += "  int Z = " + GetGlobalID(2) + ";\n";
  if (!need_local_mem) {
154 155
    c += "  if (DST_X >= args.dst_tensor.Width() || DST_Y >= "
         "args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) return;\n";
156 157 158 159 160
  }
  c += "  ACCUM_FLT4 r0 = (ACCUM_FLT4)(0.0f);\n";
  c += "  ACCUM_FLT4 r1 = (ACCUM_FLT4)(0.0f);\n";
  c += "  ACCUM_FLT4 r2 = (ACCUM_FLT4)(0.0f);\n";
  c += "  ACCUM_FLT4 r3 = (ACCUM_FLT4)(0.0f);\n";
161
  c += "  int f_offset = Z * args.filter_offset;\n";
162 163 164 165 166 167 168
  if (need_local_mem) {
    c += "  __local FLT4 weights_cache[36];\n";
  }
  if (weights_upload_type ==
      ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
    c += "  int local_id = (int)(get_local_id(1) * 8 + get_local_id(0));\n";
  }
169 170
  const std::string next_x = "SRC_X + " + pixel_stride;
  if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
171 172 173
    c += "  bool in_x0 = SRC_X >= 0 && SRC_X < args.src_tensor.Width();\n";
    c += "  bool in_x1 = " + next_x + " >= 0 && " + next_x +
         " < args.src_tensor.Width();\n";
174 175
  }
  if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
176 177 178
    c += "  bool in_y0 = SRC_Y >= 0 && SRC_Y < args.src_tensor.Height();\n";
    c += "  bool in_y1 = SRC_Y + 1 >= 0 && SRC_Y + 1 < "
         "args.src_tensor.Height();\n";
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
  }
  auto generate_check = [&](int x, int y) {
    std::string check;
    const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT};
    const std::vector<std::string> names{"in_x" + std::to_string(x),
                                         "in_y" + std::to_string(y)};
    for (int i = 0; i < axes.size(); ++i) {
      const auto& axis = axes[i];
      if (src_desc.HasAxis(axis) && !src_desc.SupportsZeroClamp(axis)) {
        if (!check.empty()) {
          check += " && ";
        }
        check += names[i];
      }
    }
    return check;
  };
  if (src_desc.IsLinear()) {
    if (src_desc.ReturnsZeroForNegOneRead()) {
198 199 200 201
      c += "  args.src_tensor.GetAddress(addr_0, SRC_X, SRC_Y, 0);\n";
      c += "  args.src_tensor.GetAddress(addr_1," + next_x + ", SRC_Y, 0);\n";
      c += "  args.src_tensor.GetAddress(addr_2, SRC_X, SRC_Y + 1, 0);\n";
      c += "  args.src_tensor.GetAddress(addr_3," + next_x + ", SRC_Y+1, 0);\n";
202 203 204 205
      c += "  addr_0 = select(-1, addr_0, (in_x0 && in_y0));\n";
      c += "  addr_1 = select(-1, addr_1, (in_x1 && in_y0));\n";
      c += "  addr_2 = select(-1, addr_2, (in_x0 && in_y1));\n";
      c += "  addr_3 = select(-1, addr_3, (in_x1 && in_y1));\n";
206 207 208 209 210 211 212 213
      c += "  int dz_0 = select(0, args.src_tensor.SliceStride(), (in_x0 && "
           "in_y0));\n";
      c += "  int dz_1 = select(0, args.src_tensor.SliceStride(), (in_x1 && "
           "in_y0));\n";
      c += "  int dz_2 = select(0, args.src_tensor.SliceStride(), (in_x0 && "
           "in_y1));\n";
      c += "  int dz_3 = select(0, args.src_tensor.SliceStride(), (in_x1 && "
           "in_y1));\n";
214 215 216 217 218 219 220 221 222 223 224
    } else {
      c += "  int xc0 = clamp(SRC_X, 0, args.src_tensor.Width() - 1);\n";
      c += "  int xc1 = clamp(" + next_x +
           ", 0, args.src_tensor.Width() - 1);\n";
      c += "  int yc0 = clamp(SRC_Y, 0, args.src_tensor.Height() - 1);\n";
      c += "  int yc1 = clamp(SRC_Y + 1, 0, args.src_tensor.Height() - 1);\n";
      c += "  args.src_tensor.GetAddress(addr_0, xc0, yc0, 0);\n";
      c += "  args.src_tensor.GetAddress(addr_1, xc1, yc0, 0);\n";
      c += "  args.src_tensor.GetAddress(addr_2, xc0, yc1, 0);\n";
      c += "  args.src_tensor.GetAddress(addr_3, xc1, yc1, 0);\n";
      c += "  int dz = args.src_tensor.SliceStride();\n";
225 226 227
    }
  }
  auto read_src = [&](int x, int y) {
228
    if (src_desc.IsLinear()) {
229 230
      const std::string id = std::to_string(y * 2 + x);
      const std::string addr = "addr_" + std::to_string(y * 2 + x);
231
      if (src_desc.ReturnsZeroForNegOneRead()) {
232 233
        return "args.src_tensor.Read(" + addr + "); " + addr + " += dz_" + id +
               ";\n";
234
      } else {
235 236 237
        return "args.src_tensor.Read(" + addr + ") * (FLT)(in_x" +
               std::to_string(x) + " && in_y" + std::to_string(y) + "); " +
               addr + " += dz;\n";
238 239
      }
    } else {
240 241 242 243
      std::string check = generate_check(x, y);
      if (!check.empty()) {
        check = " * (FLT)(" + check + ")";
      }
244
      return "args.src_tensor.Read(SRC_X + " + std::to_string(x) + "*" +
245 246
             pixel_stride + ", SRC_Y + " + std::to_string(y) + ", s)" + check +
             ";\n";
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
    }
  };
  const int padding_x_rem = abs(padding.x) % 2;
  const int padding_y_rem = abs(padding.y) % 2;
  std::vector<std::pair<int, int>> permutation;
  if (padding_x_rem == 1 && padding_y_rem == 1) {
    permutation = {{0, 0}, {1, 0}, {1, 1}, {2, 0}, {2, 2},
                   {3, 0}, {3, 1}, {3, 2}, {3, 3}};
  } else if (padding_x_rem == 0 && padding_y_rem == 1) {
    permutation = {{0, 0}, {0, 1}, {1, 1}, {2, 0}, {2, 1},
                   {2, 2}, {2, 3}, {3, 1}, {3, 3}};
  } else if (padding_x_rem == 1 && padding_y_rem == 0) {
    permutation = {{0, 0}, {0, 2}, {1, 0}, {1, 1}, {1, 2},
                   {1, 3}, {2, 2}, {3, 2}, {3, 3}};
  } else {  // padding_x_rem == 0 && padding_y_rem == 0
    permutation = {{0, 0}, {0, 1}, {0, 2}, {0, 3}, {1, 1},
                   {1, 3}, {2, 2}, {2, 3}, {3, 3}};
  }
265
  c += "  for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n";
266 267 268 269 270
  if (need_local_mem) {
    c += "    barrier(CLK_LOCAL_MEM_FENCE);\n";
  }
  if (weights_upload_type ==
      ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_ASYNC) {
271 272
    c += "    async_work_group_copy(weights_cache, "
         "args.weights.GetPtr(f_offset), 36, "
273 274 275 276
         "0);\n";
  } else if (weights_upload_type ==
             ConvolutionTransposed3x3::WeightsUploadType::
                 LOCAL_MEM_BY_THREADS) {
277 278
    c += "    weights_cache[local_id] = args.weights.Read(f_offset + "
         "local_id);\n";
279
    c += "    if (local_id < 4) {\n";
280 281 282
    c += "      weights_cache[local_id + 32] = args.weights.Read(f_offset + "
         "local_id + "
         "32);\n";
283 284
    c += "    };\n";
  } else {  // GLOBAL_MEM/CONSTANT_MEM
285 286
    c += "    " + weights_space +
         " FLT4* weights_cache = args.weights.GetPtr(f_offset);\n";
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
  }
  c += "    FLT4 src0 = " + read_src(0, 0);
  c += "    FLT4 src1 = " + read_src(1, 0);
  c += "    FLT4 src2 = " + read_src(0, 1);
  c += "    FLT4 src3 = " + read_src(1, 1);
  c += "    f_offset += 36;\n";
  if (need_local_mem) {
    c += "    barrier(CLK_LOCAL_MEM_FENCE);\n";
  }
  for (int i = 0; i < 9; ++i) {
    const std::string r_name = "r" + std::to_string(permutation[i].first);
    const std::string s_name = "src" + std::to_string(permutation[i].second);
    const std::string w_name = std::to_string(i * 4);
    c += "    CONV(" + r_name + ", " + s_name + ", " + w_name + ");\n";
  }
  c += "  }\n";
  if (need_local_mem) {
304 305
    c += "  if (DST_X >= args.dst_tensor.Width() || DST_Y >= "
         "args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) return;\n";
306
  }
307
  c += "  FLT4 bias_val = args.biases.Read(Z);\n";
308 309 310 311 312 313 314
  for (int y = 0; y < 2; ++y) {
    for (int x = 0; x < 2; ++x) {
      const std::string s_x = std::to_string(x);
      const std::string s_y = std::to_string(y);
      const std::string id = std::to_string(y * 2 + x);
      const std::string x_c = "DST_X + " + s_x + " * " + pixel_stride;
      const std::string y_c = "DST_Y + " + s_y;
315 316
      c += "  if (" + x_c + " < args.dst_tensor.Width() && " + y_c +
           " < args.dst_tensor.Height()) {\n";
317
      c += "    FLT4 res0 = TO_FLT4(r" + id + ") + bias_val;\n";
318
      c += "    args.dst_tensor.Write(res0, " + x_c + ", " + y_c + ", Z);\n";
319 320 321 322 323 324 325
      c += "  }\n";
    }
  }
  c += "}\n";
  return c;
}

326 327
absl::Status ConvolutionTransposed3x3::BindArguments(ArgumentsBinder* args) {
  RETURN_IF_ERROR(args->SetInt("filter_offset", 4 * 9 * src_[0]->Slices()));
328 329 330 331
  const int padding_x =
      padding_.x >= 1 ? (padding_.x - 1) / 2 : (padding_.x - 2) / 2;
  const int padding_y =
      padding_.y >= 1 ? (padding_.y - 1) / 2 : (padding_.y - 2) / 2;
332 333
  RETURN_IF_ERROR(args->SetInt("padding_x", padding_x * src_[0]->Batch()));
  return args->SetInt("padding_y", padding_y);
334 335
}

336 337 338 339 340 341 342 343 344 345 346 347
void ConvolutionTransposed3x3::GetPossibleKernelWorkGroups(
    TuningType tuning_type, const DeviceInfo& device_info,
    const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
  if (weights_upload_type_ == WeightsUploadType::LOCAL_MEM_ASYNC ||
      weights_upload_type_ == WeightsUploadType::LOCAL_MEM_BY_THREADS) {
    work_groups->push_back(work_group_size_);
    return;
  }
  GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_,
                            work_groups);
}

348
int3 ConvolutionTransposed3x3::GetGridSize() const {
349 350
  const int grid_x = DivideRoundUp(dst_[0]->Width(), 2) * dst_[0]->Batch();
  const int grid_y = DivideRoundUp(dst_[0]->Height(), 2);
351
  const int grid_z = dst_[0]->Slices();
352
  return int3(grid_x, grid_y, grid_z);
353 354 355
}

bool IsConvolutionTransposed3x3Supported(
356
    const OperationDef& definition,
357 358 359 360 361
    const ConvolutionTransposedAttributes& attr) {
  return attr.weights.shape.w == 3 && attr.weights.shape.h == 3 &&
         attr.stride.w == 2 && attr.stride.h == 2;
}

362 363 364
ConvolutionTransposed3x3 CreateConvolutionTransposed3x3(
    const DeviceInfo& device_info, const OperationDef& definition,
    const ConvolutionTransposedAttributes& attr) {
365
  const int2 padding = int2(attr.padding.prepended.w, attr.padding.prepended.h);
366 367
  ConvolutionTransposed3x3 result(definition, device_info, padding);
  result.UploadWeights(attr.weights);
368 369 370 371

  TensorLinearDescriptor desc;
  desc.storage_type = LinearStorageType::TEXTURE_2D;
  desc.element_type = definition.GetDataType();
372
  desc.UploadLinearData(attr.bias);
373
  result.args_.AddObject(
374
      "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
375
  return result;
376 377 378 379 380
}

}  // namespace cl
}  // namespace gpu
}  // namespace tflite