未验证 提交 6b84688b 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize the cuda implementation of sum_op (#17283)

* Optimize the cuda implementation of sum_op, which add two lod_tensors inplace.
test=develop

* Use eigen to add to tensors.
test=develop
上级 db5e74ab
...@@ -87,7 +87,7 @@ __global__ void SumAlign4CUDAKernel(const T *in_0, const T *in_1, T *out, ...@@ -87,7 +87,7 @@ __global__ void SumAlign4CUDAKernel(const T *in_0, const T *in_1, T *out,
} }
template <class T> template <class T>
void FuseLodTensorSumCompute(const framework::ExecutionContext &context) { void SumToLoDTensor(const framework::ExecutionContext &context) {
auto in_vars = context.MultiInputVar("X"); auto in_vars = context.MultiInputVar("X");
const size_t in_num = in_vars.size(); const size_t in_num = in_vars.size();
...@@ -114,16 +114,12 @@ void FuseLodTensorSumCompute(const framework::ExecutionContext &context) { ...@@ -114,16 +114,12 @@ void FuseLodTensorSumCompute(const framework::ExecutionContext &context) {
}; };
auto *out = context.Output<LoDTensor>("Out"); auto *out = context.Output<LoDTensor>("Out");
bool in_place = in_vars[0] == context.OutputVar("Out");
auto out_var = context.OutputVar("Out");
bool in_place = in_vars[0] == out_var;
if (!in_place) { if (!in_place) {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
} }
int start = in_place ? 1 : 0;
if (!in_place) { // Sum of two tensors
// seperate path for a+b,maybe not fast than eigen
if (in_num == 2 && in_vars[0]->IsType<framework::LoDTensor>() && if (in_num == 2 && in_vars[0]->IsType<framework::LoDTensor>() &&
in_vars[1]->IsType<framework::LoDTensor>()) { in_vars[1]->IsType<framework::LoDTensor>()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>(); auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
...@@ -131,13 +127,16 @@ void FuseLodTensorSumCompute(const framework::ExecutionContext &context) { ...@@ -131,13 +127,16 @@ void FuseLodTensorSumCompute(const framework::ExecutionContext &context) {
auto length = in_0.numel(); auto length = in_0.numel();
if (length) { if (length) {
ComputeKernelParameter(length); auto result = EigenVector<T>::Flatten(*out);
Sum2CUDAKernel<T><<<grids, blocks, 0, stream>>>( auto &place = *dev_ctx.eigen_device();
in_0.data<T>(), in_1.data<T>(), out->data<T>(), length); auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
} }
return; return;
} }
}
int start = in_place ? 1 : 0;
if (!in_place) { if (!in_place) {
math::SetConstant<platform::CUDADeviceContext, T> constant_functor; math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor( constant_functor(
...@@ -228,13 +227,10 @@ class SumKernel<platform::CUDADeviceContext, T> ...@@ -228,13 +227,10 @@ class SumKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto in_vars = context.MultiInputVar("X");
const size_t in_num = in_vars.size();
auto out_var = context.OutputVar("Out"); auto out_var = context.OutputVar("Out");
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
FuseLodTensorSumCompute<T>(context); SumToLoDTensor<T>(context);
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
SelectedRowsCompute<platform::CUDADeviceContext, T>(context); SelectedRowsCompute<platform::CUDADeviceContext, T>(context);
} else if (out_var->IsType<framework::LoDTensorArray>()) { } else if (out_var->IsType<framework::LoDTensorArray>()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册