selected_rows_functor.cu 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct SelectedRowsAdd<platform::GPUPlace, T> {
  void operator()(const platform::DeviceContext& context,
                  const framework::SelectedRows& input1,
                  const framework::SelectedRows& input2,
                  framework::SelectedRows* output) {
    auto in1_height = input1.height();
    PADDLE_ENFORCE_EQ(in1_height, input2.height());
    output->set_height(in1_height);

    auto& in1_rows = input1.rows();
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
    PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();

    auto in1_place = input1.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
    auto in2_place = input2.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in2_place));
    auto out_place = context.GetPlace();
    PADDLE_ENFORCE(platform::is_gpu_place(out_place));

    memory::Copy(
        boost::get<platform::GPUPlace>(out_place), out_data,
        boost::get<platform::GPUPlace>(in1_place), in1_data,
        in1_value.numel() * sizeof(T),
        reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());

    auto* in2_data = in2_value.data<T>();
    memory::Copy(
        boost::get<platform::GPUPlace>(out_place), out_data + in1_value.numel(),
        boost::get<platform::GPUPlace>(in2_place), in2_data,
        in2_value.numel() * sizeof(T),
        reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
  }
};

template struct SelectedRowsAdd<platform::GPUPlace, float>;

namespace {
template <typename T>
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
                                            const int64_t* rows, T* tensor_out,
                                            int64_t row_numel, int block_size) {
  const int ty = blockIdx.y;
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we can not use
    // tensor_out[index] += selected_rows[index]; Instead, we have to use
    // AtomicAdd to avoid concurrent write error.
Q
qijun 已提交
92
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
  }
}
}  // namespace

template <typename T>
struct SelectedRowsAddTensor<platform::GPUPlace, T> {
  void operator()(const platform::DeviceContext& context,
                  const framework::SelectedRows& input1,
                  const framework::Tensor& input2, framework::Tensor* output) {
    auto in1_height = input1.height();
    auto in2_dims = input2.dims();
    auto out_dims = output->dims();
    PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
    PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
    PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2.data<T>();
    auto* out_data = output->data<T>();

    SetConstant<platform::GPUPlace, T> functor;
    functor(context, output, 0.0);

    int block_size = 256;
    dim3 threads(block_size, 1);
Q
qijun 已提交
124
    dim3 grid(1, in1_rows.size());
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    SelectedRowsAddTensorKernel<
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(in1_data, in1_rows.data(), out_data,
                              in1_row_numel, block_size);

    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
    auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
    out_eigen.device(*context.GetEigenDevice<platform::GPUPlace>()) =
        out_eigen + in2_eigen;
  }
};

template struct SelectedRowsAddTensor<platform::GPUPlace, float>;

}  // namespace math
}  // namespace operators
}  // namespace paddle