diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 5d6f5f4d8593d2dd6640af25b9ed0eb954fe9454..86e76b95311cf6228838f57d8dcb761d48154dc5 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -244,7 +244,19 @@ REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad, ops::ConcatDoubleGradOpMaker, ops::ConcatDoubleGradOpMaker, ops::ConcatOpGradNoNeedBufferVarInferer); - +REGISTER_OP_CPU_KERNEL( + concat, ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel>, + ops::ConcatKernel>); REGISTER_OP_CPU_KERNEL( concat_grad, ops::ConcatGradKernel, diff --git a/paddle/fluid/operators/concat_op.cu.cc b/paddle/fluid/operators/concat_op.cu.cc index f7b64f16e2d8bc42063685bd62e9d2bddc6fbd33..d622dbf4d31829938666c420c50a4b7e610f261b 100644 --- a/paddle/fluid/operators/concat_op.cu.cc +++ b/paddle/fluid/operators/concat_op.cu.cc @@ -19,7 +19,18 @@ limitations under the License. */ namespace ops = paddle::operators; namespace plat = paddle::platform; - +REGISTER_OP_CUDA_KERNEL( + concat, ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel, + ops::ConcatKernel>, + ops::ConcatKernel>); REGISTER_OP_CUDA_KERNEL( concat_grad, ops::ConcatGradKernel, diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 7b53b9df6f95134f3aaafa7c34bef71eaf805d3c..1d9c10bdb8cc6a698a4a1b6ab376e90b67eb2a03 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -39,6 +39,54 @@ static inline int64_t ComputeAxis(int64_t axis, int64_t rank) { } return axis > 0 ? axis : 0; } + +template +class ConcatKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + framework::LoDTensor* out = ctx.Output("Out"); + PADDLE_ENFORCE_NOT_NULL(ins[0], + platform::errors::NotFound( + "The first input tensor is not initalized.")); + auto axis = ctx.Attr("axis"); + bool need_resize_out_dims = false; + if (ctx.HasInput("AxisTensor")) { + auto* axis_tensor = ctx.Input("AxisTensor"); + axis = GetDataFromTensor(axis_tensor)[0]; + need_resize_out_dims = true; + } + axis = ComputeAxis(static_cast(axis), + static_cast(ins[0]->dims().size())); + + if (need_resize_out_dims) { + const size_t n = ins.size(); + std::vector ins_dims(n); + for (size_t i = 0; i < n; i++) { + ins_dims[i] = ins[i]->dims(); + } + + framework::DDim out_dims = + pten::funcs::ComputeAndCheckShape(true, ins_dims, axis); + out->Resize(out_dims); + } + auto place = ctx.GetPlace(); + out->mutable_data(place); + + // call new kernel + auto& dev_ctx = ctx.device_context(); + std::vector pt_ins; + for (auto& in : ins) { + pt_ins.push_back(*in); + } + + pten::ConcatKernel( + static_cast::TYPE&>(dev_ctx), + pt_ins, axis, out); + } +}; + template class ConcatGradKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/tensor_array_to_tensor_op.cc b/paddle/fluid/operators/tensor_array_to_tensor_op.cc index fa49f254d972a38ad54922c8a303654dedc36682..eb20e1c2cd2748a5ab4db28df0c4798837c7bf21 100644 --- a/paddle/fluid/operators/tensor_array_to_tensor_op.cc +++ b/paddle/fluid/operators/tensor_array_to_tensor_op.cc @@ -299,7 +299,7 @@ class TensorArrayToTensorGradOpMaker : public framework::SingleGradOpMaker { } // namespace operators } // namespace paddle -USE_OP_ITSELF(concat); +USE_OP(concat); namespace ops = paddle::operators; REGISTER_OPERATOR(