selected_rows_functor.cu 14.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
typhoonzero 已提交
15
#include <set>
16
#include <vector>
T
typhoonzero 已提交
17

Y
Yi Wang 已提交
18 19
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
D
dzhwinter 已提交
20
#include "paddle/fluid/platform/cuda_primitives.h"
21
#include "paddle/fluid/platform/float16.h"
22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {
template <typename T>
Q
QI JUN 已提交
27 28
struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
29 30 31 32 33 34 35
                  const framework::SelectedRows& input1,
                  const framework::SelectedRows& input2,
                  framework::SelectedRows* output) {
    auto in1_height = input1.height();
    PADDLE_ENFORCE_EQ(in1_height, input2.height());
    output->set_height(in1_height);

D
dzhwinter 已提交
36
    framework::Vector<int64_t> in1_rows(input1.rows());
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
    PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();

    auto in1_place = input1.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
    auto in2_place = input2.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in2_place));
    auto out_place = context.GetPlace();
    PADDLE_ENFORCE(platform::is_gpu_place(out_place));

    memory::Copy(
D
dzhwinter 已提交
65 66
        boost::get<platform::CUDAPlace>(out_place), out_data,
        boost::get<platform::CUDAPlace>(in1_place), in1_data,
67 68 69 70
        in1_value.numel() * sizeof(T),
        reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());

    auto* in2_data = in2_value.data<T>();
D
dzhwinter 已提交
71
    memory::Copy(boost::get<platform::CUDAPlace>(out_place),
Q
QI JUN 已提交
72
                 out_data + in1_value.numel(),
D
dzhwinter 已提交
73
                 boost::get<platform::CUDAPlace>(in2_place), in2_data,
Q
QI JUN 已提交
74
                 in2_value.numel() * sizeof(T), context.stream());
75 76 77
  }
};

Q
QI JUN 已提交
78 79
template struct SelectedRowsAdd<platform::CUDADeviceContext, float>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, double>;
80
template struct SelectedRowsAdd<platform::CUDADeviceContext, platform::float16>;
81 82

namespace {
Q
QI JUN 已提交
83
template <typename T, int block_size>
84 85
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
                                            const int64_t* rows, T* tensor_out,
Q
QI JUN 已提交
86
                                            int64_t row_numel) {
87 88 89 90 91 92 93 94 95 96
  const int ty = blockIdx.y;
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we can not use
    // tensor_out[index] += selected_rows[index]; Instead, we have to use
    // AtomicAdd to avoid concurrent write error.
Q
qijun 已提交
97
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
98 99 100 101 102
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
103 104
struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
105 106 107 108 109 110 111 112 113
                  const framework::SelectedRows& input1,
                  const framework::Tensor& input2, framework::Tensor* output) {
    auto in1_height = input1.height();
    auto in2_dims = input2.dims();
    auto out_dims = output->dims();
    PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
    PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);

    auto& in1_value = input1.value();
D
dzhwinter 已提交
114
    framework::Vector<int64_t> in1_rows(input1.rows());
115 116 117 118 119 120 121 122 123

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
    PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2.data<T>();
    auto* out_data = output->data<T>();

Q
QI JUN 已提交
124
    SetConstant<platform::CUDADeviceContext, T> functor;
125
    functor(context, output, static_cast<T>(0));
126

Q
QI JUN 已提交
127
    const int block_size = 256;
128
    dim3 threads(block_size, 1);
Q
qijun 已提交
129
    dim3 grid(1, in1_rows.size());
Q
QI JUN 已提交
130 131
    SelectedRowsAddTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
132 133
        in1_data, in1_rows.CUDAData(context.GetPlace()), out_data,
        in1_row_numel);
134 135 136

    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
    auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
137
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
138 139 140
  }
};

Q
QI JUN 已提交
141 142
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>;
143 144
template struct SelectedRowsAddTensor<platform::CUDADeviceContext,
                                      platform::float16>;
Q
QI JUN 已提交
145 146

template <typename T>
Q
QI JUN 已提交
147 148
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
Q
QI JUN 已提交
149 150 151 152 153 154
                  const framework::SelectedRows& input1,
                  const int64_t input2_offset,
                  framework::SelectedRows* input2) {
    auto in1_height = input1.height();
    PADDLE_ENFORCE_EQ(in1_height, input2->height());

D
dzhwinter 已提交
155
    framework::Vector<int64_t> in1_rows(input1.rows());
Q
QI JUN 已提交
156 157 158 159 160 161
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
Y
Fix CI  
Yu Yang 已提交
162 163 164
    if (in1_rows.size()) {
      in2_rows.Extend(in1_rows.begin(), in1_rows.end());
    }
Q
QI JUN 已提交
165 166 167 168 169 170 171 172

    auto in1_place = input1.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
    auto in2_place = input2->place();
    PADDLE_ENFORCE(platform::is_gpu_place(in2_place));

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
D
dzhwinter 已提交
173
    memory::Copy(boost::get<platform::CUDAPlace>(in2_place),
Q
QI JUN 已提交
174
                 in2_data + input2_offset,
D
dzhwinter 已提交
175
                 boost::get<platform::CUDAPlace>(in1_place), in1_data,
Q
QI JUN 已提交
176
                 in1_value.numel() * sizeof(T), context.stream());
Q
QI JUN 已提交
177 178 179
  }
};

Q
QI JUN 已提交
180 181 182 183
template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>;
184 185
template struct SelectedRowsAddTo<platform::CUDADeviceContext,
                                  platform::float16>;
Q
QI JUN 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207

namespace {
template <typename T, int block_size>
__global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
                                              const int64_t* rows,
                                              T* tensor_out,
                                              int64_t row_numel) {
  const int ty = blockIdx.y;
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we have to use
    // Atomic Operation to avoid concurrent write error.
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
208 209
struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
Q
QI JUN 已提交
210 211 212 213 214 215 216
                  const framework::SelectedRows& input1,
                  framework::Tensor* input2) {
    auto in1_height = input1.height();
    auto in2_dims = input2->dims();
    PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);

    auto& in1_value = input1.value();
D
dzhwinter 已提交
217
    framework::Vector<int64_t> in1_rows(input1.rows());
Q
QI JUN 已提交
218 219 220 221 222 223 224 225 226

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2->data<T>();
    const int block_size = 256;
    dim3 threads(block_size, 1);
    dim3 grid(1, in1_rows.size());
Q
QI JUN 已提交
227 228
    SelectedRowsAddToTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
229 230
        in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data,
        in1_row_numel);
Q
QI JUN 已提交
231 232 233
  }
};

Q
QI JUN 已提交
234 235 236 237
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
238 239
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext,
                                        platform::float16>;
T
typhoonzero 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

namespace scatter {

template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
                               T* out, const int64_t* out_rows,
                               size_t out_rows_size, int64_t row_numel) {
  const int ty = blockIdx.y;
  int tid = threadIdx.x;
  __shared__ size_t out_idx;

  if (tid == 0) {
    for (size_t i = 0; i < out_rows_size; i++) {
      if (input_rows[ty] == out_rows[i]) {
        out_idx = i;
      }
    }
  }

  __syncthreads();

  input += ty * row_numel;
  out += out_idx * row_numel;
  for (int index = tid; index < row_numel; index += block_size) {
    paddle::platform::CudaAtomicAdd(out + index, input[index]);
  }
}

template <typename T>
T
typhoonzero 已提交
269 270
struct MergeAdd<platform::CUDADeviceContext, T> {
  framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
T
wip  
typhoonzero 已提交
271 272
                                     const framework::SelectedRows& input) {
    framework::SelectedRows out;
D
dzhwinter 已提交
273
    framework::Vector<int64_t> input_rows(input.rows());
T
typhoonzero 已提交
274 275 276 277
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
    std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());

    auto input_width = input.value().dims()[1];
T
wip  
typhoonzero 已提交
278 279 280 281

    out.set_rows(merge_rows);
    out.set_height(input.height());
    out.mutable_value()->mutable_data<T>(
T
typhoonzero 已提交
282 283 284 285 286
        framework::make_ddim(
            {static_cast<int64_t>(merge_rows.size()), input_width}),
        context.GetPlace());

    math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
287
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
T
typhoonzero 已提交
288

T
wip  
typhoonzero 已提交
289
    auto* out_data = out.mutable_value()->data<T>();
T
typhoonzero 已提交
290 291 292 293 294 295 296 297 298
    auto* input_data = input.value().data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);
    dim3 grid1(1, input_rows.size());

    MergeAddKernel<
        T, 256><<<grid1, threads, 0,
                  reinterpret_cast<const platform::CUDADeviceContext&>(context)
Y
Yu Yang 已提交
299 300 301 302
                      .stream()>>>(
        input_data, input_rows.CUDAData(context.GetPlace()), out_data,
        out.mutable_rows()->CUDAMutableData(context.GetPlace()),
        out.rows().size(), input_width);
T
wip  
typhoonzero 已提交
303
    return out;
T
typhoonzero 已提交
304 305 306
  }
};

T
typhoonzero 已提交
307 308 309 310
template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
311
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
T
wip  
typhoonzero 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362

template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
                                     const int64_t* rows, const ScatterOps& op,
                                     T* tensor_out, int64_t row_numel) {
  const int ty = blockIdx.y;
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;
  // FIXME(typhoonzero): use macro fix the below messy code.
  switch (op) {
    case ScatterOps::ASSIGN:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index];
      }
      break;
    case ScatterOps::ADD:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] += selected_rows[index];
      }
      break;
    case ScatterOps::SUB:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] -= selected_rows[index];
      }
      break;
    case ScatterOps::SUBBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] - tensor_out[index];
      }
      break;
    case ScatterOps::MUL:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] *= selected_rows[index];
      }
      break;
    case ScatterOps::DIV:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] /= selected_rows[index];
      }
      break;
    case ScatterOps::DIVBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] / tensor_out[index];
      }
      break;
  }
}

template <typename T>
T
typhoonzero 已提交
363 364 365 366
struct UpdateToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
                  const ScatterOps& op, const framework::SelectedRows& input1,
                  framework::Tensor* input2) {
T
wip  
typhoonzero 已提交
367 368
    // NOTE: Use SelectedRowsAddToTensor for better performance
    //       no additional MergeAdd called.
T
typhoonzero 已提交
369 370
    MergeAdd<platform::CUDADeviceContext, T> merge_func;
    auto merged_in1 = merge_func(context, input1);
T
wip  
typhoonzero 已提交
371 372 373 374 375 376 377 378 379 380 381

    auto in1_height = merged_in1.height();
    auto in2_dims = input2->dims();
    PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);

    auto& in1_value = merged_in1.value();
    auto& in1_rows = merged_in1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);

T
typhoonzero 已提交
382 383
    auto* in1_data = in1_value.template data<T>();
    auto* in2_data = input2->data<T>();
T
wip  
typhoonzero 已提交
384

T
typhoonzero 已提交
385
    dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
T
wip  
typhoonzero 已提交
386
    dim3 grid(1, in1_rows.size());
T
typhoonzero 已提交
387
    UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
D
dzhwinter 已提交
388 389
        grid, threads, 0, context.stream()>>>(in1_data, in1_rows.cuda_data(),
                                              op, in2_data, in1_row_numel);
T
wip  
typhoonzero 已提交
390 391
  }
};
T
typhoonzero 已提交
392
}  // namespace scatter
393 394 395
}  // namespace math
}  // namespace operators
}  // namespace paddle