提交 d48a0e4e 编写于 作者: T typhoonzero

WIP: adding generic scattor functors

上级 dd21ae6c
...@@ -179,6 +179,53 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, double>; ...@@ -179,6 +179,53 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int>; template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>; template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
// This is a separated namespace for manipulate SelectedRows typed
// data. Like merge duplicated rows, adding two SelectedRows etc.
//
// Another group of functors is called "scatter updates", which means
// use SelectedRows to update a dense tensor with different Ops, like
// add or mul.
namespace scatter {
size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
return std::find(rows.begin(), rows.end(), value) - rows.begin();
}
template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* out) {
auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1];
// std::unique_ptr<framework::SelectedRows> out{
// new framework::SelectedRows()};
out->set_rows(merge_rows);
out->set_height(input.height());
out->mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, out->mutable_value(), 0.0);
auto* out_data = out->mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = FindPos(merge_rows, input_rows[i]);
for (int64_t j = 0; j < input_width; j++) {
out_data[out_i * input_width + j] += input_data[i * input_width + j];
}
}
}
};
} // namespace scatter
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -222,6 +222,73 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>; ...@@ -222,6 +222,73 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>; template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>; template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>; template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
namespace scatter {
template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
T* out, const int64_t* out_rows,
size_t out_rows_size, int64_t row_numel) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ size_t out_idx;
if (tid == 0) {
for (size_t i = 0; i < out_rows_size; i++) {
if (input_rows[ty] == out_rows[i]) {
out_idx = i;
}
}
}
__syncthreads();
input += ty * row_numel;
out += out_idx * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
paddle::platform::CudaAtomicAdd(out + index, input[index]);
}
}
template <typename T>
struct MergeAdd<platform::GPUDeviceContext, T> {
void operator()(const platform::GPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* out) {
auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1];
// std::unique_ptr<framework::SelectedRows> out{
// new framework::SelectedRows()};
out->set_rows(merge_rows);
out->set_height(input.height());
out->mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, out->mutable_value(), 0.0);
auto* out_data = out->mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid1(1, input_rows.size());
MergeAddKernel<
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input_data, input.rows().data(), out_data,
out->rows().data(), out->rows().size(),
input_width);
}
};
} // namespace scatter
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -52,6 +52,53 @@ struct SelectedRowsAddToTensor { ...@@ -52,6 +52,53 @@ struct SelectedRowsAddToTensor {
framework::Tensor* input2); framework::Tensor* input2);
}; };
namespace scatter {
// functors for manuplating SelectedRows data
template <typename DeviceContext, typename T>
struct MergeAdd {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
void operator()(const DeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* out);
};
template <typename DeviceContext, typename T>
struct Add {
void operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* out) {
out->set_rows(input1->rows());
out->set_height(input1->height());
out->mutable_value()->mutable_data<T>(input1->value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1->value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2->value());
e_out.device(*context.eigen_device()) = e_in1 + e_in2;
}
};
template <typename DeviceContext, typename T>
struct Mul {
void operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* out) {
out->set_rows(input1->rows());
out->set_height(input1->height());
out->mutable_value()->mutable_data<T>(input1->value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1->value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2->value());
e_out.device(*context.eigen_device()) = e_in1 * e_in2;
}
};
} // namespace scatter
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册