diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 6ac70eacaf9b5b1c9205f84ab3c7047e4b1bffc9..92c8ab6d9ff11ec6acd46a39877eb67d624748a9 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -33,14 +33,26 @@ class ConcatKernel : public framework::OpKernel { auto place = ctx.GetPlace(); out->mutable_data(place); - // TODO(zcd): Sometimes direct copies will be faster - std::vector inputs(ins.size()); - for (size_t j = 0; j < ins.size(); ++j) { - inputs[j] = *ins[j]; + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && ins.size() < 10) { + size_t output_offset = 0; + for (auto* in : ins) { + auto in_stride = framework::stride_numel(in->dims()); + auto out_stride = framework::stride_numel(out->dims()); + StridedNumelCopyWithAxis(ctx.device_context(), axis, + out->data() + output_offset, out_stride, + in->data(), in_stride, in_stride[axis]); + output_offset += in_stride[axis]; + } + } else { + std::vector inputs(ins.size()); + for (size_t j = 0; j < ins.size(); ++j) { + inputs[j] = *ins[j]; + } + auto& dev_ctx = ctx.template device_context(); + paddle::operators::math::ConcatFunctor concat_functor; + concat_functor(dev_ctx, inputs, static_cast(axis), out); } - auto& dev_ctx = ctx.template device_context(); - paddle::operators::math::ConcatFunctor concat_functor; - concat_functor(dev_ctx, inputs, static_cast(axis), out); } }; @@ -52,17 +64,31 @@ class ConcatGradKernel : public framework::OpKernel { auto outs = ctx.MultiOutput(framework::GradVarName("X")); int64_t axis = static_cast(ctx.Attr("axis")); - // TODO(zcd): Sometimes direct copies will be faster - std::vector outputs(outs.size()); - for (size_t j = 0; j < outs.size(); ++j) { - outs[j]->mutable_data(ctx.GetPlace()); - outputs[j] = *outs[j]; - } + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + size_t input_offset = 0; + auto in_stride = framework::stride_numel(in->dims()); + + for (auto& out : outs) { + out->mutable_data(ctx.GetPlace()); + auto out_stride = framework::stride_numel(out->dims()); + StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), + out_stride, in->data() + input_offset, + in_stride, out_stride[axis]); + input_offset += out_stride[axis]; + } + } else { + std::vector outputs(outs.size()); + for (size_t j = 0; j < outs.size(); ++j) { + outs[j]->mutable_data(ctx.GetPlace()); + outputs[j] = *outs[j]; + } - auto& dev_ctx = ctx.template device_context(); - paddle::operators::math::ConcatGradFunctor - concat_grad_functor; - concat_grad_functor(dev_ctx, *in, static_cast(axis), outputs); + auto& dev_ctx = ctx.template device_context(); + paddle::operators::math::ConcatGradFunctor + concat_grad_functor; + concat_grad_functor(dev_ctx, *in, static_cast(axis), outputs); + } } };