selected_rows_functor.cu 14.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
typhoonzero 已提交
15 16
#include <set>

17 18 19 20 21 22 23 24
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {
template <typename T>
Q
QI JUN 已提交
25 26
struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
27 28 29 30 31 32 33
                  const framework::SelectedRows& input1,
                  const framework::SelectedRows& input2,
                  framework::SelectedRows* output) {
    auto in1_height = input1.height();
    PADDLE_ENFORCE_EQ(in1_height, input2.height());
    output->set_height(in1_height);

D
dzhwinter 已提交
34
    framework::Vector<int64_t> in1_rows(input1.rows());
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
    PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();

    auto in1_place = input1.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
    auto in2_place = input2.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in2_place));
    auto out_place = context.GetPlace();
    PADDLE_ENFORCE(platform::is_gpu_place(out_place));

    memory::Copy(
D
dzhwinter 已提交
63 64
        boost::get<platform::CUDAPlace>(out_place), out_data,
        boost::get<platform::CUDAPlace>(in1_place), in1_data,
65 66 67 68
        in1_value.numel() * sizeof(T),
        reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());

    auto* in2_data = in2_value.data<T>();
D
dzhwinter 已提交
69
    memory::Copy(boost::get<platform::CUDAPlace>(out_place),
Q
QI JUN 已提交
70
                 out_data + in1_value.numel(),
D
dzhwinter 已提交
71
                 boost::get<platform::CUDAPlace>(in2_place), in2_data,
Q
QI JUN 已提交
72
                 in2_value.numel() * sizeof(T), context.stream());
73 74 75
  }
};

Q
QI JUN 已提交
76 77
template struct SelectedRowsAdd<platform::CUDADeviceContext, float>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, double>;
78 79

namespace {
Q
QI JUN 已提交
80
template <typename T, int block_size>
81 82
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
                                            const int64_t* rows, T* tensor_out,
Q
QI JUN 已提交
83
                                            int64_t row_numel) {
84 85 86 87 88 89 90 91 92 93
  const int ty = blockIdx.y;
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we can not use
    // tensor_out[index] += selected_rows[index]; Instead, we have to use
    // AtomicAdd to avoid concurrent write error.
Q
qijun 已提交
94
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
95 96 97 98 99
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
100 101
struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
102 103 104 105 106 107 108 109 110
                  const framework::SelectedRows& input1,
                  const framework::Tensor& input2, framework::Tensor* output) {
    auto in1_height = input1.height();
    auto in2_dims = input2.dims();
    auto out_dims = output->dims();
    PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
    PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);

    auto& in1_value = input1.value();
D
dzhwinter 已提交
111
    framework::Vector<int64_t> in1_rows(input1.rows());
112 113 114 115 116 117 118 119 120

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
    PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2.data<T>();
    auto* out_data = output->data<T>();

Q
QI JUN 已提交
121
    SetConstant<platform::CUDADeviceContext, T> functor;
122 123
    functor(context, output, 0.0);

Q
QI JUN 已提交
124
    const int block_size = 256;
125
    dim3 threads(block_size, 1);
Q
qijun 已提交
126
    dim3 grid(1, in1_rows.size());
Q
QI JUN 已提交
127 128
    SelectedRowsAddTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
129 130
        in1_data, in1_rows.CUDAData(context.GetPlace()), out_data,
        in1_row_numel);
131 132 133

    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
    auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
134
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
135 136 137
  }
};

Q
QI JUN 已提交
138 139
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>;
Q
QI JUN 已提交
140 141

template <typename T>
Q
QI JUN 已提交
142 143
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
Q
QI JUN 已提交
144 145 146 147 148 149
                  const framework::SelectedRows& input1,
                  const int64_t input2_offset,
                  framework::SelectedRows* input2) {
    auto in1_height = input1.height();
    PADDLE_ENFORCE_EQ(in1_height, input2->height());

D
dzhwinter 已提交
150
    framework::Vector<int64_t> in1_rows(input1.rows());
Q
QI JUN 已提交
151 152 153 154 155 156
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
Y
Yu Yang 已提交
157
    in2_rows.Extend(in1_rows.begin(), in1_rows.end());
Q
QI JUN 已提交
158 159 160 161 162 163 164 165

    auto in1_place = input1.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
    auto in2_place = input2->place();
    PADDLE_ENFORCE(platform::is_gpu_place(in2_place));

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
D
dzhwinter 已提交
166
    memory::Copy(boost::get<platform::CUDAPlace>(in2_place),
Q
QI JUN 已提交
167
                 in2_data + input2_offset,
D
dzhwinter 已提交
168
                 boost::get<platform::CUDAPlace>(in1_place), in1_data,
Q
QI JUN 已提交
169
                 in1_value.numel() * sizeof(T), context.stream());
Q
QI JUN 已提交
170 171 172
  }
};

Q
QI JUN 已提交
173 174 175 176
template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>;
Q
QI JUN 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198

namespace {
template <typename T, int block_size>
__global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
                                              const int64_t* rows,
                                              T* tensor_out,
                                              int64_t row_numel) {
  const int ty = blockIdx.y;
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we have to use
    // Atomic Operation to avoid concurrent write error.
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
199 200
struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
Q
QI JUN 已提交
201 202 203 204 205 206 207
                  const framework::SelectedRows& input1,
                  framework::Tensor* input2) {
    auto in1_height = input1.height();
    auto in2_dims = input2->dims();
    PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);

    auto& in1_value = input1.value();
D
dzhwinter 已提交
208
    framework::Vector<int64_t> in1_rows(input1.rows());
Q
QI JUN 已提交
209 210 211 212 213 214 215 216 217

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2->data<T>();
    const int block_size = 256;
    dim3 threads(block_size, 1);
    dim3 grid(1, in1_rows.size());
Q
QI JUN 已提交
218 219
    SelectedRowsAddToTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
220 221
        in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data,
        in1_row_numel);
Q
QI JUN 已提交
222 223 224
  }
};

Q
QI JUN 已提交
225 226 227 228
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
T
typhoonzero 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257

namespace scatter {

template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
                               T* out, const int64_t* out_rows,
                               size_t out_rows_size, int64_t row_numel) {
  const int ty = blockIdx.y;
  int tid = threadIdx.x;
  __shared__ size_t out_idx;

  if (tid == 0) {
    for (size_t i = 0; i < out_rows_size; i++) {
      if (input_rows[ty] == out_rows[i]) {
        out_idx = i;
      }
    }
  }

  __syncthreads();

  input += ty * row_numel;
  out += out_idx * row_numel;
  for (int index = tid; index < row_numel; index += block_size) {
    paddle::platform::CudaAtomicAdd(out + index, input[index]);
  }
}

template <typename T>
T
typhoonzero 已提交
258 259
struct MergeAdd<platform::CUDADeviceContext, T> {
  framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
T
wip  
typhoonzero 已提交
260 261
                                     const framework::SelectedRows& input) {
    framework::SelectedRows out;
D
dzhwinter 已提交
262
    framework::Vector<int64_t> input_rows(input.rows());
T
typhoonzero 已提交
263 264 265 266
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
    std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());

    auto input_width = input.value().dims()[1];
T
wip  
typhoonzero 已提交
267 268 269 270

    out.set_rows(merge_rows);
    out.set_height(input.height());
    out.mutable_value()->mutable_data<T>(
T
typhoonzero 已提交
271 272 273 274 275
        framework::make_ddim(
            {static_cast<int64_t>(merge_rows.size()), input_width}),
        context.GetPlace());

    math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
T
wip  
typhoonzero 已提交
276
    constant_functor(context, out.mutable_value(), 0.0);
T
typhoonzero 已提交
277

T
wip  
typhoonzero 已提交
278
    auto* out_data = out.mutable_value()->data<T>();
T
typhoonzero 已提交
279 280 281 282 283 284 285 286 287
    auto* input_data = input.value().data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);
    dim3 grid1(1, input_rows.size());

    MergeAddKernel<
        T, 256><<<grid1, threads, 0,
                  reinterpret_cast<const platform::CUDADeviceContext&>(context)
Y
Yu Yang 已提交
288 289 290 291
                      .stream()>>>(
        input_data, input_rows.CUDAData(context.GetPlace()), out_data,
        out.mutable_rows()->CUDAMutableData(context.GetPlace()),
        out.rows().size(), input_width);
T
wip  
typhoonzero 已提交
292
    return out;
T
typhoonzero 已提交
293 294 295
  }
};

T
typhoonzero 已提交
296 297 298 299
template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
T
wip  
typhoonzero 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350

template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
                                     const int64_t* rows, const ScatterOps& op,
                                     T* tensor_out, int64_t row_numel) {
  const int ty = blockIdx.y;
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;
  // FIXME(typhoonzero): use macro fix the below messy code.
  switch (op) {
    case ScatterOps::ASSIGN:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index];
      }
      break;
    case ScatterOps::ADD:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] += selected_rows[index];
      }
      break;
    case ScatterOps::SUB:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] -= selected_rows[index];
      }
      break;
    case ScatterOps::SUBBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] - tensor_out[index];
      }
      break;
    case ScatterOps::MUL:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] *= selected_rows[index];
      }
      break;
    case ScatterOps::DIV:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] /= selected_rows[index];
      }
      break;
    case ScatterOps::DIVBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] / tensor_out[index];
      }
      break;
  }
}

template <typename T>
T
typhoonzero 已提交
351 352 353 354
struct UpdateToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
                  const ScatterOps& op, const framework::SelectedRows& input1,
                  framework::Tensor* input2) {
T
wip  
typhoonzero 已提交
355 356
    // NOTE: Use SelectedRowsAddToTensor for better performance
    //       no additional MergeAdd called.
T
typhoonzero 已提交
357 358
    MergeAdd<platform::CUDADeviceContext, T> merge_func;
    auto merged_in1 = merge_func(context, input1);
T
wip  
typhoonzero 已提交
359 360 361 362 363 364 365 366 367 368 369

    auto in1_height = merged_in1.height();
    auto in2_dims = input2->dims();
    PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);

    auto& in1_value = merged_in1.value();
    auto& in1_rows = merged_in1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);

T
typhoonzero 已提交
370 371
    auto* in1_data = in1_value.template data<T>();
    auto* in2_data = input2->data<T>();
T
wip  
typhoonzero 已提交
372

T
typhoonzero 已提交
373
    dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
T
wip  
typhoonzero 已提交
374
    dim3 grid(1, in1_rows.size());
T
typhoonzero 已提交
375
    UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
D
dzhwinter 已提交
376 377
        grid, threads, 0, context.stream()>>>(in1_data, in1_rows.cuda_data(),
                                              op, in2_data, in1_row_numel);
T
wip  
typhoonzero 已提交
378 379
  }
};
T
typhoonzero 已提交
380
}  // namespace scatter
381 382 383
}  // namespace math
}  // namespace operators
}  // namespace paddle