未验证 提交 34605d26 编写于 作者: D dzhwinter 提交者: GitHub

accelerate the cuda concat op, avoid many times copy (#8585)

* "try enhance concat op"

* "enhance the concat operator"
上级 087d8e7f
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h"
......@@ -34,12 +35,46 @@ class ConcatKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride_numel(out->dims());
size_t output_offset = 0;
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
// If axis >=1, copy to out immediately need to call many times
// of cuda memcpy. Copy the input to cpu and do the stride copy,
// then copy to gpu output.
if (platform::is_gpu_place(place) && axis >= 1) {
platform::CPUPlace copy_place;
auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place);
framework::Tensor cpu_out;
cpu_out.Resize(out->dims());
cpu_out.mutable_data<T>(copy_place);
auto& dev_ctx = ctx.device_context();
std::vector<std::unique_ptr<framework::Tensor>> cpu_ins;
for (auto* in : ins) {
std::unique_ptr<framework::Tensor> cpu_in(new framework::Tensor);
framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get());
cpu_ins.emplace_back(std::move(cpu_in));
}
// TODO(dzhwinter): overlap copy and compute stream
// https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
dev_ctx.Wait();
for (auto& in : cpu_ins) {
auto& cpu_in = *in.get();
auto in_stride = framework::stride_numel(cpu_in.dims());
StridedNumelCopyWithAxis<T>(
cpu_ctx, axis, cpu_out.data<T>() + output_offset, out_stride,
cpu_in.data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
framework::TensorCopy(cpu_out, place, dev_ctx, out);
} else {
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册