提交 82bd82c1 编写于 作者: C chengduoZH

follow comments and refine code

上级 00e596ed
...@@ -33,6 +33,7 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -33,6 +33,7 @@ class ConcatKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
out->mutable_data<T>(place); out->mutable_data<T>(place);
// TODO(zcd): Sometimes direct copies will be faster
std::vector<framework::Tensor> inputs(ins.size()); std::vector<framework::Tensor> inputs(ins.size());
for (size_t j = 0; j < ins.size(); ++j) { for (size_t j = 0; j < ins.size(); ++j) {
inputs[j] = *ins[j]; inputs[j] = *ins[j];
...@@ -51,6 +52,7 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -51,6 +52,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
// TODO(zcd): Sometimes direct copies will be faster
std::vector<framework::Tensor> outputs(outs.size()); std::vector<framework::Tensor> outputs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) { for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace()); outs[j]->mutable_data<T>(ctx.GetPlace());
......
...@@ -19,7 +19,8 @@ namespace operators { ...@@ -19,7 +19,8 @@ namespace operators {
namespace math { namespace math {
/* /*
* All tensors' dimension should be the same. * All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/ */
template <typename T> template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> { class ConcatFunctor<platform::CPUDeviceContext, T> {
...@@ -27,12 +28,9 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { ...@@ -27,12 +28,9 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis, const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) { framework::Tensor* output) {
// assume the the max size of input is less than 8 and see the performance // TODO(zcd): Add input data validity checking
// save origin dim
int num = input.size(); int num = input.size();
std::vector<paddle::framework::DDim> origin_dim(num);
// get the matrix size
int rows = 1; int rows = 1;
auto dim_0 = input[0].dims(); auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
...@@ -40,7 +38,6 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { ...@@ -40,7 +38,6 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
} }
int out_rows = rows, out_cols = 0; int out_rows = rows, out_cols = 0;
// get input's cols
std::vector<int64_t> input_cols(input.size()); std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows; int t_cols = input[i].numel() / rows;
...@@ -64,18 +61,19 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { ...@@ -64,18 +61,19 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
} }
}; };
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T> template <typename T>
class ConcatGradFunctor<platform::CPUDeviceContext, T> { class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const int axis, const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) { std::vector<framework::Tensor>& outputs) {
// assume the the max size of input is less than 8 and see the performance // TODO(zcd): Add input data validity checking
// save origin dim
int num = outputs.size(); int num = outputs.size();
std::vector<paddle::framework::DDim> origin_dim(num);
// get the matrix size
int input_rows = 1; int input_rows = 1;
auto dim_0 = outputs[0].dims(); auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
...@@ -83,7 +81,6 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { ...@@ -83,7 +81,6 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
} }
int input_cols = 0; int input_cols = 0;
// get outputs' cols
std::vector<int64_t> output_cols(outputs.size()); std::vector<int64_t> output_cols(outputs.size());
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
int t_cols = outputs[i].numel() / input_rows; int t_cols = outputs[i].numel() / input_rows;
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_helper.h"
...@@ -19,16 +20,6 @@ namespace paddle { ...@@ -19,16 +20,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// TODO(zcd): This can be replaced by tensor,
// if that, maybe we should add int8 to VarType::Type.
// Or replaced by tensorArray.
static constexpr int MaxSize = 8;
template <typename T>
struct CUDADeviceArray {
T data[MaxSize];
int size;
};
template <typename T> template <typename T>
__device__ T upper_bound(const T* first, T count, T val) { __device__ T upper_bound(const T* first, T count, T val) {
const T* orig = first; const T* orig = first;
...@@ -49,25 +40,24 @@ __device__ T upper_bound(const T* first, T count, T val) { ...@@ -49,25 +40,24 @@ __device__ T upper_bound(const T* first, T count, T val) {
} }
template <typename T> template <typename T>
__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs, __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
const CUDADeviceArray<int> input_cols,
const int output_rows, const int output_cols, const int output_rows, const int output_cols,
T* output) { T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(input_cols.data, input_cols.size, tid_x) - 1; int segment = upper_bound<int>(input_cols, col_size, tid_x) - 1;
int curr_offset = input_cols.data[segment]; int curr_offset = input_cols[segment];
int curr_segment = segment; int curr_segment = segment;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset; T curr_col_offset;
while ((curr_col_offset = input_cols.data[curr_segment + 1]) <= tid_x) { while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset; curr_offset = curr_col_offset;
++curr_segment; ++curr_segment;
} }
int local_col = tid_x - curr_offset; int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset; int segment_width = curr_col_offset - curr_offset;
const T* input_ptr = inputs.data[curr_segment]; T* input_ptr = inputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] = output[tid_y * output_cols + tid_x] =
...@@ -76,41 +66,41 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs, ...@@ -76,41 +66,41 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
} }
template <typename T> template <typename T>
__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs, __global__ void KernelConcat(T** inputs, const int input_col,
const int input_col, const int output_rows, const int output_rows, const int output_cols,
const int output_cols, T* output) { T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
float inv_input_col = 1.0 / input_col; float inv_input_col = 1.0 / input_col;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col; int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col; int in_offset = tid_x - split * input_col;
const T* input_ptr = inputs.data[split]; T* input_ptr = inputs[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
output[tid_y * output_cols + tid_x] = output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * input_col + in_offset]; input_ptr[tid_y * input_col + in_offset];
}
} }
} }
template <typename T> template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row, __global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int input_col, const int* output_cols,
CUDADeviceArray<int> output_cols, int col_size, T** outputs) {
CUDADeviceArray<T*> outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(output_cols.data, output_cols.size, tid_x) - 1; int segment = upper_bound<int>(output_cols, col_size, tid_x) - 1;
int curr_offset = output_cols.data[segment]; int curr_offset = output_cols[segment];
int curr_segment = segment; int curr_segment = segment;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset; T curr_col_offset;
while ((curr_col_offset = output_cols.data[curr_segment + 1]) <= tid_x) { while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset; curr_offset = curr_col_offset;
++curr_segment; ++curr_segment;
} }
int local_col = tid_x - curr_offset; int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset; int segment_width = curr_col_offset - curr_offset;
T* output_ptr = outputs.data[curr_segment]; T* output_ptr = outputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] = output_ptr[tid_y * segment_width + local_col] =
...@@ -121,13 +111,13 @@ __global__ void KernelConcatGrad(const T* input, const int input_row, ...@@ -121,13 +111,13 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
template <typename T> template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row, __global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int output_cols, const int input_col, const int output_cols,
CUDADeviceArray<T*> outputs) { T** outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
float inv_input_col = 1.0 / input_col; float inv_input_col = 1.0 / input_col;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col; int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col; int in_offset = tid_x - split * input_col;
T* output_ptr = outputs.data[split]; T* output_ptr = outputs[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * output_cols + in_offset] = output_ptr[tid_y * output_cols + in_offset] =
...@@ -136,7 +126,8 @@ __global__ void KernelConcatGrad(const T* input, const int input_row, ...@@ -136,7 +126,8 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
} }
/* /*
* All tensors' dimension should be the same. * All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/ */
template <typename T> template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> { class ConcatFunctor<platform::CUDADeviceContext, T> {
...@@ -144,12 +135,8 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -144,12 +135,8 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis, const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) { framework::Tensor* output) {
// assume the the max size of input is less than 8 and see the performance // TODO(zcd): Add input data validity checking
// save origin dim
int num = input.size(); int num = input.size();
PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d",
MaxSize);
// get the matrix size
int rows = 1; int rows = 1;
auto dim_0 = input[0].dims(); auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
...@@ -157,25 +144,27 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -157,25 +144,27 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
} }
int cols = input[0].numel() / rows; int cols = input[0].numel() / rows;
int out_rows = rows, out_cols = 0; int out_rows = rows, out_cols = 0;
bool sameShape = true;
CUDADeviceArray<const T*> inputs_data; paddle::framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
CUDADeviceArray<int> inputs_cols; paddle::framework::Vector<int> inputs_cols(num + 1);
inputs_data.size = num; inputs_cols[0] = 0;
inputs_cols.size = num + 1; T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());
inputs_cols.data[0] = 0;
// reshape to matrix bool sameShape = true;
// check input shape is valid
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows; int t_cols = input[i].numel() / rows;
if (sameShape) { if (sameShape) {
if (t_cols != cols) sameShape = false; if (t_cols != cols) sameShape = false;
} }
out_cols += t_cols; out_cols += t_cols;
inputs_cols.data[i + 1] = out_cols; inputs_cols[i + 1] = out_cols;
inputs_data.data[i] = input[i].data<T>(); inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
} }
T** ins_gpu =
reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));
const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace());
// computation // computation
// set the thread block and grid according to CurrentDeviceId // set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024; const int kThreadsPerBlock = 1024;
...@@ -198,27 +187,27 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -198,27 +187,27 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
if (sameShape) { if (sameShape) {
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
inputs_data, cols, out_rows, out_cols, output->data<T>()); ins_gpu, cols, out_rows, out_cols, output->data<T>());
} else { } else {
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
inputs_data, inputs_cols, out_rows, out_cols, output->data<T>()); ins_gpu, ins_col_gpu, static_cast<int>(inputs_cols.size()), out_rows,
out_cols, output->data<T>());
} }
} }
}; };
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T> template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> { class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const int axis, const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) { std::vector<framework::Tensor>& outputs) {
// assume the the max size of input is less than 8 and see the performance // TODO(zcd): Add input data validity checking
// save origin dim
int num = outputs.size(); int num = outputs.size();
PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d",
MaxSize);
// get the matrix size
int input_row = 1; int input_row = 1;
auto dim_0 = outputs[0].dims(); auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
...@@ -229,11 +218,10 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -229,11 +218,10 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
int input_col = 0; int input_col = 0;
bool sameShape = true; bool sameShape = true;
CUDADeviceArray<T*> outputs_data; paddle::framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
CUDADeviceArray<int> outputs_cols; paddle::framework::Vector<int> outputs_cols(num + 1);
outputs_data.size = num; outputs_cols[0] = 0;
outputs_cols.size = num + 1; T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
outputs_cols.data[0] = 0;
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
int t_col = outputs[i].numel() / input_row; int t_col = outputs[i].numel() / input_row;
...@@ -241,12 +229,16 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -241,12 +229,16 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
if (t_col != output_col_0) sameShape = false; if (t_col != output_col_0) sameShape = false;
} }
input_col += t_col; input_col += t_col;
outputs_cols.data[i + 1] = input_col; outputs_cols[i + 1] = input_col;
outputs_data.data[i] = outputs[i].data<T>(); outputs_ptr[i] = outputs[i].data<T>();
} }
T** outs_gpu =
reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));
const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace());
// computation // computation
const int kThreadsPerBlock = 256; const int kThreadsPerBlock = 1024;
int block_cols = std::min(input_col, kThreadsPerBlock); int block_cols = std::min(input_col, kThreadsPerBlock);
int block_rows = std::max(kThreadsPerBlock / block_cols, 1); int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
dim3 block_size = dim3(block_cols, block_rows, 1); dim3 block_size = dim3(block_cols, block_rows, 1);
...@@ -257,10 +249,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -257,10 +249,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
if (sameShape) { if (sameShape) {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), input_row, input_col, output_col_0, outputs_data); input.data<T>(), input_row, input_col, output_col_0, outs_gpu);
} else { } else {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), input_row, input_col, outputs_cols, outputs_data); input.data<T>(), input_row, input_col, outs_col_gpu,
static_cast<int>(outputs_cols.size()), outs_gpu);
} }
} }
}; };
......
...@@ -20,7 +20,16 @@ namespace operators { ...@@ -20,7 +20,16 @@ namespace operators {
namespace math { namespace math {
/* /*
* \brief Concatenate the input tensors along the dimension axis.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input[0] = [[1,2],[3,4]]
* Input[1] = [[5,6]]
* axis = 0
* *
* Output = [[1,2],
* [3,4],
* [5,6]]
*/ */
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ConcatFunctor { class ConcatFunctor {
...@@ -30,6 +39,18 @@ class ConcatFunctor { ...@@ -30,6 +39,18 @@ class ConcatFunctor {
framework::Tensor* output); framework::Tensor* output);
}; };
/*
* \brief Split the input tensors along the dimension axis into outputs.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input = [[1,2],
* [3,4],
* [5,6]]
* axis = 0
*
* Output[0] = [[1,2],[3,4]]
* Output[1] = [[5,6]]
*/
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ConcatGradFunctor { class ConcatGradFunctor {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册