未验证 提交 84aea8a8 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #8669 from chengduoZH/feature/concat_op

Refine concat_op
......@@ -201,6 +201,7 @@ op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(concat_op DEPS concat_functor)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
ops::ConcatOpGrad, false)
REGISTER_OP_CPU_KERNEL(concat,
ops::ConcatKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP_CPU_KERNEL(concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>)
REGISTER_OP_CPU_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>)
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle {
......@@ -27,54 +28,30 @@ class ConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
auto out_stride = framework::stride_numel(out->dims());
size_t output_offset = 0;
// If axis >=1, copy to out immediately need to call many times
// of cuda memcpy. Copy the input to cpu and do the stride copy,
// then copy to gpu output.
if (platform::is_gpu_place(place) && axis >= 1) {
platform::CPUPlace copy_place;
auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place);
framework::Tensor cpu_out;
cpu_out.Resize(out->dims());
cpu_out.mutable_data<T>(copy_place);
auto& dev_ctx = ctx.device_context();
std::vector<std::unique_ptr<framework::Tensor>> cpu_ins;
for (auto* in : ins) {
std::unique_ptr<framework::Tensor> cpu_in(new framework::Tensor);
framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get());
cpu_ins.emplace_back(std::move(cpu_in));
}
// TODO(dzhwinter): overlap copy and compute stream
// https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
dev_ctx.Wait();
for (auto& in : cpu_ins) {
auto& cpu_in = *in.get();
auto in_stride = framework::stride_numel(cpu_in.dims());
StridedNumelCopyWithAxis<T>(
cpu_ctx, axis, cpu_out.data<T>() + output_offset, out_stride,
cpu_in.data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
framework::TensorCopy(cpu_out, place, dev_ctx, out);
} else {
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && ins.size() < 10) {
size_t output_offset = 0;
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
} else {
std::vector<framework::Tensor> inputs(ins.size());
for (size_t j = 0; j < ins.size(); ++j) {
inputs[j] = *ins[j];
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
}
}
};
......@@ -86,16 +63,31 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
size_t input_offset = 0;
auto in_stride = framework::stride_numel(in->dims());
for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride, out_stride[axis]);
input_offset += out_stride[axis];
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
size_t input_offset = 0;
auto in_stride = framework::stride_numel(in->dims());
for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride, out_stride[axis]);
input_offset += out_stride[axis];
}
} else {
std::vector<framework::Tensor> outputs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs[j] = *outs[j];
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor;
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs);
}
}
};
......
......@@ -20,6 +20,7 @@ if(WITH_GPU)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context)
nv_library(concat_functor SRCS concat.cc concat.cu DEPS device_context tensor)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
......@@ -37,6 +38,7 @@ else()
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context)
cc_library(concat_functor SRCS concat.cc DEPS device_context tensor)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......@@ -44,3 +46,4 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
cc_test(concat_test SRCS concat_test.cc DEPS concat_functor tensor)
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
namespace paddle {
namespace operators {
namespace math {
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int num = input.size();
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation
for (int k = 0; k < out_rows; ++k) {
T* dst_ptr = output->data<T>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
const T* src_prt = input[j].data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt,
sizeof(T) * col_len);
col_idx += col_len;
}
}
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
// TODO(zcd): Add input data validity checking
int num = outputs.size();
int input_rows = 1;
auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) {
input_rows *= dim_0[i];
}
int input_cols = 0;
std::vector<int64_t> output_cols(outputs.size());
for (int i = 0; i < num; ++i) {
int t_cols = outputs[i].numel() / input_rows;
input_cols += t_cols;
output_cols[i] = t_cols;
}
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation
for (int k = 0; k < input_rows; ++k) {
const T* src_ptr = input.data<T>() + k * input_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = output_cols[j];
T* dst_ptr = outputs[j].data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
sizeof(T) * col_len);
col_idx += col_len;
}
}
}
};
template class ConcatFunctor<platform::CPUDeviceContext, int>;
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatFunctor<platform::CPUDeviceContext, float>;
template class ConcatFunctor<platform::CPUDeviceContext, double>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__device__ T upper_bound(const T* first, T count, T val) {
const T* orig = first;
const T* it = nullptr;
T step = 0;
while (count > 0) {
it = first;
step = count / 2;
it += step;
if (!(val < *it)) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
return first - orig;
}
template <typename T>
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
const int output_rows, const int output_cols,
T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(input_cols, col_size, tid_x) - 1;
int curr_offset = input_cols[segment];
int curr_segment = segment;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset;
while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* input_ptr = inputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * segment_width + local_col];
}
}
template <typename T>
__global__ void KernelConcat(T** inputs, const int input_col,
const int output_rows, const int output_cols,
T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
double inv_input_col = 1.0 / input_col;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
T* input_ptr = inputs[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * input_col + in_offset];
}
}
}
template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int* output_cols,
int col_size, T** outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(output_cols, col_size, tid_x) - 1;
int curr_offset = output_cols[segment];
int curr_segment = segment;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset;
while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* output_ptr = outputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] =
input[tid_y * input_col + tid_x];
}
}
template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int output_cols,
T** outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
double inv_input_col = 1.0 / input_col;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
T* output_ptr = outputs[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * output_cols + in_offset] =
input[tid_y * input_col + tid_x];
}
}
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int num = input.size();
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int cols = input[0].numel() / rows;
int out_rows = rows, out_cols = 0;
framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
framework::Vector<int> inputs_cols(num + 1);
inputs_cols[0] = 0;
T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());
bool sameShape = true;
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
if (sameShape) {
if (t_cols != cols) sameShape = false;
}
out_cols += t_cols;
inputs_cols[i + 1] = out_cols;
inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
}
T** ins_gpu =
reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));
const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace());
// computation
// set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((out_cols + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) {
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
ins_gpu, cols, out_rows, out_cols, output->data<T>());
} else {
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
ins_gpu, ins_col_gpu, static_cast<int>(inputs_cols.size()), out_rows,
out_cols, output->data<T>());
}
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
// TODO(zcd): Add input data validity checking
int num = outputs.size();
int input_row = 1;
auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) {
input_row *= dim_0[i];
}
int output_col_0 = outputs[0].numel() / input_row;
int input_col = 0;
bool sameShape = true;
framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
framework::Vector<int> outputs_cols(num + 1);
outputs_cols[0] = 0;
T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
for (int i = 0; i < num; ++i) {
int t_col = outputs[i].numel() / input_row;
if (sameShape) {
if (t_col != output_col_0) sameShape = false;
}
input_col += t_col;
outputs_cols[i + 1] = input_col;
outputs_ptr[i] = outputs[i].data<T>();
}
T** outs_gpu =
reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));
const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace());
// computation
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((input_col + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((input_col + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), input_row, input_col, output_col_0, outs_gpu);
} else {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), input_row, input_col, outs_col_gpu,
static_cast<int>(outputs_cols.size()), outs_gpu);
}
}
};
template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
/*
* \brief Concatenate the input tensors along the dimension axis.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input[0] = [[1,2],[3,4]]
* Input[1] = [[5,6]]
* axis = 0
*
* Output = [[1,2],
* [3,4],
* [5,6]]
*/
template <typename DeviceContext, typename T>
class ConcatFunctor {
public:
void operator()(const DeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output);
};
/*
* \brief Split the input tensors along the dimension axis into outputs.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input = [[1,2],
* [3,4],
* [5,6]]
* axis = 0
*
* Output[0] = [[1,2],[3,4]]
* Output[1] = [[5,6]]
*/
template <typename DeviceContext, typename T>
class ConcatGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const int axis, std::vector<framework::Tensor>& outputs);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/tensor_util.h"
using namespace paddle::framework;
using namespace paddle::platform;
template <typename DeviceContext, typename Place>
void testConcat() {
Tensor input_a_cpu;
Tensor input_b_cpu;
Tensor out_cpu;
Tensor input_a;
Tensor input_b;
Tensor out;
DeviceContext* context = new DeviceContext(Place());
// DeviceContext context(Place());
/**
* cast1:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [3, 3, 4]
* output:
* out.shape: [5, 3, 4]
*/
auto dim_a = make_ddim({2, 3, 4});
auto dim_b = make_ddim({3, 3, 4});
auto dim_out = make_ddim({5, 3, 4});
input_a.mutable_data<int>(dim_a, Place());
input_b.mutable_data<int>(dim_b, Place());
out.mutable_data<int>(dim_out, Place());
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.mutable_data<int>(dim_a, CPUPlace());
input_b_cpu.mutable_data<int>(dim_b, CPUPlace());
out_cpu.mutable_data<int>(dim_out, CPUPlace());
}
int* a_ptr;
int* b_ptr;
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
} else {
a_ptr = input_a.data<int>();
b_ptr = input_b.data<int>();
}
for (int i = 0; i < 2 * 3 * 4; ++i) {
a_ptr[i] = i;
}
for (int i = 0; i < 3 * 3 * 4; ++i) {
b_ptr[i] = i;
}
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(input_a_cpu, Place(), *context, &input_a);
TensorCopy(input_b_cpu, Place(), *context, &input_b);
}
std::vector<Tensor> input;
input.push_back(input_a);
input.push_back(input_b);
paddle::operators::math::ConcatFunctor<DeviceContext, int> concat_functor;
concat_functor(*context, input, 0, &out);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
int* out_ptr;
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(out, CPUPlace(), *context, &out_cpu);
out_ptr = out_cpu.data<int>();
} else {
out_ptr = out.data<int>();
}
int cols = 2 * 3 * 4;
int idx_a = 0, idx_b = 0;
for (int j = 0; j < 5 * 3 * 4; ++j) {
if (j >= cols) {
PADDLE_ENFORCE_EQ(out_ptr[j], b_ptr[idx_b]);
++idx_b;
} else {
PADDLE_ENFORCE_EQ(out_ptr[j], a_ptr[idx_a]);
++idx_a;
}
}
//
/**
* cast2:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [2, 4, 4]
* output:
* out.shape: [2, 7, 4]
*/
dim_a = make_ddim({2, 3, 4});
dim_b = make_ddim({2, 4, 4});
dim_out = make_ddim({2, 7, 4});
input_a.Resize(dim_a);
input_b.Resize(dim_b);
out.Resize(dim_out);
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.Resize(dim_a);
input_b_cpu.Resize(dim_b);
out_cpu.Resize(dim_out);
}
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
} else {
a_ptr = input_a.data<int>();
b_ptr = input_b.data<int>();
}
for (int i = 0; i < 2 * 3 * 4; ++i) {
a_ptr[i] = i;
}
for (int i = 0; i < 2 * 4 * 4; ++i) {
b_ptr[i] = i;
}
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(input_a_cpu, Place(), *context, &input_a);
TensorCopy(input_b_cpu, Place(), *context, &input_b);
}
input.clear();
input.push_back(input_a);
input.push_back(input_b);
concat_functor(*context, input, 1, &out);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(out, CPUPlace(), *context, &out_cpu);
out_ptr = out_cpu.data<int>();
} else {
out_ptr = out.data<int>();
}
cols = 3 * 4;
idx_a = 0, idx_b = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 28; ++j) {
if (j >= cols) {
PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], b_ptr[idx_b]);
++idx_b;
} else {
PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], a_ptr[idx_a]);
++idx_a;
}
}
}
/**
* cast3:
* inputs:
* t_a.shape: [2, 3, 5]
* t_b.shape: [2, 3, 4]
* output:
* out.shape: [2, 3, 9]
*/
dim_a = make_ddim({2, 3, 4});
dim_b = make_ddim({2, 3, 5});
dim_out = make_ddim({2, 3, 9});
input_a.Resize(dim_a);
input_b.Resize(dim_b);
out.Resize(dim_out);
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.Resize(dim_a);
input_b_cpu.Resize(dim_b);
out_cpu.Resize(dim_out);
}
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
} else {
a_ptr = input_a.data<int>();
b_ptr = input_b.data<int>();
}
for (int i = 0; i < 2 * 3 * 4; ++i) {
a_ptr[i] = i;
}
for (int i = 0; i < 2 * 3 * 5; ++i) {
b_ptr[i] = i;
}
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(input_a_cpu, Place(), *context, &input_a);
TensorCopy(input_b_cpu, Place(), *context, &input_b);
}
input.clear();
input.push_back(input_a);
input.push_back(input_b);
concat_functor(*context, input, 2, &out);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(out, CPUPlace(), *context, &out_cpu);
out_ptr = out_cpu.data<int>();
} else {
out_ptr = out.data<int>();
}
// check the data
cols = 4;
idx_a = 0, idx_b = 0;
for (int i = 0; i < 6; ++i) {
for (int j = 0; j < 9; ++j) {
if (j >= cols) {
PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], b_ptr[idx_b]);
++idx_b;
} else {
PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], a_ptr[idx_a]);
++idx_a;
}
}
}
/**
* cast4:
* inputs:
* axis = 1
* t_a.shape: [2, 3, 4]
* t_b.shape: [2, 3, 4]
* output:
* out.shape: [2, 6, 4]
*/
dim_a = make_ddim({2, 3, 4});
dim_b = make_ddim({2, 3, 4});
dim_out = make_ddim({2, 6, 4});
input_a.Resize(dim_a);
input_b.Resize(dim_b);
out.Resize(dim_out);
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.Resize(dim_a);
input_b_cpu.Resize(dim_b);
out_cpu.Resize(dim_out);
}
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
} else {
a_ptr = input_a.data<int>();
b_ptr = input_b.data<int>();
}
for (int i = 0; i < 2 * 3 * 4; ++i) {
a_ptr[i] = i;
}
for (int i = 0; i < 2 * 3 * 4; ++i) {
b_ptr[i] = i;
}
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(input_a_cpu, Place(), *context, &input_a);
TensorCopy(input_b_cpu, Place(), *context, &input_b);
}
input.clear();
input.push_back(input_a);
input.push_back(input_b);
concat_functor(*context, input, 1, &out);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(out, CPUPlace(), *context, &out_cpu);
out_ptr = out_cpu.data<int>();
} else {
out_ptr = out.data<int>();
}
// check the data
cols = 12;
idx_a = 0, idx_b = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 24; ++j) {
if (j >= cols) {
PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], b_ptr[idx_b]);
++idx_b;
} else {
PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], a_ptr[idx_a]);
++idx_a;
}
}
}
}
TEST(math, concat) {
testConcat<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>();
#ifdef PADDLE_WITH_CUDA
testConcat<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>();
#endif
}
......@@ -127,6 +127,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
SetDeviceId(place_.device);
multi_process = GetCUDAMultiProcessors(place_.device);
max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place);
......@@ -160,6 +162,10 @@ void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaGetLastError());
}
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
return multi_process * max_threads_per_mp;
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
......
......@@ -79,6 +79,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return place in the device context. */
Place GetPlace() const override;
/*! \brief Return the max physical thread count in the device context */
int GetMaxPhysicalThreadCount() const;
/*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const;
......@@ -100,6 +103,9 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_;
int multi_process;
int max_threads_per_mp;
};
template <>
......
......@@ -33,6 +33,26 @@ int GetCUDADeviceCount() {
return count;
}
int GetCUDAMultiProcessors(int id) {
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
int count;
PADDLE_ENFORCE(
cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id),
"cudaDeviceGetAttribute failed in "
"paddle::platform::GetCUDAMultiProcessors");
return count;
}
int GetCUDAMaxThreadsPerMultiProcessor(int id) {
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
int count;
PADDLE_ENFORCE(cudaDeviceGetAttribute(
&count, cudaDevAttrMaxThreadsPerMultiProcessor, id),
"cudaDeviceGetAttribute failed in "
"paddle::platform::GetCUDAMaxThreadsPerMultiProcessor");
return count;
}
int GetCurrentDeviceId() {
int device_id;
PADDLE_ENFORCE(
......
......@@ -30,6 +30,12 @@ const std::string kEnvFractionGpuMemoryToUse =
//! Get the total number of GPU devices in system.
int GetCUDADeviceCount();
//! Get the MultiProcessors of the ith GPU.
int GetCUDAMultiProcessors(int i);
//! Get the MaxThreads of each MultiProcessor of the ith GPU.
int GetCUDAMaxThreadsPerMultiProcessor(int i);
//! Get the current GPU device id in system.
int GetCurrentDeviceId();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册