未验证 提交 6b84688b 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize the cuda implementation of sum_op (#17283)

* Optimize the cuda implementation of sum_op, which add two lod_tensors inplace.
test=develop

* Use eigen to add to tensors.
test=develop
上级 db5e74ab
......@@ -87,7 +87,7 @@ __global__ void SumAlign4CUDAKernel(const T *in_0, const T *in_1, T *out,
}
template <class T>
void FuseLodTensorSumCompute(const framework::ExecutionContext &context) {
void SumToLoDTensor(const framework::ExecutionContext &context) {
auto in_vars = context.MultiInputVar("X");
const size_t in_num = in_vars.size();
......@@ -114,30 +114,29 @@ void FuseLodTensorSumCompute(const framework::ExecutionContext &context) {
};
auto *out = context.Output<LoDTensor>("Out");
auto out_var = context.OutputVar("Out");
bool in_place = in_vars[0] == out_var;
bool in_place = in_vars[0] == context.OutputVar("Out");
if (!in_place) {
out->mutable_data<T>(context.GetPlace());
}
int start = in_place ? 1 : 0;
if (!in_place) {
// seperate path for a+b,maybe not fast than eigen
if (in_num == 2 && in_vars[0]->IsType<framework::LoDTensor>() &&
in_vars[1]->IsType<framework::LoDTensor>()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
auto length = in_0.numel();
if (length) {
ComputeKernelParameter(length);
Sum2CUDAKernel<T><<<grids, blocks, 0, stream>>>(
in_0.data<T>(), in_1.data<T>(), out->data<T>(), length);
}
return;
// Sum of two tensors
if (in_num == 2 && in_vars[0]->IsType<framework::LoDTensor>() &&
in_vars[1]->IsType<framework::LoDTensor>()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
auto length = in_0.numel();
if (length) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
}
return;
}
int start = in_place ? 1 : 0;
if (!in_place) {
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(
......@@ -228,13 +227,10 @@ class SumKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto in_vars = context.MultiInputVar("X");
const size_t in_num = in_vars.size();
auto out_var = context.OutputVar("Out");
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) {
FuseLodTensorSumCompute<T>(context);
SumToLoDTensor<T>(context);
} else if (out_var->IsType<framework::SelectedRows>()) {
SelectedRowsCompute<platform::CUDADeviceContext, T>(context);
} else if (out_var->IsType<framework::LoDTensorArray>()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册