From 8b1048b460e6f7093746e648db39604200c7c5cd Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 16 Feb 2022 21:24:17 +0800 Subject: [PATCH] Revert "[pten] remove concat fluid kernel (#39268)" This reverts commit 552db8dc00262858221586cc52ef652b703346eb. --- paddle/fluid/operators/concat_op.cc | 14 +++++- paddle/fluid/operators/concat_op.cu.cc | 13 ++++- paddle/fluid/operators/concat_op.h | 48 +++++++++++++++++++ .../operators/tensor_array_to_tensor_op.cc | 2 +- 4 files changed, 74 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 5d6f5f4d859..86e76b95311 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -244,7 +244,19 @@ REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad, ops::ConcatDoubleGradOpMaker, ops::ConcatDoubleGradOpMaker, ops::ConcatOpGradNoNeedBufferVarInferer); - +REGISTER_OP_CPU_KERNEL( + concat, ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel>, + ops::ConcatKernel>); REGISTER_OP_CPU_KERNEL( concat_grad, ops::ConcatGradKernel, diff --git a/paddle/fluid/operators/concat_op.cu.cc b/paddle/fluid/operators/concat_op.cu.cc index f7b64f16e2d..d622dbf4d31 100644 --- a/paddle/fluid/operators/concat_op.cu.cc +++ b/paddle/fluid/operators/concat_op.cu.cc @@ -19,7 +19,18 @@ limitations under the License. */ namespace ops = paddle::operators; namespace plat = paddle::platform; - +REGISTER_OP_CUDA_KERNEL( + concat, ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel>, + ops::ConcatKernel>); REGISTER_OP_CUDA_KERNEL( concat_grad, ops::ConcatGradKernel, diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 7b53b9df6f9..1d9c10bdb8c 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -39,6 +39,54 @@ static inline int64_t ComputeAxis(int64_t axis, int64_t rank) { } return axis > 0 ? axis : 0; } + +template +class ConcatKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + framework::LoDTensor* out = ctx.Output("Out"); + PADDLE_ENFORCE_NOT_NULL(ins[0], + platform::errors::NotFound( + "The first input tensor is not initalized.")); + auto axis = ctx.Attr("axis"); + bool need_resize_out_dims = false; + if (ctx.HasInput("AxisTensor")) { + auto* axis_tensor = ctx.Input("AxisTensor"); + axis = GetDataFromTensor(axis_tensor)[0]; + need_resize_out_dims = true; + } + axis = ComputeAxis(static_cast(axis), + static_cast(ins[0]->dims().size())); + + if (need_resize_out_dims) { + const size_t n = ins.size(); + std::vector ins_dims(n); + for (size_t i = 0; i < n; i++) { + ins_dims[i] = ins[i]->dims(); + } + + framework::DDim out_dims = + pten::funcs::ComputeAndCheckShape(true, ins_dims, axis); + out->Resize(out_dims); + } + auto place = ctx.GetPlace(); + out->mutable_data(place); + + // call new kernel + auto& dev_ctx = ctx.device_context(); + std::vector pt_ins; + for (auto& in : ins) { + pt_ins.push_back(*in); + } + + pten::ConcatKernel( + static_cast::TYPE&>(dev_ctx), + pt_ins, axis, out); + } +}; + template class ConcatGradKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/tensor_array_to_tensor_op.cc b/paddle/fluid/operators/tensor_array_to_tensor_op.cc index fa49f254d97..eb20e1c2cd2 100644 --- a/paddle/fluid/operators/tensor_array_to_tensor_op.cc +++ b/paddle/fluid/operators/tensor_array_to_tensor_op.cc @@ -299,7 +299,7 @@ class TensorArrayToTensorGradOpMaker : public framework::SingleGradOpMaker { } // namespace operators } // namespace paddle -USE_OP_ITSELF(concat); +USE_OP(concat); namespace ops = paddle::operators; REGISTER_OPERATOR( -- GitLab