diff --git a/paddle/fluid/operators/flatten_op_mlu.cc b/paddle/fluid/operators/flatten_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..fbc0d989cb4510dd7bf8cd5a4705d925ad773398 --- /dev/null +++ b/paddle/fluid/operators/flatten_op_mlu.cc @@ -0,0 +1,245 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/operators/flatten_op.h" + +namespace paddle { +namespace operators { + +template +class FlattenMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *in = context.Input("X"); + auto *out = context.Output("Out"); + + auto &axes = context.Attr("axis"); + auto x_dims = in->dims(); + auto out_dims = framework::make_ddim(GetOutputShape(axes, x_dims)); + out->mutable_data(context.GetPlace(), in->type()); + framework::TensorCopy( + *in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); + } + + static std::vector GetOutputShape(const int axis, + const framework::DDim &in_dims) { + int64_t outer = 1, inner = 1; + for (int i = 0; i < in_dims.size(); ++i) { + if (i < axis) { + outer *= in_dims[i]; + } else { + inner *= in_dims[i]; + } + } + std::vector out_shape(2); + out_shape[0] = outer; + out_shape[1] = inner; + return out_shape; + } +}; + +template +class FlattenGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + auto in_dims = ctx.Input("X")->dims(); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); + d_x->Resize(in_dims); + } +}; + +template +class Flatten2MLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto &axes = context.Attr("axis"); + + auto *in = context.Input("X"); + auto x_dims = in->dims(); + + auto *out = context.Output("Out"); + + auto out_dims = framework::make_ddim( + FlattenMLUKernel::GetOutputShape(axes, x_dims)); + + out->mutable_data(context.GetPlace(), in->type()); + framework::TensorCopy( + *in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); + } +}; + +template +class Flatten2GradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + + auto xshape_dims = ctx.Input("XShape")->dims(); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); + d_x->Resize(x_dims); + } +}; + +template +class FlattenContiguousRangeMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *in = context.Input("X"); + auto *out = context.Output("Out"); + out->mutable_data(context.GetPlace(), in->type()); + auto &start_axis = context.Attr("start_axis"); + auto &stop_axis = context.Attr("stop_axis"); + + // make out dims + auto in_dims = in->dims(); + auto out_dims = + framework::make_ddim(GetOutputShape(start_axis, stop_axis, in_dims)); + framework::TensorCopy( + *in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); + } + static std::vector GetOutputShape(const int start_axis, + const int stop_axis, + const framework::DDim &in_dims) { + int64_t outer = 1; + std::vector out_shape; + int in_dims_size = in_dims.size(); + out_shape.reserve(in_dims_size - stop_axis + start_axis); + int real_start_axis = start_axis, real_stop_axis = stop_axis; + if (start_axis < 0) { + real_start_axis = start_axis + in_dims_size; + } + if (stop_axis < 0) { + real_stop_axis = stop_axis + in_dims_size; + } + + for (int i = 0; i < real_start_axis; ++i) { + out_shape.push_back(in_dims[i]); + } + for (int i = real_start_axis; i <= real_stop_axis; i++) { + if (in_dims[i] == -1 || outer == -1) { + outer = -1; + } else { + outer *= in_dims[i]; + } + } + out_shape.push_back(outer); + for (int i = real_stop_axis + 1; i < in_dims_size; i++) { + out_shape.push_back(in_dims[i]); + } + + return out_shape; + } +}; + +template +class FlattenContiguousRangeGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + + auto xshape_dims = ctx.Input("XShape")->dims(); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); + d_x->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_MLU_KERNEL( + flatten, ops::FlattenMLUKernel, + ops::FlattenMLUKernel, + ops::FlattenMLUKernel, + ops::FlattenMLUKernel, + ops::FlattenMLUKernel, + ops::FlattenMLUKernel); +REGISTER_OP_MLU_KERNEL( + flatten_grad, + ops::FlattenGradMLUKernel, + ops::FlattenGradMLUKernel, + ops::FlattenGradMLUKernel, + ops::FlattenGradMLUKernel, + ops::FlattenGradMLUKernel, + ops::FlattenGradMLUKernel); +REGISTER_OP_MLU_KERNEL( + flatten2, ops::Flatten2MLUKernel, + ops::Flatten2MLUKernel, + ops::Flatten2MLUKernel, + ops::Flatten2MLUKernel, + ops::Flatten2MLUKernel, + ops::Flatten2MLUKernel); +REGISTER_OP_MLU_KERNEL( + flatten2_grad, + ops::Flatten2GradMLUKernel, + ops::Flatten2GradMLUKernel, + ops::Flatten2GradMLUKernel, + ops::Flatten2GradMLUKernel, + ops::Flatten2GradMLUKernel, + ops::Flatten2GradMLUKernel); +REGISTER_OP_MLU_KERNEL( + flatten_contiguous_range, + ops::FlattenContiguousRangeMLUKernel, + ops::FlattenContiguousRangeMLUKernel, + ops::FlattenContiguousRangeMLUKernel, + ops::FlattenContiguousRangeMLUKernel, + ops::FlattenContiguousRangeMLUKernel, + ops::FlattenContiguousRangeMLUKernel); +REGISTER_OP_MLU_KERNEL( + flatten_contiguous_range_grad, + ops::FlattenContiguousRangeGradMLUKernel, + ops::FlattenContiguousRangeGradMLUKernel, + ops::FlattenContiguousRangeGradMLUKernel, + ops::FlattenContiguousRangeGradMLUKernel, + ops::FlattenContiguousRangeGradMLUKernel, + ops::FlattenContiguousRangeGradMLUKernel);