未验证 提交 384372f5 编写于 作者: W wuhuachaocoding 提交者: GitHub

fix concat bug (#34319) (#34396)

上级 862e81ef
...@@ -40,18 +40,18 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { ...@@ -40,18 +40,18 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
const std::vector<framework::Tensor>& input, int axis, const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) { framework::Tensor* output) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int num = input.size(); size_t num = input.size();
int rows = 1; int64_t rows = 1;
auto dim_0 = input[0].dims(); auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
rows *= dim_0[i]; rows *= dim_0[i];
} }
int out_rows = rows, out_cols = 0; int64_t out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size()); std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows; int64_t t_cols = input[i].numel() / rows;
out_cols += t_cols; out_cols += t_cols;
input_cols[i] = t_cols; input_cols[i] = t_cols;
} }
...@@ -59,11 +59,11 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { ...@@ -59,11 +59,11 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
// computation // computation
auto output_data = output->data<T>(); auto output_data = output->data<T>();
int col_idx = 0; int64_t col_idx = 0;
for (int j = 0; j < num; ++j) { for (size_t j = 0; j < num; ++j) {
int col_len = input_cols[j]; int64_t col_len = input_cols[j];
auto input_data = input[j].data<T>(); auto input_data = input[j].data<T>();
for (int k = 0; k < out_rows; ++k) { for (int64_t k = 0; k < out_rows; ++k) {
memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place, memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place,
input_data + k * col_len, sizeof(T) * col_len); input_data + k * col_len, sizeof(T) * col_len);
} }
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
...@@ -25,9 +26,9 @@ namespace operators { ...@@ -25,9 +26,9 @@ namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
__global__ void ConcatKernel(const T** inputs, const int* input_cols, __global__ void ConcatKernel(const T** inputs, const int64_t* input_cols,
int col_size, const int output_rows, int col_size, const int64_t output_rows,
const int output_cols, T* output) { const int64_t output_cols, T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0; int curr_segment = 0;
int curr_offset = input_cols[0]; int curr_offset = input_cols[0];
...@@ -69,8 +70,8 @@ __device__ void ConcatKernelDetail(const T** inputs_data, ...@@ -69,8 +70,8 @@ __device__ void ConcatKernelDetail(const T** inputs_data,
template <typename T> template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
const int fixed_in_col, const int out_rows, const int64_t fixed_in_col, const int64_t out_rows,
const int out_cols, T* output_data) { const int64_t out_cols, T* output_data) {
const T* inputs_data[2]; const T* inputs_data[2];
inputs_data[0] = input_addr0; inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1; inputs_data[1] = input_addr1;
...@@ -80,8 +81,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, ...@@ -80,8 +81,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
template <typename T> template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
const T* input_addr2, const int fixed_in_col, const T* input_addr2, const int64_t fixed_in_col,
const int out_rows, const int out_cols, const int64_t out_rows, const int64_t out_cols,
T* output_data) { T* output_data) {
const T* inputs_data[3]; const T* inputs_data[3];
inputs_data[0] = input_addr0; inputs_data[0] = input_addr0;
...@@ -94,8 +95,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, ...@@ -94,8 +95,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
template <typename T> template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
const T* input_addr2, const T* input_addr3, const T* input_addr2, const T* input_addr3,
const int fixed_in_col, const int out_rows, const int64_t fixed_in_col, const int64_t out_rows,
const int out_cols, T* output_data) { const int64_t out_cols, T* output_data) {
const T* inputs_data[4]; const T* inputs_data[4];
inputs_data[0] = input_addr0; inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1; inputs_data[1] = input_addr1;
...@@ -107,8 +108,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, ...@@ -107,8 +108,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
template <typename T> template <typename T>
__global__ void ConcatKernel(const T** inputs_data, const int in_num, __global__ void ConcatKernel(const T** inputs_data, const int in_num,
const int fixed_in_col, const int out_rows, const int64_t fixed_in_col, const int64_t out_rows,
const int out_cols, T* output_data) { const int64_t out_cols, T* output_data) {
ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols, ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
output_data); output_data);
} }
...@@ -234,21 +235,41 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -234,21 +235,41 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
framework::Tensor* output) { framework::Tensor* output) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int in_num = input.size(); int in_num = input.size();
int in_row = 1; int64_t in_row = 1;
auto dim_0 = input[0].dims(); auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
in_row *= dim_0[i]; in_row *= dim_0[i];
} }
int in_col = input[0].numel() / in_row; int64_t in_col = input[0].numel() / in_row;
int out_row = in_row, out_col = 0; int64_t out_row = in_row, out_col = 0;
std::vector<const T*> inputs_data(in_num); int inputs_col_num = in_num + 1;
std::vector<int> inputs_col(in_num + 1); std::vector<const T*> inputs_data_vec(in_num);
std::vector<int64_t> inputs_col_vec(inputs_col_num);
const T** inputs_data = inputs_data_vec.data();
int64_t* inputs_col = inputs_col_vec.data();
// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
memory::AllocationPtr data_alloc, col_alloc;
data_alloc =
memory::Alloc(platform::CUDAPinnedPlace(), in_num * sizeof(T*));
inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
col_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
inputs_col_num * sizeof(int));
inputs_col = reinterpret_cast<int64_t*>(col_alloc->ptr());
#endif
inputs_col[0] = 0; inputs_col[0] = 0;
bool has_same_shape = true; bool has_same_shape = true;
for (int i = 0; i < in_num; ++i) { for (int i = 0; i < in_num; ++i) {
int t_cols = input[i].numel() / in_row; int64_t t_cols = input[i].numel() / in_row;
if (has_same_shape) { if (has_same_shape) {
if (t_cols != in_col) has_same_shape = false; if (t_cols != in_col) has_same_shape = false;
} }
...@@ -264,12 +285,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -264,12 +285,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
memory::allocation::AllocationPtr tmp_dev_ins_data; memory::allocation::AllocationPtr tmp_dev_ins_data;
const T** dev_ins_data = nullptr; const T** dev_ins_data = nullptr;
if (!has_same_shape || in_num < 2 || in_num > 4) { if (!has_same_shape || in_num < 2 || in_num > 4) {
tmp_dev_ins_data = tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*));
memory::Alloc(context, inputs_data.size() * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_data->ptr(), platform::CPUPlace(), tmp_dev_ins_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_data.data()), static_cast<void*>(inputs_data), in_num * sizeof(T*),
inputs_data.size() * sizeof(T*), context.stream()); context.stream());
dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr()); dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
} }
...@@ -292,17 +312,31 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -292,17 +312,31 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
} }
} else { } else {
auto tmp_dev_ins_col_data = auto tmp_dev_ins_col_data =
memory::Alloc(context, inputs_col.size() * sizeof(int)); memory::Alloc(context, inputs_col_num * sizeof(int64_t));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_col.data()), static_cast<void*>(inputs_col),
inputs_col.size() * sizeof(int), context.stream()); inputs_col_num * sizeof(int64_t), context.stream());
int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr()); int64_t* dev_ins_col_data =
static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>( ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()), dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col_num),
out_row, out_col, output->data<T>()); out_row, out_col, output->data<T>());
} }
#ifdef PADDLE_WITH_HIP
// Prevent the pinned memory value from being covered and release the memory
// after the launch kernel of the stream is executed (reapply pinned memory
// next time)
auto* data_alloc_released = data_alloc.release();
auto* col_alloc_released = col_alloc.release();
context.AddStreamCallback([data_alloc_released, col_alloc_released] {
memory::allocation::AllocationDeleter deleter;
deleter(data_alloc_released);
deleter(col_alloc_released);
});
#endif
} }
}; };
...@@ -313,6 +347,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -313,6 +347,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
template <typename T> template <typename T>
class SplitFunctor<platform::CUDADeviceContext, T> { class SplitFunctor<platform::CUDADeviceContext, T> {
public: public:
SplitFunctor();
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::Tensor*>& ref_inputs,
...@@ -329,8 +364,27 @@ class SplitFunctor<platform::CUDADeviceContext, T> { ...@@ -329,8 +364,27 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
int64_t in_col = 0, in_row = out_row; int64_t in_col = 0, in_row = out_row;
bool has_same_shape = true; bool has_same_shape = true;
std::vector<T*> outputs_data(o_num); int outputs_cols_num = o_num + 1;
std::vector<int64_t> outputs_cols(o_num + 1); std::vector<T*> outputs_data_vec(o_num);
std::vector<int64_t> outputs_cols_vec(outputs_cols_num);
T** outputs_data = outputs_data_vec.data();
int64_t* outputs_cols = outputs_cols_vec.data();
// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
memory::AllocationPtr data_alloc, cols_alloc;
data_alloc = memory::Alloc(platform::CUDAPinnedPlace(), o_num * sizeof(T*));
outputs_data = reinterpret_cast<T**>(data_alloc->ptr());
cols_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
(outputs_cols_num) * sizeof(int64_t));
outputs_cols = reinterpret_cast<int64_t*>(cols_alloc->ptr());
#endif
outputs_cols[0] = 0; outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) { for (int i = 0; i < o_num; ++i) {
...@@ -354,12 +408,11 @@ class SplitFunctor<platform::CUDADeviceContext, T> { ...@@ -354,12 +408,11 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
memory::allocation::AllocationPtr tmp_dev_outs_data; memory::allocation::AllocationPtr tmp_dev_outs_data;
T** dev_out_gpu_data = nullptr; T** dev_out_gpu_data = nullptr;
if (!has_same_shape || o_num < 2 || o_num > 4) { if (!has_same_shape || o_num < 2 || o_num > 4) {
tmp_dev_outs_data = tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*));
memory::Alloc(context, outputs_data.size() * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_outs_data->ptr(), platform::CPUPlace(), tmp_dev_outs_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_data.data()), reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*),
outputs_data.size() * sizeof(T*), context.stream()); context.stream());
dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr()); dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
} }
...@@ -382,20 +435,30 @@ class SplitFunctor<platform::CUDADeviceContext, T> { ...@@ -382,20 +435,30 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
} }
} else { } else {
auto tmp_dev_ins_col_data = auto tmp_dev_ins_col_data =
memory::Alloc(context, memory::Alloc(context, outputs_cols_num * sizeof(int64_t));
outputs_cols.size() * sizeof(int64_t));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_cols.data()), reinterpret_cast<void*>(outputs_cols),
outputs_cols.size() * sizeof(int64_t), context.stream()); outputs_cols_num * sizeof(int64_t), context.stream());
int64_t* dev_outs_col_data = int64_t* dev_outs_col_data =
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr()); reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>( SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, dev_outs_col_data, input.data<T>(), in_row, in_col, dev_outs_col_data,
static_cast<int>(outputs_cols.size()), dev_out_gpu_data); static_cast<int>(outputs_cols_num), dev_out_gpu_data);
} }
#ifdef PADDLE_WITH_HIP
// Prevent the pinned memory value from being covered and release the memory
// after the launch kernel of the stream is executed (reapply pinned memory
// next time)
auto* data_alloc_released = data_alloc.release();
auto* cols_alloc_released = cols_alloc.release();
context.AddStreamCallback([data_alloc_released, cols_alloc_released] {
memory::allocation::AllocationDeleter deleter;
deleter(data_alloc_released);
deleter(cols_alloc_released);
});
#endif
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册