未验证 提交 34605d26 编写于 作者: D dzhwinter 提交者: GitHub

accelerate the cuda concat op, avoid many times copy (#8585)

* "try enhance concat op"

* "enhance the concat operator"
上级 087d8e7f
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
...@@ -34,6 +35,39 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -34,6 +35,39 @@ class ConcatKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
size_t output_offset = 0; size_t output_offset = 0;
// If axis >=1, copy to out immediately need to call many times
// of cuda memcpy. Copy the input to cpu and do the stride copy,
// then copy to gpu output.
if (platform::is_gpu_place(place) && axis >= 1) {
platform::CPUPlace copy_place;
auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place);
framework::Tensor cpu_out;
cpu_out.Resize(out->dims());
cpu_out.mutable_data<T>(copy_place);
auto& dev_ctx = ctx.device_context();
std::vector<std::unique_ptr<framework::Tensor>> cpu_ins;
for (auto* in : ins) {
std::unique_ptr<framework::Tensor> cpu_in(new framework::Tensor);
framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get());
cpu_ins.emplace_back(std::move(cpu_in));
}
// TODO(dzhwinter): overlap copy and compute stream
// https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
dev_ctx.Wait();
for (auto& in : cpu_ins) {
auto& cpu_in = *in.get();
auto in_stride = framework::stride_numel(cpu_in.dims());
StridedNumelCopyWithAxis<T>(
cpu_ctx, axis, cpu_out.data<T>() + output_offset, out_stride,
cpu_in.data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
framework::TensorCopy(cpu_out, place, dev_ctx, out);
} else {
for (auto* in : ins) { for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims()); auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
...@@ -42,6 +76,7 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -42,6 +76,7 @@ class ConcatKernel : public framework::OpKernel<T> {
output_offset += in_stride[axis]; output_offset += in_stride[axis];
} }
} }
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册