提交 60e7ee06 编写于 作者: C chengduoZH

refine concat_op

上级 cf883d9c
......@@ -184,6 +184,7 @@ op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(concat_op DEPS concat_functor)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
ops::ConcatOpGrad, false)
REGISTER_OP_CPU_KERNEL(concat,
ops::ConcatKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP_CPU_KERNEL(concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>)
REGISTER_OP_CPU_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>)
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle {
......@@ -27,55 +28,17 @@ class ConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
auto out_stride = framework::stride_numel(out->dims());
size_t output_offset = 0;
// If axis >=1, copy to out immediately need to call many times
// of cuda memcpy. Copy the input to cpu and do the stride copy,
// then copy to gpu output.
if (platform::is_gpu_place(place) && axis >= 1) {
platform::CPUPlace copy_place;
auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place);
framework::Tensor cpu_out;
cpu_out.Resize(out->dims());
cpu_out.mutable_data<T>(copy_place);
auto& dev_ctx = ctx.device_context();
std::vector<std::unique_ptr<framework::Tensor>> cpu_ins;
for (auto* in : ins) {
std::unique_ptr<framework::Tensor> cpu_in(new framework::Tensor);
framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get());
cpu_ins.emplace_back(std::move(cpu_in));
}
// TODO(dzhwinter): overlap copy and compute stream
// https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
dev_ctx.Wait();
for (auto& in : cpu_ins) {
auto& cpu_in = *in.get();
auto in_stride = framework::stride_numel(cpu_in.dims());
StridedNumelCopyWithAxis<T>(
cpu_ctx, axis, cpu_out.data<T>() + output_offset, out_stride,
cpu_in.data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
framework::TensorCopy(cpu_out, place, dev_ctx, out);
} else {
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
std::vector<framework::Tensor> inputs(ins.size());
for (size_t j = 0; j < ins.size(); ++j) {
inputs[j] = *ins[j];
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
}
};
......
......@@ -20,6 +20,7 @@ if(WITH_GPU)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context)
nv_library(concat_functor SRCS concat.cc concat.cu DEPS device_context tensor)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
......@@ -37,6 +38,7 @@ else()
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context)
cc_library(concat_functor SRCS concat.cc DEPS device_context tensor)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......@@ -44,3 +46,4 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
cc_test(concat_test SRCS concat_test.cc DEPS concat_functor tensor)
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
namespace paddle {
namespace operators {
namespace math {
/*
* All tensors' dimension should be the same.
*/
template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// assume the the max size of input is less than 8 and see the performance
// save origin dim
int num = input.size();
std::vector<paddle::framework::DDim> origin_dim(num);
// for (int j = 0; j < num; ++j) {
// origin_dim[j] = input[j].dims();
// }
auto out_dim = output->dims();
// get the matrix size
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int cols = input[0].numel() / rows;
int out_rows = rows, out_cols = 0;
bool sameShape = true;
// reshape to matrix
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
if (sameShape) {
if (t_cols != cols) sameShape = false;
}
out_cols += t_cols;
input[i].Resize({rows, t_cols});
}
output->Resize({out_rows, out_cols});
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation
for (int k = 0; k < rows; ++k) {
// offset k * out_cols
T* dst_ptr = output->data<T>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input[j].dims()[1];
const T* src_prt = input[j].data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt,
sizeof(T) * col_len);
col_idx += col_len;
}
}
// recover origin dim
// for (int j = 0; j < num; ++j) {
// input[j]->Resize(origin_dim[j]);
// }
output->Resize(out_dim);
}
};
template class ConcatFunctor<platform::CPUDeviceContext, int>;
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatFunctor<platform::CPUDeviceContext, float>;
template class ConcatFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
// TODO(zcd): This can be replaced by tensor,
// if that, maybe we should add int8 to VarType::Type.
// Or replaced by tensorArray.
static constexpr int MaxSize = 32;
template <typename T>
struct CUDADeviceArray {
T data[MaxSize];
int size;
};
template <typename T>
__device__ T upper_bound(const T* first, T count, T val) {
const T* orig = first;
const T* it = nullptr;
T step = 0;
while (count > 0) {
it = first;
step = count / 2;
it += step;
if (!(val < *it)) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
return first - orig;
}
template <typename T>
__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
const CUDADeviceArray<int> input_cols,
const int output_rows, const int output_cols,
T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
int segment = upper_bound<int>(input_cols.data, input_cols.size, tid_x) - 1;
int curr_offset = input_cols.data[segment];
int curr_segment = segment;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset;
while ((curr_col_offset = input_cols.data[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
const T* input_ptr = inputs.data[curr_segment];
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * segment_width + local_col];
}
}
/*
* All tensors' dimension should be the same.
*/
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// assume the the max size of input is less than 8 and see the performance
// save origin dim
int num = input.size();
// std::vector<paddle::framework::DDim> origin_dim(num);
// for (int j = 0; j < num; ++j) {
// origin_dim[j] = input[j].dims();
// }
auto out_dim = output->dims();
// get the matrix size
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int cols = input[0].numel() / rows;
int out_rows = rows, out_cols = 0;
bool sameShape = true;
CUDADeviceArray<const T*> inputs_data;
CUDADeviceArray<int> inputs_cols;
inputs_data.size = num;
inputs_cols.size = num + 1;
inputs_cols.data[0] = 0;
// reshape to matrix
// check input shape is valid
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
if (sameShape) {
if (t_cols != cols) sameShape = false;
}
out_cols += t_cols;
input[i].Resize({rows, t_cols});
inputs_cols.data[i + 1] = out_cols;
inputs_data.data[i] = input[i].data<T>();
}
output->Resize({out_rows, out_cols});
// computation
const int kThreadsPerBlock = 256;
int block_cols = std::min(out_cols, kThreadsPerBlock);
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
dim3 block_size = dim3(block_cols, block_rows, 1);
int grid_cols = (out_cols + block_cols - 1) / block_cols;
int grid_rows = (out_rows + block_rows - 1) / block_rows;
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
inputs_data, inputs_cols, out_rows, out_cols, output->data<T>());
// recover origin dim
// for (int j = 0; j < num; ++j) {
// input[j].Resize(origin_dim[j]);
// }
output->Resize(out_dim);
}
};
template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
/*
* the tensor's shape of input will be changed,
* so the second parameter is not const.
*
*/
template <typename DeviceContext, typename T>
class ConcatFunctor {
public:
void operator()(const DeviceContext& context,
std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/tensor_util.h"
using namespace paddle::framework;
using namespace paddle::platform;
template <typename DeviceContext, typename Place>
void testConcat() {
Tensor input_a_cpu;
Tensor input_b_cpu;
Tensor out_cpu;
Tensor input_a;
Tensor input_b;
Tensor out;
DeviceContext* context = new DeviceContext(Place());
// DeviceContext context(Place());
/**
* cast1:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [3, 3, 4]
* output:
* out.shape: [5, 3, 4]
*/
auto dim_a = make_ddim({2, 3, 4});
auto dim_b = make_ddim({3, 3, 4});
auto dim_out = make_ddim({5, 3, 4});
input_a.mutable_data<int>(dim_a, Place());
input_b.mutable_data<int>(dim_b, Place());
out.mutable_data<int>(dim_out, Place());
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.mutable_data<int>(dim_a, CPUPlace());
input_b_cpu.mutable_data<int>(dim_b, CPUPlace());
out_cpu.mutable_data<int>(dim_out, CPUPlace());
}
int* a_ptr;
int* b_ptr;
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
} else {
a_ptr = input_a.data<int>();
b_ptr = input_b.data<int>();
}
for (int i = 0; i < 2 * 3 * 4; ++i) {
a_ptr[i] = i;
}
for (int i = 0; i < 3 * 3 * 4; ++i) {
b_ptr[i] = i;
}
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(input_a_cpu, Place(), *context, &input_a);
TensorCopy(input_b_cpu, Place(), *context, &input_b);
}
std::vector<Tensor> input;
input.push_back(input_a);
input.push_back(input_b);
paddle::operators::math::ConcatFunctor<DeviceContext, int> concat_functor;
concat_functor(*context, input, 0, &out);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
int* out_ptr;
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(out, CPUPlace(), *context, &out_cpu);
out_ptr = out_cpu.data<int>();
} else {
out_ptr = out.data<int>();
}
int cols = 2 * 3 * 4;
int idx_a = 0, idx_b = 0;
for (int j = 0; j < 5 * 3 * 4; ++j) {
if (j >= cols) {
PADDLE_ENFORCE_EQ(out_ptr[j], b_ptr[idx_b]);
++idx_b;
} else {
PADDLE_ENFORCE_EQ(out_ptr[j], a_ptr[idx_a]);
++idx_a;
}
}
//
/**
* cast2:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [2, 4, 4]
* output:
* out.shape: [2, 7, 4]
*/
dim_a = make_ddim({2, 3, 4});
dim_b = make_ddim({2, 4, 4});
dim_out = make_ddim({2, 7, 4});
input_a.Resize(dim_a);
input_b.Resize(dim_b);
out.Resize(dim_out);
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.Resize(dim_a);
input_b_cpu.Resize(dim_b);
out_cpu.Resize(dim_out);
}
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
} else {
a_ptr = input_a.data<int>();
b_ptr = input_b.data<int>();
}
for (int i = 0; i < 2 * 3 * 4; ++i) {
a_ptr[i] = i;
}
for (int i = 0; i < 2 * 4 * 4; ++i) {
b_ptr[i] = i;
}
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(input_a_cpu, Place(), *context, &input_a);
TensorCopy(input_b_cpu, Place(), *context, &input_b);
}
input.clear();
input.push_back(input_a);
input.push_back(input_b);
concat_functor(*context, input, 1, &out);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(out, CPUPlace(), *context, &out_cpu);
out_ptr = out_cpu.data<int>();
} else {
out_ptr = out.data<int>();
}
cols = 3 * 4;
idx_a = 0, idx_b = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 28; ++j) {
if (j >= cols) {
PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], b_ptr[idx_b]);
++idx_b;
} else {
PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], a_ptr[idx_a]);
++idx_a;
}
}
}
/**
* cast3:
* inputs:
* t_a.shape: [2, 3, 5]
* t_b.shape: [2, 3, 4]
* output:
* out.shape: [2, 3, 9]
*/
dim_a = make_ddim({2, 3, 4});
dim_b = make_ddim({2, 3, 5});
dim_out = make_ddim({2, 3, 9});
input_a.Resize(dim_a);
input_b.Resize(dim_b);
out.Resize(dim_out);
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.Resize(dim_a);
input_b_cpu.Resize(dim_b);
out_cpu.Resize(dim_out);
}
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
} else {
a_ptr = input_a.data<int>();
b_ptr = input_b.data<int>();
}
for (int i = 0; i < 2 * 3 * 4; ++i) {
a_ptr[i] = i;
}
for (int i = 0; i < 2 * 3 * 5; ++i) {
b_ptr[i] = i;
}
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(input_a_cpu, Place(), *context, &input_a);
TensorCopy(input_b_cpu, Place(), *context, &input_b);
}
input.clear();
input.push_back(input_a);
input.push_back(input_b);
concat_functor(*context, input, 2, &out);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
if (paddle::platform::is_gpu_place(Place())) {
TensorCopy(out, CPUPlace(), *context, &out_cpu);
out_ptr = out_cpu.data<int>();
} else {
out_ptr = out.data<int>();
}
// check the data
cols = 4;
idx_a = 0, idx_b = 0;
for (int i = 0; i < 6; ++i) {
for (int j = 0; j < 9; ++j) {
if (j >= cols) {
PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], b_ptr[idx_b]);
++idx_b;
} else {
PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], a_ptr[idx_a]);
++idx_a;
}
}
}
}
TEST(math, concat) {
testConcat<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>();
#ifdef PADDLE_WITH_CUDA
testConcat<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>();
#endif
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册