提交 00e596ed 编写于 作者: C chengduoZH

get max threads of GPU

上级 60e7ee06
......@@ -32,6 +32,7 @@ class ConcatKernel : public framework::OpKernel<T> {
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
std::vector<framework::Tensor> inputs(ins.size());
for (size_t j = 0; j < ins.size(); ++j) {
inputs[j] = *ins[j];
......@@ -49,17 +50,17 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
size_t input_offset = 0;
auto in_stride = framework::stride_numel(in->dims());
for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride, out_stride[axis]);
input_offset += out_stride[axis];
std::vector<framework::Tensor> outputs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs[j] = *outs[j];
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor;
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs);
}
};
......
......@@ -25,16 +25,12 @@ template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
std::vector<framework::Tensor>& input, const int axis,
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// assume the the max size of input is less than 8 and see the performance
// save origin dim
int num = input.size();
std::vector<paddle::framework::DDim> origin_dim(num);
// for (int j = 0; j < num; ++j) {
// origin_dim[j] = input[j].dims();
// }
auto out_dim = output->dims();
// get the matrix size
int rows = 1;
......@@ -42,40 +38,72 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int cols = input[0].numel() / rows;
int out_rows = rows, out_cols = 0;
bool sameShape = true;
// reshape to matrix
// get input's cols
std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
if (sameShape) {
if (t_cols != cols) sameShape = false;
}
out_cols += t_cols;
input[i].Resize({rows, t_cols});
input_cols[i] = t_cols;
}
output->Resize({out_rows, out_cols});
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation
for (int k = 0; k < rows; ++k) {
// offset k * out_cols
for (int k = 0; k < out_rows; ++k) {
T* dst_ptr = output->data<T>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input[j].dims()[1];
int col_len = input_cols[j];
const T* src_prt = input[j].data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt,
sizeof(T) * col_len);
col_idx += col_len;
}
}
}
};
template <typename T>
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
// assume the the max size of input is less than 8 and see the performance
// save origin dim
int num = outputs.size();
std::vector<paddle::framework::DDim> origin_dim(num);
// recover origin dim
// for (int j = 0; j < num; ++j) {
// input[j]->Resize(origin_dim[j]);
// }
output->Resize(out_dim);
// get the matrix size
int input_rows = 1;
auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) {
input_rows *= dim_0[i];
}
int input_cols = 0;
// get outputs' cols
std::vector<int64_t> output_cols(outputs.size());
for (int i = 0; i < num; ++i) {
int t_cols = outputs[i].numel() / input_rows;
input_cols += t_cols;
output_cols[i] = t_cols;
}
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation
for (int k = 0; k < input_rows; ++k) {
const T* src_ptr = input.data<T>() + k * input_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = output_cols[j];
T* dst_ptr = outputs[j].data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
sizeof(T) * col_len);
col_idx += col_len;
}
}
}
};
......@@ -84,6 +112,11 @@ template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatFunctor<platform::CPUDeviceContext, float>;
template class ConcatFunctor<platform::CPUDeviceContext, double>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -22,7 +22,7 @@ namespace math {
// TODO(zcd): This can be replaced by tensor,
// if that, maybe we should add int8 to VarType::Type.
// Or replaced by tensorArray.
static constexpr int MaxSize = 32;
static constexpr int MaxSize = 8;
template <typename T>
struct CUDADeviceArray {
T data[MaxSize];
......@@ -54,7 +54,6 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
const int output_rows, const int output_cols,
T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
int segment = upper_bound<int>(input_cols.data, input_cols.size, tid_x) - 1;
int curr_offset = input_cols.data[segment];
......@@ -69,13 +68,73 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
const T* input_ptr = inputs.data[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * segment_width + local_col];
}
}
template <typename T>
__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
const int input_col, const int output_rows,
const int output_cols, T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
float inv_input_col = 1.0 / input_col;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
const T* input_ptr = inputs.data[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * input_col + in_offset];
}
}
template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col,
CUDADeviceArray<int> output_cols,
CUDADeviceArray<T*> outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(output_cols.data, output_cols.size, tid_x) - 1;
int curr_offset = output_cols.data[segment];
int curr_segment = segment;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset;
while ((curr_col_offset = output_cols.data[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* output_ptr = outputs.data[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] =
input[tid_y * input_col + tid_x];
}
}
template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int output_cols,
CUDADeviceArray<T*> outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
float inv_input_col = 1.0 / input_col;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
T* output_ptr = outputs.data[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * output_cols + in_offset] =
input[tid_y * input_col + tid_x];
}
}
/*
* All tensors' dimension should be the same.
*/
......@@ -83,17 +142,13 @@ template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
std::vector<framework::Tensor>& input, const int axis,
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// assume the the max size of input is less than 8 and see the performance
// save origin dim
int num = input.size();
// std::vector<paddle::framework::DDim> origin_dim(num);
// for (int j = 0; j < num; ++j) {
// origin_dim[j] = input[j].dims();
// }
auto out_dim = output->dims();
PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d",
MaxSize);
// get the matrix size
int rows = 1;
auto dim_0 = input[0].dims();
......@@ -117,30 +172,96 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
if (t_cols != cols) sameShape = false;
}
out_cols += t_cols;
input[i].Resize({rows, t_cols});
inputs_cols.data[i + 1] = out_cols;
inputs_data.data[i] = input[i].data<T>();
}
output->Resize({out_rows, out_cols});
// computation
const int kThreadsPerBlock = 256;
// set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = std::min(out_cols, kThreadsPerBlock);
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
dim3 block_size = dim3(block_cols, block_rows, 1);
int grid_cols = (out_cols + block_cols - 1) / block_cols;
int grid_rows = (out_rows + block_rows - 1) / block_rows;
int dev_id = paddle::platform::GetCurrentDeviceId();
int multi_process = paddle::platform::GetCUDAMultiProcessors(dev_id);
int max_threads_per_mp =
paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(dev_id);
int max_threads = multi_process * max_threads_per_mp;
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
inputs_data, inputs_cols, out_rows, out_cols, output->data<T>());
if (sameShape) {
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
inputs_data, cols, out_rows, out_cols, output->data<T>());
} else {
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
inputs_data, inputs_cols, out_rows, out_cols, output->data<T>());
}
}
};
template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
// assume the the max size of input is less than 8 and see the performance
// save origin dim
int num = outputs.size();
PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d",
MaxSize);
// get the matrix size
int input_row = 1;
auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) {
input_row *= dim_0[i];
}
int output_col_0 = outputs[0].numel() / input_row;
int input_col = 0;
bool sameShape = true;
CUDADeviceArray<T*> outputs_data;
CUDADeviceArray<int> outputs_cols;
outputs_data.size = num;
outputs_cols.size = num + 1;
outputs_cols.data[0] = 0;
// recover origin dim
// for (int j = 0; j < num; ++j) {
// input[j].Resize(origin_dim[j]);
// }
output->Resize(out_dim);
for (int i = 0; i < num; ++i) {
int t_col = outputs[i].numel() / input_row;
if (sameShape) {
if (t_col != output_col_0) sameShape = false;
}
input_col += t_col;
outputs_cols.data[i + 1] = input_col;
outputs_data.data[i] = outputs[i].data<T>();
}
// computation
const int kThreadsPerBlock = 256;
int block_cols = std::min(input_col, kThreadsPerBlock);
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
dim3 block_size = dim3(block_cols, block_rows, 1);
int grid_cols = (input_col + block_cols - 1) / block_cols;
int grid_rows = (input_row + block_rows - 1) / block_rows;
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), input_row, input_col, output_col_0, outputs_data);
} else {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), input_row, input_col, outputs_cols, outputs_data);
}
}
};
......@@ -149,6 +270,11 @@ template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -20,18 +20,23 @@ namespace operators {
namespace math {
/*
* the tensor's shape of input will be changed,
* so the second parameter is not const.
*
*/
template <typename DeviceContext, typename T>
class ConcatFunctor {
public:
void operator()(const DeviceContext& context,
std::vector<framework::Tensor>& input, const int axis,
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output);
};
template <typename DeviceContext, typename T>
class ConcatGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const int axis, std::vector<framework::Tensor>& outputs);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -251,6 +251,80 @@ void testConcat() {
}
}
}
/**
* cast4:
* inputs:
* axis = 1
* t_a.shape: [2, 3, 4]
* t_b.shape: [2, 3, 4]
* output:
* out.shape: [2, 6, 4]
*/
dim_a = make_ddim({2, 3, 4});
dim_b = make_ddim({2, 3, 4});
dim_out = make_ddim({2, 6, 4});
input_a.Resize(dim_a);
input_b.Resize(dim_b);
out.Resize(dim_out);
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.Resize(dim_a);
input_b_cpu.Resize(dim_b);
out_cpu.Resize(dim_out);
}
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
} else {
a_ptr = input_a.data<int>();
b_ptr = input_b.data<int>();
}
for (int i = 0; i < 2 * 3 * 4; ++i) {
a_ptr[i] = i;
}
for (int i = 0; i < 2 * 3 * 4; ++i) {
b_ptr[i] = i;
}
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(input_a_cpu, Place(), *context, &input_a);
TensorCopy(input_b_cpu, Place(), *context, &input_b);
}
input.clear();
input.push_back(input_a);
input.push_back(input_b);
concat_functor(*context, input, 1, &out);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(out, CPUPlace(), *context, &out_cpu);
out_ptr = out_cpu.data<int>();
} else {
out_ptr = out.data<int>();
}
// check the data
cols = 12;
idx_a = 0, idx_b = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 24; ++j) {
if (j >= cols) {
PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], b_ptr[idx_b]);
++idx_b;
} else {
PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], a_ptr[idx_a]);
++idx_a;
}
}
}
}
TEST(math, concat) {
......
......@@ -33,6 +33,26 @@ int GetCUDADeviceCount() {
return count;
}
int GetCUDAMultiProcessors(int id) {
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
int count;
PADDLE_ENFORCE(
cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id),
"cudaDeviceGetAttribute failed in "
"paddle::platform::GetCUDAMultiProcessors");
return count;
}
int GetCUDAMaxThreadsPerMultiProcessor(int id) {
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
int count;
PADDLE_ENFORCE(cudaDeviceGetAttribute(
&count, cudaDevAttrMaxThreadsPerMultiProcessor, id),
"cudaDeviceGetAttribute failed in "
"paddle::platform::GetCUDAMaxThreadsPerMultiProcessor");
return count;
}
int GetCurrentDeviceId() {
int device_id;
PADDLE_ENFORCE(
......
......@@ -30,6 +30,12 @@ const std::string kEnvFractionGpuMemoryToUse =
//! Get the total number of GPU devices in system.
int GetCUDADeviceCount();
//! Get the MultiProcessors of the ith GPU.
int GetCUDAMultiProcessors(int i);
//! Get the MaxThreads of each MultiProcessor of the ith GPU.
int GetCUDAMaxThreadsPerMultiProcessor(int i);
//! Get the current GPU device id in system.
int GetCurrentDeviceId();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册