diff --git a/paddle/fluid/operators/tile_op.cc b/paddle/fluid/operators/tile_op.cc index dc12f8e8892a022c6f55f4fe3a6237a7a01fa290..e179149c5bb77bd642f744be48109a941c66febf 100644 --- a/paddle/fluid/operators/tile_op.cc +++ b/paddle/fluid/operators/tile_op.cc @@ -12,11 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/tile_op.h" #include #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { @@ -26,66 +30,6 @@ class TileOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Tile"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Tile"); - auto x_dims = ctx->GetInputDim("X"); - auto repeat_times = ctx->Attrs().Get>("repeat_times"); - if (repeat_times.size() == 0) { - repeat_times = std::vector(x_dims.size(), -1); - } - - PADDLE_ENFORCE_LE( - x_dims.size(), MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The rank of the input 'x' for tile op " - "must not be greater than %d, but the value received is %d.", - MAX_RANK_SUPPORTED, x_dims.size())); - PADDLE_ENFORCE_LE( - repeat_times.size(), MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The size of the shape of input 'repeat_times' for tile op " - "must not be greater than %d, but the value received is %d.", - MAX_RANK_SUPPORTED, repeat_times.size())); - PADDLE_ENFORCE_GE( - repeat_times.size(), 1, - platform::errors::InvalidArgument( - "The size of the shape of input 'repeat_times' for tile op " - "must be positive integers, but the value received is %d.", - repeat_times.size())); - - auto out_rank = - std::max(static_cast(x_dims.size()), repeat_times.size()); - std::vector out_shape(out_rank); - auto x_dim_vec = phi::vectorize(x_dims); - if (x_dim_vec.size() > repeat_times.size()) { - auto diff = x_dim_vec.size() - repeat_times.size(); - repeat_times.insert(repeat_times.begin(), diff, -1); - } else { - auto diff = repeat_times.size() - x_dim_vec.size(); - x_dim_vec.insert(x_dim_vec.begin(), diff, -1); - } - for (size_t i = 0; i < repeat_times.size(); ++i) { - if (x_dim_vec[i] == -1 || repeat_times[i] == -1) { - out_shape[i] = -1; - } else { - PADDLE_ENFORCE_GT( - repeat_times[i], 0, - platform::errors::InvalidArgument( - "Every element of the input 'repeat_times' for tile op must be " - "greater than 0, but the value given is %d.", - repeat_times[i])); - out_shape[i] = x_dim_vec[i] * repeat_times[i]; - } - } - - ctx->SetOutputDim("Out", phi::make_ddim(out_shape)); - if (out_shape[0] == x_dims[0]) { - ctx->ShareLoD("X", "Out"); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -268,38 +212,15 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(TileGradNoNeedBufVarsInferer, "X"); } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(tile, TileInferMetaFunctor, + PD_INFER_META(phi::TileInferMeta)); + REGISTER_OPERATOR(tile, ops::TileOp, ops::TileOpMaker, ops::TileGradOpMaker, - ops::TileGradOpMaker); + ops::TileGradOpMaker, + TileInferMetaFunctor); REGISTER_OPERATOR(tile_grad, ops::TileGradOp, ops::TileDoubleGradOpMaker, ops::TileDoubleGradOpMaker, ops::TileGradNoNeedBufVarsInferer); -REGISTER_OP_CPU_KERNEL( - tile, ops::TileKernel, - ops::TileKernel, - ops::TileKernel, - ops::TileKernel, - ops::TileKernel); -REGISTER_OP_CPU_KERNEL( - tile_grad, ops::TileGradKernel, - ops::TileGradKernel, - ops::TileGradKernel, - ops::TileGradKernel); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -REGISTER_OP_CUDA_KERNEL( - tile, ops::TileKernel, - ops::TileKernel, - ops::TileKernel, - ops::TileKernel, - ops::TileKernel, - ops::TileKernel); -REGISTER_OP_CUDA_KERNEL( - tile_grad, ops::TileGradKernel, - ops::TileGradKernel, - ops::TileGradKernel, - ops::TileGradKernel, - ops::TileGradKernel); -#endif diff --git a/paddle/fluid/operators/tile_op.h b/paddle/fluid/operators/tile_op.h deleted file mode 100644 index 1698b5e3c6322e2cd9cbe7cf4839e2fc08627b32..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/tile_op.h +++ /dev/null @@ -1,306 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" - -#define MAX_RANK_SUPPORTED 6 - -namespace paddle { -namespace operators { -inline std::vector get_repeat_times( - const framework::ExecutionContext& ctx) { - if (ctx.HasInput("RepeatTimes")) { - auto* repeat_tensor = ctx.Input("RepeatTimes"); - auto* repeat_data = repeat_tensor->data(); - framework::Tensor cpu_repeat_tensor; - if (platform::is_gpu_place(repeat_tensor->place()) || - platform::is_xpu_place(repeat_tensor->place()) || - platform::is_npu_place(repeat_tensor->place())) { - paddle::framework::TensorCopySync(*repeat_tensor, platform::CPUPlace(), - &cpu_repeat_tensor); - repeat_data = cpu_repeat_tensor.data(); - } - auto vec_repeat_times = - std::vector(repeat_data, repeat_data + repeat_tensor->numel()); - return vec_repeat_times; - } - - auto list_repeat_times_tensor = - ctx.MultiInput("repeat_times_tensor"); - if (list_repeat_times_tensor.size() > 0) { - // get tensor from - std::vector vec_repeat_times; - for (size_t i = 0; i < list_repeat_times_tensor.size(); ++i) { - auto tensor = list_repeat_times_tensor[i]; - if (platform::is_gpu_place(tensor->place()) || - platform::is_xpu_place(tensor->place()) || - platform::is_npu_place(tensor->place())) { - framework::Tensor temp; - paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); - vec_repeat_times.push_back(*temp.data()); - } else { - vec_repeat_times.push_back(*tensor->data()); - } - } - return vec_repeat_times; - } else { - return ctx.Attr>("repeat_times"); - } -} - -using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; -template -using EigenTensor = framework::EigenTensor; -using framework::To32BitIndex; - -template -class TileKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto rank = context.Input("X")->dims().size(); - PADDLE_ENFORCE_GE( - rank, 1, platform::errors::InvalidArgument( - "The rank of the input 'x' for tile op must be a positive " - "integer, but the value received is %d.", - rank)); - PADDLE_ENFORCE_LE( - rank, MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The rank of the input 'x' for tile op " - "must be less than or equal to %d, but the value received is %d.", - MAX_RANK_SUPPORTED, rank)); - auto repeat_times = get_repeat_times(context); - int repeat_times_size = repeat_times.size(); - PADDLE_ENFORCE_GE( - repeat_times_size, 1, - platform::errors::InvalidArgument( - "The number of elements of the input 'repeat_times' for tile " - "op must be positive, but the value received is %d.", - repeat_times_size)); - PADDLE_ENFORCE_LE( - repeat_times_size, MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The number of elements of the input 'repeat_times' for tile op " - "must be less than or equal to %d, but the value received is %d.", - MAX_RANK_SUPPORTED, repeat_times_size)); - rank = std::max(rank, repeat_times_size); - switch (rank) { - case 1: - Tile<1>(context); - break; - case 2: - Tile<2>(context); - break; - case 3: - Tile<3>(context); - break; - case 4: - Tile<4>(context); - break; - case 5: - Tile<5>(context); - break; - case 6: - Tile<6>(context); - break; - } - } - - protected: - template - void Tile(const framework::ExecutionContext& context) const { - auto* in0 = context.Input("X"); - - auto in_dims = in0->dims(); - auto repeat_times = get_repeat_times(context); - for (size_t i = 0; i < repeat_times.size(); ++i) { - PADDLE_ENFORCE_GT( - repeat_times[i], 0, - platform::errors::InvalidArgument( - "All elements of the input 'repeat_times' for tile op must " - "be positive integers, but the value received is %d.", - repeat_times[i])); - } - auto vec_in_dims = phi::vectorize(in_dims); - if (repeat_times.size() < vec_in_dims.size()) { - int diff = vec_in_dims.size() - repeat_times.size(); - repeat_times.insert(repeat_times.begin(), diff, 1); - } else { - int diff = repeat_times.size() - vec_in_dims.size(); - vec_in_dims.insert(vec_in_dims.begin(), diff, 1); - } - PADDLE_ENFORCE_EQ( - repeat_times.size(), vec_in_dims.size(), - platform::errors::InvalidArgument( - "The rank (%d) of the input 'x' and the rank (%d) of the input " - "'repeat_times' for tile op must match after promotion.", - vec_in_dims.size(), repeat_times.size())); - auto* out0 = context.Output("Out"); - Eigen::DSizes bcast_dims; - for (size_t i = 0; i < repeat_times.size(); ++i) { - bcast_dims[i] = repeat_times[i]; - } - - framework::DDim new_in_dims = phi::make_ddim(vec_in_dims); - framework::DDim out_dims(new_in_dims); - for (size_t i = 0; i < repeat_times.size(); ++i) { - out_dims[i] *= repeat_times[i]; - } - - out0->Resize(out_dims); - auto x = EigenTensor::From(*in0, new_in_dims); - out0->mutable_data(context.GetPlace()); - auto y = EigenTensor::From(*out0, out_dims); - auto& place = - *context.template device_context().eigen_device(); - // use 32-bit index to speed up - bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); - if (use_32bit_index) { - EigenBroadcast, T, Rank>::Eval( - place, To32BitIndex(y), To32BitIndex(x), bcast_dims); - } else { - EigenBroadcast, T, Rank>::Eval(place, y, x, - bcast_dims); - } - } -}; - -template -class TileGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto repeat_times = get_repeat_times(context); - auto x_dims = x->dims(); - auto vec_in_dims = phi::vectorize(x_dims); - if (repeat_times.size() < vec_in_dims.size()) { - int diff = vec_in_dims.size() - repeat_times.size(); - repeat_times.insert(repeat_times.begin(), diff, 1); - } else { - int diff = repeat_times.size() - vec_in_dims.size(); - vec_in_dims.insert(vec_in_dims.begin(), diff, 1); - } - // 1. reshape_dims_vec is the broadcast parameter. - // 2. reduce_dims_vec is the dimension parameter to compute gradients. For - // each dimension expanded, the gradients should be summed to original - // size. - std::vector reshape_dims_vec; - std::vector reduce_dims_vec; - for (size_t i = 0; i < repeat_times.size(); ++i) { - reduce_dims_vec.push_back(reshape_dims_vec.size()); - reshape_dims_vec.push_back(repeat_times[i]); - reshape_dims_vec.push_back(vec_in_dims[i]); - } - - int dims = reduce_dims_vec.size(); - - bool just_copy = true; - for (size_t i = 0; i < repeat_times.size(); i++) { - if (repeat_times[i] != 1) { - just_copy = false; - break; - } - } - // no need reduce, just copy - if (just_copy) { - auto* dout = context.Input(framework::GradVarName("Out")); - auto* dx = context.Output(framework::GradVarName("X")); - dx->mutable_data(context.GetPlace()); - framework::TensorCopy(*dout, context.GetPlace(), context.device_context(), - dx); - // TensorCopy may change the dims of dx - dx->Resize(x_dims); - } else { - PADDLE_ENFORCE_GE(dims, 1, - platform::errors::InvalidArgument( - "Th rank of the input 'Out@GRAD' for tile_grad op " - " must be greater than or equal to 1, but " - "the value received is %d.", - dims)); - PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The rank of the input 'Out@GRAD' for tile_grad op " - "must be less than or equal " - "to %d, but the value received is %d.", - MAX_RANK_SUPPORTED, dims)); - switch (dims) { - case 1: - TileBackward<1>(context, reshape_dims_vec, reduce_dims_vec); - break; - case 2: - TileBackward<2>(context, reshape_dims_vec, reduce_dims_vec); - break; - case 3: - TileBackward<3>(context, reshape_dims_vec, reduce_dims_vec); - break; - case 4: - TileBackward<4>(context, reshape_dims_vec, reduce_dims_vec); - break; - case 5: - TileBackward<5>(context, reshape_dims_vec, reduce_dims_vec); - break; - case 6: - TileBackward<6>(context, reshape_dims_vec, reduce_dims_vec); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Only support tensor with rank being between 1 and 6. But " - "received tensor's rank = %d.", - dims)); - } - } - } - - protected: - template - void TileBackward(const framework::ExecutionContext& context, - const std::vector& reshape_dims_vec, - const std::vector& reduce_dims_vec) const { - size_t reshape_size = reshape_dims_vec.size(); - size_t reduce_size = reduce_dims_vec.size(); - auto* in0 = context.Input(framework::GradVarName("Out")); - auto* out0 = context.Output(framework::GradVarName("X")); - out0->mutable_data(context.GetPlace()); - auto x_grad = EigenVector::Flatten(*out0); - Eigen::DSizes reshape_dims; - for (size_t i = 0; i < reshape_size; ++i) { - reshape_dims[i] = reshape_dims_vec[i]; - } - Eigen::DSizes reduce_dims; - for (size_t i = 0; i < reduce_size; ++i) { - reduce_dims[i] = reduce_dims_vec[i]; - } - - auto out_grad = EigenVector::Flatten(*in0); - auto& place = - *context.template device_context().eigen_device(); - EigenBroadcastGrad, T, Dims>::Eval( - place, x_grad, out_grad, reduce_dims, reshape_dims); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/tile_op_functor.h b/paddle/fluid/operators/tile_op_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..95bfb9f4e1a9d374c66997567f5d80df8b5d8701 --- /dev/null +++ b/paddle/fluid/operators/tile_op_functor.h @@ -0,0 +1,67 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/operator.h" + +#define MAX_RANK_SUPPORTED 6 + +namespace paddle { +namespace operators { + +inline std::vector get_repeat_times( + const framework::ExecutionContext& ctx) { + if (ctx.HasInput("RepeatTimes")) { + auto* repeat_tensor = ctx.Input("RepeatTimes"); + auto* repeat_data = repeat_tensor->data(); + framework::Tensor cpu_repeat_tensor; + if (platform::is_gpu_place(repeat_tensor->place()) || + platform::is_xpu_place(repeat_tensor->place()) || + platform::is_npu_place(repeat_tensor->place())) { + paddle::framework::TensorCopySync(*repeat_tensor, platform::CPUPlace(), + &cpu_repeat_tensor); + repeat_data = cpu_repeat_tensor.data(); + } + auto vec_repeat_times = + std::vector(repeat_data, repeat_data + repeat_tensor->numel()); + return vec_repeat_times; + } + + auto list_repeat_times_tensor = + ctx.MultiInput("repeat_times_tensor"); + if (list_repeat_times_tensor.size() > 0) { + // get tensor from + std::vector vec_repeat_times; + for (size_t i = 0; i < list_repeat_times_tensor.size(); ++i) { + auto tensor = list_repeat_times_tensor[i]; + if (platform::is_gpu_place(tensor->place()) || + platform::is_xpu_place(tensor->place()) || + platform::is_npu_place(tensor->place())) { + framework::Tensor temp; + paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_repeat_times.push_back(*temp.data()); + } else { + vec_repeat_times.push_back(*tensor->data()); + } + } + return vec_repeat_times; + } else { + return ctx.Attr>("repeat_times"); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/tile_op_npu.cc b/paddle/fluid/operators/tile_op_npu.cc index 9e306c7be537bc7403812f4907541e1a9671c12a..cea6b458aec782923722cb37fe41c1c4d59292e5 100644 --- a/paddle/fluid/operators/tile_op_npu.cc +++ b/paddle/fluid/operators/tile_op_npu.cc @@ -11,7 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/tile_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/tile_op_functor.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/tile_op_xpu.cc b/paddle/fluid/operators/tile_op_xpu.cc index 6b60b167a2465fcb03d8ec088cfa288f9fb14af1..598377587d6f73e0c21abbc4d3819d16eacb1f23 100644 --- a/paddle/fluid/operators/tile_op_xpu.cc +++ b/paddle/fluid/operators/tile_op_xpu.cc @@ -11,11 +11,14 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/tile_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/tile_op_functor.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class TileXPUKernel : public framework::OpKernel { public: diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c26af34f771d5c1e2476489fc40ad0844933bad5..d6d4efad9fae26e6cbdb914752d8c24ab23d948c 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -395,6 +395,74 @@ void MultinomialInferMeta(const MetaTensor& x, out->set_dtype(DataType::INT64); } +void TileInferMeta(const MetaTensor& x, + const ScalarArray& repeat_times, + MetaTensor* out, + MetaConfig config) { +#define MAX_RANK_SUPPORTED 6 + + auto repeat_times_data = repeat_times.GetData(); + auto x_dims = x.dims(); + if (repeat_times_data.size() == 0) { + repeat_times_data = std::vector(x_dims.size(), -1); + } + + PADDLE_ENFORCE_LE( + x_dims.size(), + MAX_RANK_SUPPORTED, + errors::InvalidArgument( + "The rank of the input 'x' for tile op " + "must not be greater than %d, but the value received is %d.", + MAX_RANK_SUPPORTED, + x_dims.size())); + PADDLE_ENFORCE_LE( + repeat_times_data.size(), + MAX_RANK_SUPPORTED, + errors::InvalidArgument( + "The size of the shape of input 'repeat_times' for tile op " + "must not be greater than %d, but the value received is %d.", + MAX_RANK_SUPPORTED, + repeat_times_data.size())); + PADDLE_ENFORCE_GE( + repeat_times_data.size(), + 1, + errors::InvalidArgument( + "The size of the shape of input 'repeat_times' for tile op " + "must be positive integers, but the value received is %d.", + repeat_times_data.size())); + + auto out_rank = + std::max(static_cast(x_dims.size()), repeat_times_data.size()); + std::vector out_shape(out_rank); + auto x_dim_vec = phi::vectorize(x_dims); + if (x_dim_vec.size() > repeat_times_data.size()) { + auto diff = x_dim_vec.size() - repeat_times_data.size(); + repeat_times_data.insert(repeat_times_data.begin(), diff, -1); + } else { + auto diff = repeat_times_data.size() - x_dim_vec.size(); + x_dim_vec.insert(x_dim_vec.begin(), diff, -1); + } + for (size_t i = 0; i < repeat_times_data.size(); ++i) { + if (x_dim_vec[i] == -1 || repeat_times_data[i] == -1) { + out_shape[i] = -1; + } else { + PADDLE_ENFORCE_GT( + repeat_times_data[i], + 0, + errors::InvalidArgument( + "Every element of the input 'repeat_times' for tile op must be " + "greater than 0, but the value given is %d.", + repeat_times_data[i])); + out_shape[i] = x_dim_vec[i] * repeat_times_data[i]; + } + } + + out->set_dims(phi::make_ddim(out_shape)); + if (out_shape[0] == x_dims[0]) { + out->share_lod(x); + } +} + void ReshapeInferMeta(const MetaTensor& x, const ScalarArray& shape, MetaTensor* out, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 59ee613b8b0eb0f7650d00b4ab04a68e1f3b7027..e8be73e943e09c9794376945cc904fe6f2a3d324 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -100,6 +100,11 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void TileInferMeta(const MetaTensor& x, + const ScalarArray& repeat_times, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void SumRawInferMeta(const MetaTensor& x, const std::vector& axis, bool keep_dim, diff --git a/paddle/phi/kernels/cpu/tile_grad_kernel.cc b/paddle/phi/kernels/cpu/tile_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..636ade93742da11cf5e875adb27f72930c9a6686 --- /dev/null +++ b/paddle/phi/kernels/cpu/tile_grad_kernel.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/tile_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/tile_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(tile_grad, + CPU, + ALL_LAYOUT, + phi::TileGradKernel, + bool, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/tile_kernel.cc b/paddle/phi/kernels/cpu/tile_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b590ed475aa2611d75a0426294cf5e8d30ab6fa --- /dev/null +++ b/paddle/phi/kernels/cpu/tile_kernel.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/tile_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/tile_kernel_impl.h" + +PD_REGISTER_KERNEL( + tile, CPU, ALL_LAYOUT, phi::TileKernel, bool, float, double, int, int64_t) { +} diff --git a/paddle/phi/kernels/gpu/tile_grad_kernel.cu b/paddle/phi/kernels/gpu/tile_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c092609e623d3f4f3dc4b3d77b1c973e6ddfbcf3 --- /dev/null +++ b/paddle/phi/kernels/gpu/tile_grad_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/tile_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/tile_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(tile_grad, + GPU, + ALL_LAYOUT, + phi::TileGradKernel, + bool, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/tile_kernel.cu b/paddle/phi/kernels/gpu/tile_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0c3c29e82c42aefae33a4a9be9e9a7d9ec0c1e99 --- /dev/null +++ b/paddle/phi/kernels/gpu/tile_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/tile_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/tile_kernel_impl.h" + +PD_REGISTER_KERNEL(tile, + GPU, + ALL_LAYOUT, + phi::TileKernel, + bool, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/tile_grad_kernel_impl.h b/paddle/phi/kernels/impl/tile_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..a2c2720244fe8a714e0e866dd6178a1a238728b7 --- /dev/null +++ b/paddle/phi/kernels/impl/tile_grad_kernel_impl.h @@ -0,0 +1,147 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/tile_grad_kernel.h" + +namespace phi { + +template +void TileBackward(const Context& dev_ctx, + const DenseTensor& out_grad, + const std::vector& reshape_dims_vec, + const std::vector& reduce_dims_vec, + DenseTensor* x_grad) { + size_t reshape_size = reshape_dims_vec.size(); + size_t reduce_size = reduce_dims_vec.size(); + dev_ctx.template Alloc(x_grad); + + auto eigen_x_grad = EigenVector::Flatten(*x_grad); + Eigen::DSizes reshape_dims; + for (size_t i = 0; i < reshape_size; ++i) { + reshape_dims[i] = reshape_dims_vec[i]; + } + Eigen::DSizes reduce_dims; + for (size_t i = 0; i < reduce_size; ++i) { + reduce_dims[i] = reduce_dims_vec[i]; + } + + auto eigen_out_grad = EigenVector::Flatten(out_grad); + auto& place = *dev_ctx.eigen_device(); + funcs::EigenBroadcastGrad, T, Dims>::Eval( + place, eigen_x_grad, eigen_out_grad, reduce_dims, reshape_dims); +} + +template +void TileGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const ScalarArray& repeat_times, + DenseTensor* x_grad) { + auto x_dims = x.dims(); + auto vec_x_dims = phi::vectorize(x_dims); + auto repeat_times_data = repeat_times.GetData(); + if (repeat_times_data.size() < vec_x_dims.size()) { + int diff = vec_x_dims.size() - repeat_times_data.size(); + repeat_times_data.insert(repeat_times_data.begin(), diff, 1); + } else { + int diff = repeat_times_data.size() - vec_x_dims.size(); + vec_x_dims.insert(vec_x_dims.begin(), diff, 1); + } + // 1. reshape_dims_vec is the broadcast parameter. + // 2. reduce_dims_vec is the dimension parameter to compute gradients. For + // each dimension expanded, the gradients should be summed to original + // size. + std::vector reshape_dims_vec; + std::vector reduce_dims_vec; + for (size_t i = 0; i < repeat_times_data.size(); ++i) { + reduce_dims_vec.push_back(reshape_dims_vec.size()); + reshape_dims_vec.push_back(repeat_times_data[i]); + reshape_dims_vec.push_back(vec_x_dims[i]); + } + + int dims = reduce_dims_vec.size(); + + bool just_copy = true; + for (size_t i = 0; i < repeat_times_data.size(); i++) { + if (repeat_times_data[i] != 1) { + just_copy = false; + break; + } + } + // no need reduce, just copy + if (just_copy) { + dev_ctx.template Alloc(x_grad); + + paddle::framework::TensorCopy( + out_grad, dev_ctx.GetPlace(), dev_ctx, x_grad); + // TensorCopy may change the dims of dx + x_grad->Resize(x_dims); + } else { + PADDLE_ENFORCE_GE(dims, + 1, + errors::InvalidArgument( + "Th rank of the input 'Out@GRAD' for tile_grad op " + " must be greater than or equal to 1, but " + "the value received is %d.", + dims)); + PADDLE_ENFORCE_LE(dims, + MAX_RANK_SUPPORTED, + errors::InvalidArgument( + "The rank of the input 'Out@GRAD' for tile_grad op " + "must be less than or equal " + "to %d, but the value received is %d.", + MAX_RANK_SUPPORTED, + dims)); + switch (dims) { + case 1: + TileBackward( + dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad); + break; + case 2: + TileBackward( + dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad); + break; + case 3: + TileBackward( + dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad); + break; + case 4: + TileBackward( + dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad); + break; + case 5: + TileBackward( + dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad); + break; + case 6: + TileBackward( + dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad); + break; + default: + PADDLE_THROW(errors::InvalidArgument( + "Only support tensor with rank being between 1 and 6. But " + "received tensor's rank = %d.", + dims)); + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/tile_kernel_impl.h b/paddle/phi/kernels/impl/tile_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..bafbbde4e680daa04a62ba9ab6b03778b4477107 --- /dev/null +++ b/paddle/phi/kernels/impl/tile_kernel_impl.h @@ -0,0 +1,117 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/tile_kernel.h" + +namespace phi { + +template +void Tile(const Context& dev_ctx, + const DenseTensor& x, + std::vector repeat_times, + DenseTensor* out) { + auto x_dims = x.dims(); + for (size_t i = 0; i < repeat_times.size(); ++i) { + PADDLE_ENFORCE_GT( + repeat_times[i], + 0, + errors::InvalidArgument( + "All elements of the input 'repeat_times' for tile op must " + "be positive integers, but the value received is %d.", + repeat_times[i])); + } + auto vec_x_dims = phi::vectorize(x_dims); + if (repeat_times.size() < vec_x_dims.size()) { + int diff = vec_x_dims.size() - repeat_times.size(); + repeat_times.insert(repeat_times.begin(), diff, 1); + } else { + int diff = repeat_times.size() - vec_x_dims.size(); + vec_x_dims.insert(vec_x_dims.begin(), diff, 1); + } + PADDLE_ENFORCE_EQ( + repeat_times.size(), + vec_x_dims.size(), + errors::InvalidArgument( + "The rank (%d) of the input 'x' and the rank (%d) of the input " + "'repeat_times' for tile op must match after promotion.", + vec_x_dims.size(), + repeat_times.size())); + + Eigen::DSizes bcast_dims; + for (size_t i = 0; i < repeat_times.size(); ++i) { + bcast_dims[i] = repeat_times[i]; + } + + DDim new_x_dims = make_ddim(vec_x_dims); + DDim out_dims(new_x_dims); + for (size_t i = 0; i < repeat_times.size(); ++i) { + out_dims[i] *= repeat_times[i]; + } + + out->Resize(out_dims); + auto eigen_x = EigenTensor::From(x, new_x_dims); + dev_ctx.template Alloc(out); + + auto eigen_out = EigenTensor::From(*out, out_dims); + auto& place = *dev_ctx.eigen_device(); + // use 32-bit index to speed up + bool use_32bit_index = eigen_out.size() < Eigen::NumTraits::highest(); + if (use_32bit_index) { + funcs::EigenBroadcast, T, Rank>::Eval( + place, To32BitIndex(eigen_out), To32BitIndex(eigen_x), bcast_dims); + } else { + funcs::EigenBroadcast, T, Rank>::Eval( + place, eigen_out, eigen_x, bcast_dims); + } +} + +template +void TileKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& repeat_times, + DenseTensor* out) { + auto rank = x.dims().size(); + auto& repeat_times_data = repeat_times.GetData(); + int repeat_times_size = repeat_times_data.size(); + rank = std::max(rank, repeat_times_size); + + switch (rank) { + case 1: + Tile(dev_ctx, x, repeat_times_data, out); + break; + case 2: + Tile(dev_ctx, x, repeat_times_data, out); + break; + case 3: + Tile(dev_ctx, x, repeat_times_data, out); + break; + case 4: + Tile(dev_ctx, x, repeat_times_data, out); + break; + case 5: + Tile(dev_ctx, x, repeat_times_data, out); + break; + case 6: + Tile(dev_ctx, x, repeat_times_data, out); + break; + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/tile_grad_kernel.h b/paddle/phi/kernels/tile_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..830276c28e05395c36854426e560fbe6f2642589 --- /dev/null +++ b/paddle/phi/kernels/tile_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +#define MAX_RANK_SUPPORTED 6 + +namespace phi { + +template +void TileGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const ScalarArray& repeat_times, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/tile_kernel.h b/paddle/phi/kernels/tile_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..924d0149fe34598753f0d02cfa3b087e2e7493f2 --- /dev/null +++ b/paddle/phi/kernels/tile_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +#define MAX_RANK_SUPPORTED 6 + +namespace phi { + +template +void TileKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& repeat_times, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/tile_sig.cc b/paddle/phi/ops/compat/tile_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..49a6d02225d931f1dc2d3324cb13c2c620f5dfe6 --- /dev/null +++ b/paddle/phi/ops/compat/tile_sig.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature TileOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("RepeatTimes")) { + return KernelSignature("tile", {"X"}, {"RepeatTimes"}, {"Out"}); + } else if (ctx.InputSize("repeat_times_tensor") > 0) { + return KernelSignature("tile", {"X"}, {"repeat_times_tensor"}, {"Out"}); + } else { + return KernelSignature("tile", {"X"}, {"repeat_times"}, {"Out"}); + } +} + +KernelSignature TileGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("RepeatTimes")) { + return KernelSignature("tile_grad", + {"X", GradVarName("Out")}, + {"RepeatTimes"}, + {GradVarName("X")}); + } else if (ctx.InputSize("repeat_times_tensor") > 0) { + return KernelSignature("tile_grad", + {"X", GradVarName("Out")}, + {"repeat_times_tensor"}, + {GradVarName("X")}); + } else { + return KernelSignature("tile_grad", + {"X", GradVarName("Out")}, + {"repeat_times"}, + {GradVarName("X")}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(tile, phi::TileOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(tile_grad, phi::TileGradOpArgumentMapping);