提交 ab5dc9fe 编写于 作者: Q qijun

remove SelectedRows functors to selected_rows_functor.h

上级 690b0412
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor selected_rows)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor)
nv_library(selected_rows_functor SRCS selected_rows_functor.cc selected_rows_functor.cu DEPS selected_rows math_function)
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor selected_rows)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
......@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
#include "paddle/framework/eigen.h"
#include "paddle/memory/memcpy.h"
namespace paddle {
namespace operators {
......@@ -134,97 +132,6 @@ void matmul<platform::CPUPlace, double>(
template struct SetConstant<platform::CPUPlace, float>;
template <typename T>
struct SelectedRowsAdd<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* output) {
auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2.height());
output->set_height(in1_height);
auto& in1_rows = input1.rows();
auto& in2_rows = input2.rows();
std::vector<int64_t> out_rows;
out_rows.reserve(in1_rows.size() + in2_rows.size());
// concat rows
out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
output->set_rows(out_rows);
auto* out_value = output->mutable_value();
auto& in1_value = input1.value();
auto& in2_value = input2.value();
auto in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());
auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_cpu_place(in1_place));
auto in2_place = input2.place();
PADDLE_ENFORCE(platform::is_cpu_place(in2_place));
auto out_place = context.GetPlace();
PADDLE_ENFORCE(platform::is_cpu_place(out_place));
auto* out_data = out_value->data<T>();
auto* in1_data = in1_value.data<T>();
memory::Copy(boost::get<platform::CPUPlace>(out_place), out_data,
boost::get<platform::CPUPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T));
auto* in2_data = in2_value.data<T>();
memory::Copy(boost::get<platform::CPUPlace>(out_place),
out_data + in1_value.numel(),
boost::get<platform::CPUPlace>(in2_place), in2_data,
in2_value.numel() * sizeof(T));
}
};
template struct SelectedRowsAdd<platform::CPUPlace, float>;
template <typename T>
struct SelectedRowsAddTensor<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output) {
auto in1_height = input1.height();
auto in2_dims = input2.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);
SetConstant<platform::CPUPlace, T> functor;
functor(context, output, 0.0);
auto* in1_data = in1_value.data<T>();
auto* out_data = output->data<T>();
for (size_t i = 0; i < in1_rows.size(); i++) {
for (int64_t j = 0; j < in1_row_numel; j++) {
out_data[in1_rows[i] * in1_row_numel + j] +=
in1_data[i * in1_row_numel + j];
}
}
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
out_eigen.device(*context.GetEigenDevice<platform::CPUPlace>()) =
out_eigen + in2_eigen;
}
};
template struct SelectedRowsAddTensor<platform::CPUPlace, float>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
......@@ -156,123 +155,7 @@ void matmul<platform::GPUPlace, double>(
matrix_b.data<double>(), beta, matrix_out->data<double>());
}
template <typename T>
struct SelectedRowsAdd<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* output) {
auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2.height());
output->set_height(in1_height);
auto& in1_rows = input1.rows();
auto& in2_rows = input2.rows();
std::vector<int64_t> out_rows;
out_rows.reserve(in1_rows.size() + in2_rows.size());
// concat rows
out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
output->set_rows(out_rows);
auto* out_value = output->mutable_value();
auto& in1_value = input1.value();
auto& in2_value = input2.value();
auto in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());
auto* out_data = out_value->data<T>();
auto* in1_data = in1_value.data<T>();
auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
auto in2_place = input2.place();
PADDLE_ENFORCE(platform::is_gpu_place(in2_place));
auto out_place = context.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(out_place));
memory::Copy(
boost::get<platform::GPUPlace>(out_place), out_data,
boost::get<platform::GPUPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
auto* in2_data = in2_value.data<T>();
memory::Copy(
boost::get<platform::GPUPlace>(out_place), out_data + in1_value.numel(),
boost::get<platform::GPUPlace>(in2_place), in2_data,
in2_value.numel() * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
}
};
template struct SelectedRowsAdd<platform::GPUPlace, float>;
namespace {
template <typename T>
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
const int64_t* rows, T* tensor_out,
int64_t row_numel, int block_size) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
selected_rows += ty * row_numel;
tensor_out += rows[ty] * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
// Since index in rows of SelectedRows can be duplicate, we can not use
// tensor_out[index] += selected_rows[index]; Instead, we have to use
// AtomicAdd to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(&tensor_out[index], selected_rows[index]);
}
}
} // namespace
template <typename T>
struct SelectedRowsAddTensor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output) {
auto in1_height = input1.height();
auto in2_dims = input2.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* in2_data = input2.data<T>();
auto* out_data = output->data<T>();
SetConstant<platform::GPUPlace, T> functor;
functor(context, output, 0.0);
int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, in1_height);
SelectedRowsAddTensorKernel<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in1_data, in1_rows.data(), out_data,
in1_row_numel, block_size);
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
out_eigen.device(*context.GetEigenDevice<platform::GPUPlace>()) =
out_eigen + in2_eigen;
}
};
template struct SelectedRowsAddTensor<platform::GPUPlace, float>;
template struct SetConstant<platform::GPUPlace, float>;
} // namespace math
} // namespace operators
......
......@@ -53,7 +53,6 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include <cmath>
#include "paddle/framework/eigen.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
......@@ -96,23 +95,6 @@ struct SetConstant {
}
};
// SelectedRows + SelectedRows will simplely concat value and rows.
// The real computation happens in dealing with LoDTensor.
template <typename Place, typename T>
struct SelectedRowsAdd {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* output);
};
template <typename Place, typename T>
struct SelectedRowsAddTensor {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -89,93 +89,3 @@ TEST(math_function, zero) {
EXPECT_EQ(t[2], 1);
EXPECT_EQ(t[3], 1);
}
TEST(math_function, selected_rows_add) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators::math;
CPUPlace cpu_place;
CPUDeviceContext ctx(cpu_place);
SetConstant<CPUPlace, float> functor;
int64_t height = 10;
int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), cpu_place);
functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), cpu_place);
functor(ctx, in2_value, 2.0);
std::unique_ptr<SelectedRows> output{new SelectedRows()};
auto* out_value = output->mutable_value();
// simplely concat two SelectedRows
out_value->mutable_data<float>(make_ddim({7, 10}), cpu_place);
SelectedRowsAdd<CPUPlace, float> add_functor;
add_functor(ctx, *selected_rows1, *selected_rows2, output.get());
auto out_height = output->height();
EXPECT_EQ(out_height, height);
auto& out_rows = output->rows();
// input1 rows
EXPECT_EQ(out_rows[0], 0);
EXPECT_EQ(out_rows[1], 4);
EXPECT_EQ(out_rows[2], 7);
// input2 rows
EXPECT_EQ(out_rows[3], 0);
EXPECT_EQ(out_rows[4], 5);
EXPECT_EQ(out_rows[5], 7);
EXPECT_EQ(out_rows[6], 9);
auto* out_data = output->value().data<float>();
// input1 value
EXPECT_EQ(out_data[0 * row_numel + 0], 1.0);
EXPECT_EQ(out_data[0 * row_numel + 8], 1.0);
EXPECT_EQ(out_data[1 * row_numel + 1], 1.0);
EXPECT_EQ(out_data[2 * row_numel + 6], 1.0);
// input2 value
EXPECT_EQ(out_data[3 * row_numel + 3], 2.0);
EXPECT_EQ(out_data[3 * row_numel + 8], 2.0);
EXPECT_EQ(out_data[4 * row_numel + 4], 2.0);
EXPECT_EQ(out_data[5 * row_numel + 7], 2.0);
EXPECT_EQ(out_data[6 * row_numel + 9], 2.0);
std::unique_ptr<Tensor> tensor1{new Tensor()};
tensor1->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
SetConstant<CPUPlace, float> constant_functor;
constant_functor(ctx, tensor1.get(), 3.0);
std::unique_ptr<Tensor> tensor2{new Tensor()};
tensor2->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
SelectedRowsAddTensor<CPUPlace, float> add_tensor_functor;
add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
auto* tensor2_data = tensor2->data<float>();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor2_data[0 * row_numel + 0], 6.0);
// row1: 3.0
EXPECT_EQ(tensor2_data[1 * row_numel + 1], 3.0);
// row4 : 1.0 + 3.0
EXPECT_EQ(tensor2_data[4 * row_numel + 6], 4.0);
// row5: 2.0 + 3.0
EXPECT_EQ(tensor2_data[5 * row_numel + 7], 5.0);
// row6: 3.0
EXPECT_EQ(tensor2_data[6 * row_numel + 1], 3.0);
// row7: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor2_data[7 * row_numel + 3], 6.0);
// row9: 2.0 + 3.0
EXPECT_EQ(tensor2_data[9 * row_numel + 6], 5.0);
}
......@@ -177,102 +177,3 @@ TEST(math_function, gemm_trans_cublas) {
EXPECT_EQ(input3_ptr[7], 99);
delete gpu_place;
}
TEST(math_function, selected_rows_add) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators::math;
GPUPlace gpu_place(0);
CPUPlace cpu_place;
CUDADeviceContext ctx(gpu_place);
SetConstant<GPUPlace, float> functor;
int64_t height = 10;
int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), gpu_place);
functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), gpu_place);
functor(ctx, in2_value, 2.0);
std::unique_ptr<SelectedRows> output{new SelectedRows()};
auto* out_value = output->mutable_value();
// simplely concat two SelectedRows
out_value->mutable_data<float>(make_ddim({7, 10}), gpu_place);
SelectedRowsAdd<GPUPlace, float> add_functor;
add_functor(ctx, *selected_rows1, *selected_rows2, output.get());
auto out_height = output->height();
EXPECT_EQ(out_height, height);
auto& out_rows = output->rows();
// input1 rows
EXPECT_EQ(out_rows[0], 0);
EXPECT_EQ(out_rows[1], 4);
EXPECT_EQ(out_rows[2], 7);
// input2 rows
EXPECT_EQ(out_rows[3], 0);
EXPECT_EQ(out_rows[4], 5);
EXPECT_EQ(out_rows[5], 7);
EXPECT_EQ(out_rows[6], 9);
Tensor out_cpu;
out_cpu.CopyFrom<float>(*out_value, cpu_place, ctx);
ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>();
// input1 value
EXPECT_EQ(out_cpu_data[0 * row_numel + 0], 1.0);
EXPECT_EQ(out_cpu_data[0 * row_numel + 8], 1.0);
EXPECT_EQ(out_cpu_data[1 * row_numel + 1], 1.0);
EXPECT_EQ(out_cpu_data[2 * row_numel + 6], 1.0);
// input2 value
EXPECT_EQ(out_cpu_data[3 * row_numel + 3], 2.0);
EXPECT_EQ(out_cpu_data[3 * row_numel + 8], 2.0);
EXPECT_EQ(out_cpu_data[4 * row_numel + 4], 2.0);
EXPECT_EQ(out_cpu_data[5 * row_numel + 7], 2.0);
EXPECT_EQ(out_cpu_data[6 * row_numel + 9], 2.0);
std::unique_ptr<Tensor> tensor1{new Tensor()};
tensor1->mutable_data<float>(make_ddim({height, row_numel}), gpu_place);
SetConstant<GPUPlace, float> constant_functor;
constant_functor(ctx, tensor1.get(), 3.0);
std::unique_ptr<Tensor> tensor2{new Tensor()};
tensor2->mutable_data<float>(make_ddim({height, row_numel}), gpu_place);
SelectedRowsAddTensor<GPUPlace, float> add_tensor_functor;
add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
Tensor tensor2_cpu;
tensor2_cpu.CopyFrom<float>(*tensor2, cpu_place, ctx);
ctx.Wait();
auto* tensor2_cpu_data = tensor2_cpu.data<float>();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[0 * row_numel + 0], 6.0);
// row1: 3.0
EXPECT_EQ(tensor2_cpu_data[1 * row_numel + 1], 3.0);
// row4 : 1.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[4 * row_numel + 6], 4.0);
// row5: 2.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[5 * row_numel + 7], 5.0);
// row6: 3.0
EXPECT_EQ(tensor2_cpu_data[6 * row_numel + 1], 3.0);
// row7: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[7 * row_numel + 3], 6.0);
// row9: 2.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[9 * row_numel + 6], 5.0);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct SelectedRowsAdd<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* output) {
auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2.height());
output->set_height(in1_height);
auto& in1_rows = input1.rows();
auto& in2_rows = input2.rows();
std::vector<int64_t> out_rows;
out_rows.reserve(in1_rows.size() + in2_rows.size());
// concat rows
out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
output->set_rows(out_rows);
auto* out_value = output->mutable_value();
auto& in1_value = input1.value();
auto& in2_value = input2.value();
auto in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());
auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_cpu_place(in1_place));
auto in2_place = input2.place();
PADDLE_ENFORCE(platform::is_cpu_place(in2_place));
auto out_place = context.GetPlace();
PADDLE_ENFORCE(platform::is_cpu_place(out_place));
auto* out_data = out_value->data<T>();
auto* in1_data = in1_value.data<T>();
memory::Copy(boost::get<platform::CPUPlace>(out_place), out_data,
boost::get<platform::CPUPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T));
auto* in2_data = in2_value.data<T>();
memory::Copy(boost::get<platform::CPUPlace>(out_place),
out_data + in1_value.numel(),
boost::get<platform::CPUPlace>(in2_place), in2_data,
in2_value.numel() * sizeof(T));
}
};
template struct SelectedRowsAdd<platform::CPUPlace, float>;
template <typename T>
struct SelectedRowsAddTensor<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output) {
auto in1_height = input1.height();
auto in2_dims = input2.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);
SetConstant<platform::CPUPlace, T> functor;
functor(context, output, 0.0);
auto* in1_data = in1_value.data<T>();
auto* out_data = output->data<T>();
for (size_t i = 0; i < in1_rows.size(); i++) {
for (int64_t j = 0; j < in1_row_numel; j++) {
out_data[in1_rows[i] * in1_row_numel + j] +=
in1_data[i * in1_row_numel + j];
}
}
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
out_eigen.device(*context.GetEigenDevice<platform::CPUPlace>()) =
out_eigen + in2_eigen;
}
};
template struct SelectedRowsAddTensor<platform::CPUPlace, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct SelectedRowsAdd<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* output) {
auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2.height());
output->set_height(in1_height);
auto& in1_rows = input1.rows();
auto& in2_rows = input2.rows();
std::vector<int64_t> out_rows;
out_rows.reserve(in1_rows.size() + in2_rows.size());
// concat rows
out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
output->set_rows(out_rows);
auto* out_value = output->mutable_value();
auto& in1_value = input1.value();
auto& in2_value = input2.value();
auto in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());
auto* out_data = out_value->data<T>();
auto* in1_data = in1_value.data<T>();
auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
auto in2_place = input2.place();
PADDLE_ENFORCE(platform::is_gpu_place(in2_place));
auto out_place = context.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(out_place));
memory::Copy(
boost::get<platform::GPUPlace>(out_place), out_data,
boost::get<platform::GPUPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
auto* in2_data = in2_value.data<T>();
memory::Copy(
boost::get<platform::GPUPlace>(out_place), out_data + in1_value.numel(),
boost::get<platform::GPUPlace>(in2_place), in2_data,
in2_value.numel() * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
}
};
template struct SelectedRowsAdd<platform::GPUPlace, float>;
namespace {
template <typename T>
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
const int64_t* rows, T* tensor_out,
int64_t row_numel, int block_size) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
selected_rows += ty * row_numel;
tensor_out += rows[ty] * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
// Since index in rows of SelectedRows can be duplicate, we can not use
// tensor_out[index] += selected_rows[index]; Instead, we have to use
// AtomicAdd to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(&tensor_out[index], selected_rows[index]);
}
}
} // namespace
template <typename T>
struct SelectedRowsAddTensor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output) {
auto in1_height = input1.height();
auto in2_dims = input2.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* in2_data = input2.data<T>();
auto* out_data = output->data<T>();
SetConstant<platform::GPUPlace, T> functor;
functor(context, output, 0.0);
int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, in1_height);
SelectedRowsAddTensorKernel<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in1_data, in1_rows.data(), out_data,
in1_row_numel, block_size);
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
out_eigen.device(*context.GetEigenDevice<platform::GPUPlace>()) =
out_eigen + in2_eigen;
}
};
template struct SelectedRowsAddTensor<platform::GPUPlace, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
// SelectedRows + SelectedRows will simplely concat value and rows.
// The real computation happens in dealing with LoDTensor.
template <typename Place, typename T>
struct SelectedRowsAdd {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* output);
};
template <typename Place, typename T>
struct SelectedRowsAddTensor {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/selected_rows_functor.h"
#include "gtest/gtest.h"
#include "paddle/operators/math/math_function.h"
TEST(selected_rows_functor, cpu_add) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators::math;
CPUPlace cpu_place;
CPUDeviceContext ctx(cpu_place);
SetConstant<CPUPlace, float> functor;
int64_t height = 10;
int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), cpu_place);
functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), cpu_place);
functor(ctx, in2_value, 2.0);
std::unique_ptr<SelectedRows> output{new SelectedRows()};
auto* out_value = output->mutable_value();
// simplely concat two SelectedRows
out_value->mutable_data<float>(make_ddim({7, 10}), cpu_place);
SelectedRowsAdd<CPUPlace, float> add_functor;
add_functor(ctx, *selected_rows1, *selected_rows2, output.get());
auto out_height = output->height();
EXPECT_EQ(out_height, height);
auto& out_rows = output->rows();
// input1 rows
EXPECT_EQ(out_rows[0], 0);
EXPECT_EQ(out_rows[1], 4);
EXPECT_EQ(out_rows[2], 7);
// input2 rows
EXPECT_EQ(out_rows[3], 0);
EXPECT_EQ(out_rows[4], 5);
EXPECT_EQ(out_rows[5], 7);
EXPECT_EQ(out_rows[6], 9);
auto* out_data = output->value().data<float>();
// input1 value
EXPECT_EQ(out_data[0 * row_numel + 0], 1.0);
EXPECT_EQ(out_data[0 * row_numel + 8], 1.0);
EXPECT_EQ(out_data[1 * row_numel + 1], 1.0);
EXPECT_EQ(out_data[2 * row_numel + 6], 1.0);
// input2 value
EXPECT_EQ(out_data[3 * row_numel + 3], 2.0);
EXPECT_EQ(out_data[3 * row_numel + 8], 2.0);
EXPECT_EQ(out_data[4 * row_numel + 4], 2.0);
EXPECT_EQ(out_data[5 * row_numel + 7], 2.0);
EXPECT_EQ(out_data[6 * row_numel + 9], 2.0);
std::unique_ptr<Tensor> tensor1{new Tensor()};
tensor1->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
functor(ctx, tensor1.get(), 3.0);
std::unique_ptr<Tensor> tensor2{new Tensor()};
tensor2->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
SelectedRowsAddTensor<CPUPlace, float> add_tensor_functor;
add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
auto* tensor2_data = tensor2->data<float>();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor2_data[0 * row_numel + 0], 6.0);
// row1: 3.0
EXPECT_EQ(tensor2_data[1 * row_numel + 1], 3.0);
// row4 : 1.0 + 3.0
EXPECT_EQ(tensor2_data[4 * row_numel + 6], 4.0);
// row5: 2.0 + 3.0
EXPECT_EQ(tensor2_data[5 * row_numel + 7], 5.0);
// row6: 3.0
EXPECT_EQ(tensor2_data[6 * row_numel + 1], 3.0);
// row7: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor2_data[7 * row_numel + 3], 6.0);
// row9: 2.0 + 3.0
EXPECT_EQ(tensor2_data[9 * row_numel + 6], 5.0);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
TEST(selected_rows_functor, gpu_add) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators::math;
GPUPlace gpu_place(0);
CPUPlace cpu_place;
CUDADeviceContext ctx(gpu_place);
SetConstant<GPUPlace, float> functor;
int64_t height = 10;
int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), gpu_place);
functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), gpu_place);
functor(ctx, in2_value, 2.0);
std::unique_ptr<SelectedRows> output{new SelectedRows()};
auto* out_value = output->mutable_value();
// simplely concat two SelectedRows
out_value->mutable_data<float>(make_ddim({7, 10}), gpu_place);
SelectedRowsAdd<GPUPlace, float> add_functor;
add_functor(ctx, *selected_rows1, *selected_rows2, output.get());
auto out_height = output->height();
EXPECT_EQ(out_height, height);
auto& out_rows = output->rows();
// input1 rows
EXPECT_EQ(out_rows[0], 0);
EXPECT_EQ(out_rows[1], 4);
EXPECT_EQ(out_rows[2], 7);
// input2 rows
EXPECT_EQ(out_rows[3], 0);
EXPECT_EQ(out_rows[4], 5);
EXPECT_EQ(out_rows[5], 7);
EXPECT_EQ(out_rows[6], 9);
Tensor out_cpu;
out_cpu.CopyFrom<float>(*out_value, cpu_place, ctx);
ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>();
// input1 value
EXPECT_EQ(out_cpu_data[0 * row_numel + 0], 1.0);
EXPECT_EQ(out_cpu_data[0 * row_numel + 8], 1.0);
EXPECT_EQ(out_cpu_data[1 * row_numel + 1], 1.0);
EXPECT_EQ(out_cpu_data[2 * row_numel + 6], 1.0);
// input2 value
EXPECT_EQ(out_cpu_data[3 * row_numel + 3], 2.0);
EXPECT_EQ(out_cpu_data[3 * row_numel + 8], 2.0);
EXPECT_EQ(out_cpu_data[4 * row_numel + 4], 2.0);
EXPECT_EQ(out_cpu_data[5 * row_numel + 7], 2.0);
EXPECT_EQ(out_cpu_data[6 * row_numel + 9], 2.0);
std::unique_ptr<Tensor> tensor1{new Tensor()};
tensor1->mutable_data<float>(make_ddim({height, row_numel}), gpu_place);
functor(ctx, tensor1.get(), 3.0);
std::unique_ptr<Tensor> tensor2{new Tensor()};
tensor2->mutable_data<float>(make_ddim({height, row_numel}), gpu_place);
SelectedRowsAddTensor<GPUPlace, float> add_tensor_functor;
add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
Tensor tensor2_cpu;
tensor2_cpu.CopyFrom<float>(*tensor2, cpu_place, ctx);
ctx.Wait();
auto* tensor2_cpu_data = tensor2_cpu.data<float>();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[0 * row_numel + 0], 6.0);
// row1: 3.0
EXPECT_EQ(tensor2_cpu_data[1 * row_numel + 1], 3.0);
// row4 : 1.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[4 * row_numel + 6], 4.0);
// row5: 2.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[5 * row_numel + 7], 5.0);
// row6: 3.0
EXPECT_EQ(tensor2_cpu_data[6 * row_numel + 1], 3.0);
// row7: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[7 * row_numel + 3], 6.0);
// row9: 2.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[9 * row_numel + 6], 5.0);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册