concat_kernel.cl 5.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <cl_common.h>

__kernel void concat2(__read_only image2d_t input0,
15 16 17
                      __read_only image2d_t input1,
                      __write_only image2d_t output,
                      int flag, int C_0, int out_C, int out_W, int width) {
18 19 20
  const int out_w = get_global_id(0); // image_width cxw/4
  const int out_c = get_global_id(1); // image_width cxw/4
  const int out_nh = get_global_id(2); // image_height nxh
21 22 23 24

  const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
                            CLK_ADDRESS_CLAMP |
                            CLK_FILTER_NEAREST;
25 26 27 28 29 30 31
  if (flag == 1){ // by channel
    int c_in = out_c;
    int2 output_pos;
    output_pos.x = out_c * out_W + out_w;
    output_pos.y = out_nh;
    CL_DTYPE4 output_data;
    for (int i = 0; i < 4; i++) {
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
      int c = out_c * 4 + i;
      if (c >= out_C) {
        break;
      }
      int c_in;
      CL_DTYPE4 input_data;
      if (c < C_0) {
        c_in = c;
        int2 input_pos;
        input_pos.x = (c_in / 4) * out_W + out_w;
        input_pos.y = out_nh;
        input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, input_pos);
      } else {
        c_in = c - C_0;
        int2 input_pos;
        input_pos.x = (c_in / 4) * out_W + out_w;
        input_pos.y = out_nh;
        input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, sampler, input_pos);
      }
      int value_offset = c_in % 4;
      CL_DTYPE value;
      if (value_offset == 0) {
        value = input_data.x;
      } else if (value_offset == 1) {
        value = input_data.y;
      } else if (value_offset == 2) {
        value = input_data.z;
      } else if (value_offset == 3) {
        value = input_data.w;
      }
      if (i == 0) {
        output_data.x = value;
      } else if (i == 1) {
        output_data.y = value;
      } else if (i == 2) {
        output_data.z = value;
      } else if (i == 3) {
        output_data.w = value;
70 71 72 73 74 75
      }
    }
    WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, output_data);
  }else if (flag == 2){ // by height,  width == n
    int2 input_pos;
    input_pos.x = out_c * out_W + out_w;
76
    int h = out_nh / width;
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
    CL_DTYPE4 input;
    if (h < C_0){
      input_pos.y = out_nh;
      input = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, input_pos);
    }else{
      input_pos.y = (h - C_0) * width;
      input = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, sampler, input_pos);
    }
    int2 output_pos;
    output_pos.x = out_c * out_W + out_w;
    output_pos.y = out_nh;
    WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, input);
  }else if (flag == 3){ // by width, width == C
    int2 input_pos;
    input_pos.y = out_nh;
    CL_DTYPE4 input;
    if (out_w < C_0){
      input_pos.x = out_c * out_W + out_w;
      input = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, input_pos);
    }else{
      input_pos.x = out_c * out_W + (out_w - C_0);
      input = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, sampler, input_pos);
    }
    int2 output_pos;
    output_pos.x = out_c * out_W + out_w;
    output_pos.y = out_nh;
    WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, input);
104 105 106
  }
}

107
__kernel void concat_mul(__read_only image2d_t input,
108 109
                         __write_only image2d_t output,
                         int flag, int C_0, int out_C, int out_W, int in_W, int width) {
110 111 112
  const int in_w = get_global_id(0); // image_width cxw/4
  const int in_c = get_global_id(1); // image_width cxw/4
  const int in_nh = get_global_id(2); // image_height nxh
113 114 115 116

  const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
                            CLK_ADDRESS_CLAMP |
                            CLK_FILTER_NEAREST;
117 118 119 120 121 122 123 124
  int2 input_pos;
  int2 output_pos;
  input_pos.x = in_c * in_W + in_w;
  input_pos.y = in_nh;
  CL_DTYPE4 input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, input_pos);
  if (flag == 1){ // by channel
    CL_DTYPE4 output_data;
    for (int i = 0; i < 4; i++) {
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
      int c_out = C_0 + in_c * 4 + i;
      if (c_out >= out_C) {
        break;
      }
      int2 output_pos;
      output_pos.x = (c_out / 4) * in_W + in_w;
      output_pos.y = in_nh;
      CL_DTYPE val;
      if (i == 0) {
        val = input_data.x;
      } else if (i == 1) {
        val = input_data.y;
      } else if (i == 2) {
        val = input_data.z;
      } else if (i == 3) {
        val = input_data.w;
141 142
      }
      if (c_out % 4 == 0){
143
        output_data.x = val;
144
      }else if (c_out % 4 == 1){
145 146 147 148 149 150
        output_data.y = val;
      }else if (c_out % 4 == 2){
        output_data.z = val;
      }else if (c_out % 4 == 3){
        output_data.w = val;
      }
151 152 153 154 155 156 157 158 159 160 161 162
      WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, output_data);
    }
  }else if (flag == 2){ // by height, width == n
    int2 output_pos;
    output_pos.x = in_c * in_W + in_w;
    output_pos.y = in_nh + C_0 * width;
    WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, input_data);
  }else if (flag == 3){ // by width, width == C
    int2 output_pos;
    output_pos.y = in_nh;
    output_pos.x = in_c * out_W + (in_w + C_0);
    WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, input_data);
163
  }
164
}