未验证 提交 69827f30 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #11527 from jacquesqiao/concat-grad-support-data-input

concat support data as input
...@@ -60,34 +60,45 @@ template <typename DeviceContext, typename T> ...@@ -60,34 +60,45 @@ template <typename DeviceContext, typename T>
class ConcatGradKernel : public framework::OpKernel<T> { class ConcatGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto out_var_names = ctx.Outputs(framework::GradVarName("X"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
// get output tensor that the name is not kEmptyVarName
std::vector<framework::Tensor*> outputs;
for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs.push_back(outs[j]);
} else {
outputs.push_back(nullptr);
}
}
// Sometimes direct copies will be faster, this maybe need deeply analysis. // Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) { if (axis == 0 && outs.size() < 10) {
size_t input_offset = 0; size_t input_offset = 0;
auto in_stride = framework::stride_numel(in->dims()); const auto in_stride = framework::stride_numel(out_grad->dims());
for (auto& out : outs) { for (size_t i = 0; i < outs.size(); ++i) {
out->mutable_data<T>(ctx.GetPlace()); auto out_stride = framework::stride_numel(ins[i]->dims());
auto out_stride = framework::stride_numel(out->dims()); auto* out = outputs[i];
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), if (out != nullptr) {
out_stride, in->data<T>() + input_offset, StridedNumelCopyWithAxis<T>(
in_stride, out_stride[axis]); ctx.device_context(), axis, out->data<T>(), out_stride,
out_grad->data<T>() + input_offset, in_stride, out_stride[axis]);
}
input_offset += out_stride[axis]; input_offset += out_stride[axis];
} }
} else { } else {
std::vector<framework::Tensor> outputs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs[j] = *outs[j];
}
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T> paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor; concat_grad_functor;
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), &outputs); concat_grad_functor(dev_ctx, *out_grad, ins, static_cast<int>(axis),
&outputs);
} }
} }
}; };
......
...@@ -70,35 +70,40 @@ template <typename T> ...@@ -70,35 +70,40 @@ template <typename T>
class ConcatGradFunctor<platform::CPUDeviceContext, T> { class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const int axis, const framework::Tensor& input,
std::vector<framework::Tensor>* outputs) { const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int num = outputs->size(); size_t num = outputs->size();
int input_rows = 1; int input_rows = 1;
auto dim_0 = outputs->at(0).dims(); auto dim_0 = ref_inputs[0]->dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
input_rows *= dim_0[i]; input_rows *= dim_0[i];
} }
int input_cols = 0; int input_cols = 0;
std::vector<int64_t> output_cols(outputs->size()); std::vector<int64_t> output_cols(outputs->size());
for (int i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
int t_cols = outputs->at(i).numel() / input_rows; int t_cols = ref_inputs[i]->numel() / input_rows;
input_cols += t_cols; input_cols += t_cols;
output_cols[i] = t_cols; output_cols[i] = t_cols;
} }
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace()); auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation // computation
for (int k = 0; k < input_rows; ++k) { for (size_t k = 0; k < input_rows; ++k) {
const T* src_ptr = input.data<T>() + k * input_cols; const T* src_ptr = input.data<T>() + k * input_cols;
int col_idx = 0; int col_idx = 0;
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
int col_len = output_cols[j]; int col_len = output_cols[j];
T* dst_ptr = outputs->at(j).data<T>() + k * col_len; auto* out_tensor = outputs->at(j);
if (out_tensor != nullptr) {
T* dst_ptr = out_tensor->data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx, memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
sizeof(T) * col_len); sizeof(T) * col_len);
}
col_idx += col_len; col_idx += col_len;
} }
} }
......
...@@ -102,11 +102,13 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row, ...@@ -102,11 +102,13 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
int local_col = tid_x - curr_offset; int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset; int segment_width = curr_col_offset - curr_offset;
T* output_ptr = outputs_data[curr_segment]; T* output_ptr = outputs_data[curr_segment];
if (output_ptr != nullptr) {
int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] = output_ptr[tid_y * segment_width + local_col] =
input_data[tid_y * in_col + tid_x]; input_data[tid_y * in_col + tid_x];
} }
}
} }
template <typename T> template <typename T>
...@@ -118,11 +120,13 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row, ...@@ -118,11 +120,13 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
int split = tid_x / fixed_out_col; int split = tid_x / fixed_out_col;
int in_offset = tid_x - split * fixed_out_col; int in_offset = tid_x - split * fixed_out_col;
T* output_ptr = outputs_data[split]; T* output_ptr = outputs_data[split];
if (output_ptr != nullptr) {
int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * fixed_out_col + in_offset] = output_ptr[tid_y * fixed_out_col + in_offset] =
input_data[tid_y * in_col + tid_x]; input_data[tid_y * in_col + tid_x];
} }
}
} }
/* /*
...@@ -203,17 +207,18 @@ template <typename T> ...@@ -203,17 +207,18 @@ template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> { class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const int axis, const framework::Tensor& input,
std::vector<framework::Tensor>* outputs) { const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int o_num = outputs->size(); int o_num = outputs->size();
int out_row = 1; int out_row = 1;
auto dim_0 = outputs->at(0).dims(); auto dim_0 = ref_inputs[0]->dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
out_row *= dim_0[i]; out_row *= dim_0[i];
} }
int out_col = outputs->at(0).numel() / out_row; int out0_col = ref_inputs[0]->numel() / out_row;
int in_col = 0, in_row = out_row; int in_col = 0, in_row = out_row;
bool sameShape = true; bool sameShape = true;
...@@ -223,13 +228,17 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -223,13 +228,17 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
outputs_cols[0] = 0; outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) { for (int i = 0; i < o_num; ++i) {
int t_col = outputs->at(i).numel() / out_row; int t_col = outputs->at(i)->numel() / out_row;
if (sameShape) { if (sameShape) {
if (t_col != out_col) sameShape = false; if (t_col != out0_col) sameShape = false;
} }
in_col += t_col; in_col += t_col;
outputs_cols[i + 1] = in_col; outputs_cols[i + 1] = in_col;
outputs_ptr[i] = outputs->at(i).data<T>(); if (outputs->at(i) != nullptr) {
outputs_ptr[i] = outputs->at(i)->data<T>();
} else {
outputs_ptr[i] = nullptr;
}
} }
T** dev_out_gpu_data = T** dev_out_gpu_data =
...@@ -255,7 +264,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -255,7 +264,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
if (sameShape) { if (sameShape) {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, out_col, dev_out_gpu_data); input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
} else { } else {
const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace()); const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
......
...@@ -57,7 +57,8 @@ template <typename DeviceContext, typename T> ...@@ -57,7 +57,8 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor { class ConcatGradFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const int axis, std::vector<framework::Tensor>* outputs); const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs);
}; };
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册