diff --git a/paddle/fluid/operators/crop_tensor_op.cc b/paddle/fluid/operators/crop_tensor_op.cc index f72175d4d53387da1d303bba0ef9dd253df28613..52106c74314a461fc65dbafe522110a8b2f95997 100644 --- a/paddle/fluid/operators/crop_tensor_op.cc +++ b/paddle/fluid/operators/crop_tensor_op.cc @@ -12,11 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/crop_tensor_op.h" +#include "paddle/fluid/framework/op_registry.h" -#include -#include -#include +// TODO(freeliuzc): Delete old infershape +// New infershape has already in unary.h and backward.h namespace paddle { namespace operators { @@ -297,8 +296,8 @@ class CropTensorGradOpMaker : public framework::SingleGradOpMaker { protected: void Apply(GradOpPtr op) const override { op->SetType("crop_tensor_grad"); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); if (this->HasInput("OffsetsTensor")) { op->SetInput("OffsetsTensor", this->Input("OffsetsTensor")); } @@ -314,32 +313,10 @@ class CropTensorGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; + REGISTER_OPERATOR(crop_tensor, ops::CropTensorOp, ops::CropTensorOpMaker, ops::CropTensorGradOpMaker, ops::CropTensorGradOpMaker); REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad); -REGISTER_OP_CPU_KERNEL(crop_tensor, - ops::CropTensorKernel, - ops::CropTensorKernel, - ops::CropTensorKernel, - ops::CropTensorKernel); -REGISTER_OP_CPU_KERNEL(crop_tensor_grad, - ops::CropTensorGradKernel, - ops::CropTensorGradKernel, - ops::CropTensorGradKernel, - ops::CropTensorGradKernel); - -REGISTER_OP_CUDA_KERNEL( - crop_tensor, - ops::CropTensorKernel, - ops::CropTensorKernel, - ops::CropTensorKernel, - ops::CropTensorKernel); -REGISTER_OP_CUDA_KERNEL( - crop_tensor_grad, - ops::CropTensorGradKernel, - ops::CropTensorGradKernel, - ops::CropTensorGradKernel, - ops::CropTensorGradKernel); diff --git a/paddle/fluid/operators/crop_tensor_op.h b/paddle/fluid/operators/crop_tensor_op.h deleted file mode 100644 index afaae4d0ac3cde7ee688e8dc234a5758723e12f7..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/crop_tensor_op.h +++ /dev/null @@ -1,350 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/strided_memcpy.h" - -namespace paddle { -namespace operators { // Internal - -template -using EigenTensor = framework::EigenTensor; -using framework::Tensor; - -inline std::vector get_new_data( - const std::vector& list_new_tensor) { - // get tensor from - std::vector vec_new_data; - for (size_t i = 0; i < list_new_tensor.size(); ++i) { - auto tensor = list_new_tensor[i]; - PADDLE_ENFORCE_EQ( - tensor->dims(), - phi::make_ddim({1}), - platform::errors::InvalidArgument( - "The tensor's shape in list of Op(crop_tensor) should be [1], " - "but the value received is %d.", - tensor->dims())); - if (platform::is_gpu_place(tensor->place())) { - framework::Tensor temp; - paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); - - vec_new_data.push_back(static_cast(*temp.data())); - } else { - vec_new_data.push_back(static_cast(*tensor->data())); - } - } - - return vec_new_data; -} - -static framework::DDim ValidateShape(const std::vector shape, - const std::vector offsets, - const framework::DDim& in_dims) { - auto in_dim_size = in_dims.size(); - auto shape_size = shape.size(); - PADDLE_ENFORCE_EQ( - in_dim_size, - shape_size, - platform::errors::InvalidArgument( - "The number of elements (%d) for shape of Op(crop_tensor) should be " - "equal to the number of dimensions (%d) of the input tensor.", - shape_size, - in_dim_size)); - std::vector output_shape(shape.size(), 0); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] <= 0 && in_dims[i] > 0) { - PADDLE_ENFORCE_NE(shape[i], - 0, - platform::errors::InvalidArgument( - "The value (%d) of the %uth element for shape of " - "Op(crop_tensor) should not be zero.", - shape[i], - i)); - PADDLE_ENFORCE_EQ(shape[i], - -1, - platform::errors::InvalidArgument( - "When the value (%d) of the %uth " - "element for shape of Op(crop_tensor)" - " is negative, only -1 is supported.", - shape[i], - i)); - output_shape[i] = in_dims[i] - offsets[i]; - } else { - output_shape[i] = static_cast(shape[i]); - } - } - - return phi::make_ddim(output_shape); -} - -static std::vector GetShape(const framework::ExecutionContext& ctx) { - std::vector res; - int rank = ctx.Input("X")->dims().size(); - auto list_new_shape_tensor = ctx.MultiInput("ShapeTensor"); - if (list_new_shape_tensor.size() > 0) { - // have offsets tensor list - PADDLE_ENFORCE_EQ( - list_new_shape_tensor.size(), - rank, - platform::errors::InvalidArgument( - "The number of tensors (%d) for the input ShapeTensor of " - "Op(crop_tensor) must be equal to the number of " - "dimensions (%d) of the input.", - list_new_shape_tensor.size(), - rank)); - res = get_new_data(list_new_shape_tensor); - - return res; - } - - auto* shape_tensor = ctx.HasInput("Shape") - ? ctx.Input("Shape") - : nullptr; - if (shape_tensor) { - auto* shape_data = shape_tensor->data(); - framework::Tensor cpu_shape_tensor; - if (platform::is_gpu_place(shape_tensor->place())) { - paddle::framework::TensorCopySync( - *shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); - shape_data = cpu_shape_tensor.data(); - } - res = std::vector(shape_data, shape_data + shape_tensor->numel()); - } - - return res; -} - -static std::vector GetOffsets(const framework::ExecutionContext& ctx) { - std::vector res; - int rank = ctx.Input("X")->dims().size(); - auto list_new_offsets_tensor = - ctx.MultiInput("OffsetsTensor"); - if (list_new_offsets_tensor.size() > 0) { - // have offsets tensor list - res = get_new_data(list_new_offsets_tensor); - - return res; - } - - if (ctx.HasInput("Offsets")) { - const auto* offsets_tensor = ctx.Input("Offsets"); - PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), - 1, - platform::errors::InvalidArgument( - "The number of dimensions of input 'Offsets' must " - "be 1, but the value received is: %d.", - offsets_tensor->dims().size())); - PADDLE_ENFORCE_EQ(rank, - offsets_tensor->dims()[0], - platform::errors::InvalidArgument( - "The number of elements (%d) for " - "input 'Offsets' must be equal to " - "the number of dimensions (%d) of the input tensor.", - offsets_tensor->dims()[0], - rank)); - - const int* offsets_data; - framework::Tensor cpu_tmp_tensor; - if (platform::is_cpu_place(offsets_tensor->place())) { - offsets_data = offsets_tensor->data(); - } else { - framework::TensorCopySync( - *offsets_tensor, platform::CPUPlace(), &cpu_tmp_tensor); - offsets_data = cpu_tmp_tensor.data(); - } - res = std::vector(offsets_data, offsets_data + rank); - } else { - res = ctx.Attr>("offsets"); - PADDLE_ENFORCE_EQ( - rank, - static_cast(res.size()), - platform::errors::InvalidArgument("The number of elements (%d) for " - "input 'Offsets' must be equal to " - "the number of dimensions (%d) " - "of the input tensor.", - static_cast(res.size()), - rank)); - } - return res; -} - -template -void CropTensorFunction(const framework::ExecutionContext& context) { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - auto x_dims = x->dims(); - auto out_dims = out->dims(); - - // get shape from Input(ShapeTensor) of Input(Shape) - std::vector shape = GetShape(context); - // out_dims set by arrt(shape) - if (shape.size() == 0) { - for (int i = 0; i < out_dims.size(); ++i) { - shape.push_back(out_dims[i]); - } - } - - auto offsets = GetOffsets(context); - out_dims = ValidateShape(shape, offsets, x->dims()); - out->mutable_data(out_dims, context.GetPlace()); - for (size_t i = 0; i < offsets.size(); ++i) { - PADDLE_ENFORCE_LE(offsets[i] + shape[i], - x_dims[i], - platform::errors::InvalidArgument( - "The sum of the %uth elements of " - "offsets (%d) and shape (%d) of Op(crop_tensor) " - "should be less than or " - "equal to the size of %uth dimension of the input.", - i, - offsets[i], - shape[i], - i)); - } - - auto x_tensor = EigenTensor::From(*x); - auto out_tensor = EigenTensor::From(*out); - Eigen::DSizes e_offsets; - Eigen::DSizes e_shape; - for (size_t i = 0; i < D; ++i) { - e_offsets[i] = offsets[i]; - e_shape[i] = out->dims()[i]; - } - auto& place = - *context.template device_context().eigen_device(); - EigenSlice, T, D>::Eval( - place, out_tensor, x_tensor, e_offsets, e_shape); -} - -template -class CropTensorKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - int rank = context.Input("X")->dims().size(); - PADDLE_ENFORCE_GE( - rank, - 1, - platform::errors::InvalidArgument( - "The number of dimensions of the input 'x' for " - "Op(crop_tensor) must be greater than or equal to 1, but the " - "value received is %d.", - rank)); - PADDLE_ENFORCE_LE( - rank, - 6, - platform::errors::InvalidArgument( - "The number of dimensions of the input 'x' for " - "Op(crop_tensor) must be less than or equal to 6, but the " - "value received is %d.", - rank)); - switch (rank) { - case 1: - CropTensorFunction(context); - break; - case 2: - CropTensorFunction(context); - break; - case 3: - CropTensorFunction(context); - break; - case 4: - CropTensorFunction(context); - break; - case 5: - CropTensorFunction(context); - break; - case 6: - CropTensorFunction(context); - break; - } - } -}; - -template -void CropTensorGradFunction(const framework::ExecutionContext& context) { - auto* d_x = context.Output(framework::GradVarName("X")); - auto* x = context.Input("X"); - if (d_x != nullptr) { - auto* d_out = context.Input(framework::GradVarName("Out")); - d_x->mutable_data(x->dims(), context.GetPlace()); - auto offsets = GetOffsets(context); - Eigen::array, D> paddings; - for (size_t i = 0; i < D; ++i) { - paddings[i].first = offsets[i]; - paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i]; - } - auto d_x_tensor = EigenTensor::From(*d_x); - auto d_out_tensor = EigenTensor::From(*d_out); - auto& place = - *context.template device_context().eigen_device(); - EigenPad, T, D>::Eval( - place, d_x_tensor, d_out_tensor, paddings, static_cast(0)); - } -} - -template -class CropTensorGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - size_t rank = - context.Input(framework::GradVarName("Out"))->dims().size(); - PADDLE_ENFORCE_GE( - rank, - 1, - platform::errors::InvalidArgument( - "The number of dimensions of the input 'Out@GRAD' for " - "Op(crop_tensor_grad) must be greater than or equal to 1, but the " - "value received is %d.", - rank)); - PADDLE_ENFORCE_LE( - rank, - 6, - platform::errors::InvalidArgument( - "The number of dimensions of the input 'Out@GRAD' for " - "Op(crop_tensor_grad) must be less than or equal to 6, but the " - "value received is %d.", - rank)); - switch (rank) { - case 1: - CropTensorGradFunction(context); - break; - case 2: - CropTensorGradFunction(context); - break; - case 3: - CropTensorGradFunction(context); - break; - case 4: - CropTensorGradFunction(context); - break; - case 5: - CropTensorGradFunction(context); - break; - case 6: - CropTensorGradFunction(context); - break; - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 77c58816de694e7990207bc325c380f1373d55d6..6a4afd3d0626bc4b6feec563375163a079360606 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -513,6 +513,16 @@ func : cosh backward : cosh_grad +- api : crop_tensor + args : (Tensor x, IntArray shape, IntArray offsets) + output : Tensor(out) + infer_meta : + func : CropTensorInferMeta + kernel : + func : crop_tensor + data_type : x + backward : crop_tensor_grad + # Part of python API paddle.nn.functional.cross_entropy - api : cross_entropy_with_softmax args : (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 310cf7c151ff22e7babcfa9f7f4cc993866d5cb5..9d73c044dbac95146a77eb56a6fe503e192537c6 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -481,6 +481,16 @@ func : cosh_grad inplace : (out_grad -> x_grad) +- backward_api : crop_tensor_grad + forward : crop_tensor (Tensor x, IntArray shape, IntArray offsets) -> Tensor(out) + args : (Tensor x, Tensor out_grad, IntArray offsets) + output : Tensor(x_grad) + infer_meta : + func : CropTensorGradInferMeta + kernel : + func : crop_tensor_grad + data_type : x + - backward_api : cross_entropy_with_softmax_grad forward : cross_entropy_with_softmax (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) -> Tensor(softmax), Tensor(loss) args : (Tensor label, Tensor softmax, Tensor loss_grad, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a33b9587c153ca2f7d67966a3a1a3838ebece5ee..bfae939820ead6c2c88d5b21ad3b25f34670aaa3 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -156,6 +156,18 @@ void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x, } } +void CropTensorGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& x, + const IntArray& offsets, + MetaTensor* x_grad) { + auto x_dims = x.dims(); + + if (x_grad != nullptr) { + x_grad->set_dims(x_dims); + x_grad->set_dtype(x.dtype()); + } +} + void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, const MetaTensor& softmax, const MetaTensor& loss_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 5551b6bcbf183b789c2bea7698e66cce1f933bfe..16d9b82e0644233d1c1deb027acb94766e7d08bd 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -89,6 +89,11 @@ void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x, MetaTensor* dfilter, MetaTensor* ddout); +void CropTensorGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& x, + const IntArray& offsets, + MetaTensor* x_grad); + void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, const MetaTensor& softmax, const MetaTensor& loss_grad, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c018e58a59a37b537fe55f11ce346e516b0390a9..3b31b165b42595346b0ed1d61546c94b935de4a6 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -300,6 +300,47 @@ void CumInferMeta(const MetaTensor& x, out->share_lod(x); } +void CropTensorInferMeta(const MetaTensor& x, + const IntArray& shape, + const IntArray& offsets, + MetaTensor* out, + MetaConfig config) { + PADDLE_ENFORCE_NE( + out, + nullptr, + errors::InvalidArgument("CropTensor should have output tensor out.")); + + auto x_dim = x.dims(); + auto shape_dims = shape.GetData(); + auto offsets_vec = offsets.GetData(); + + PADDLE_ENFORCE_EQ(shape_dims.size(), + x_dim.size(), + errors::InvalidArgument( + "The number of elements (%d) of attribute 'shape' for " + "CropTensor must be equal to the number of " + "dimensions (%d) of the input.", + shape_dims.size(), + x_dim.size())); + + if (config.is_runtime) { + out->share_lod(x); + } + + auto out_dims = std::vector(shape.size(), -1); + for (size_t i = 0; i < shape_dims.size(); ++i) { + if (shape_dims[i] > 0) { + out_dims[i] = static_cast(shape_dims[i]); + } else { + if (shape_dims[i] == -1 && offsets_vec[i] != -1 && x_dim[i] != -1) { + out_dims[i] = x_dim[i] - static_cast(offsets_vec[i]); + } + } + } + out->set_dims(phi::make_ddim(out_dims)); + out->set_dtype(x.dtype()); +} + void DiagEmbedInferMeta( const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out) { auto x_dims = x.dims(); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 1449e8cfe197d190dce3ae2e47c6f4678e6095d8..c1db2561f0bef08ada289b561a82c07075658a03 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -66,6 +66,12 @@ void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out); void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); +void CropTensorInferMeta(const MetaTensor& x, + const IntArray& shape, + const IntArray& offsets, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void CumInferMeta(const MetaTensor& x, int axis, bool flatten, diff --git a/paddle/phi/kernels/cpu/crop_tensor_grad_kernel.cc b/paddle/phi/kernels/cpu/crop_tensor_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ac553ec9786b0975e2af0b602cb86f128077dc3 --- /dev/null +++ b/paddle/phi/kernels/cpu/crop_tensor_grad_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/crop_tensor_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/crop_tensor_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(crop_tensor_grad, + CPU, + ALL_LAYOUT, + phi::CropTensorGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/crop_tensor_kernel.cc b/paddle/phi/kernels/cpu/crop_tensor_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cd42d5fa82395274db0c8ee8c036c1e1ff2a51b --- /dev/null +++ b/paddle/phi/kernels/cpu/crop_tensor_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/crop_tensor_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/crop_tensor_kernel_impl.h" + +PD_REGISTER_KERNEL(crop_tensor, + CPU, + ALL_LAYOUT, + phi::CropTensorKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/crop_tensor_grad_kernel.h b/paddle/phi/kernels/crop_tensor_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..97f1fbf5b029aa0ad226afe912dc97525693392f --- /dev/null +++ b/paddle/phi/kernels/crop_tensor_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CropTensorGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const IntArray& offsets, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/crop_tensor_kernel.h b/paddle/phi/kernels/crop_tensor_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..079959eb05c145a42936a0be1ccd623a8b06b64e --- /dev/null +++ b/paddle/phi/kernels/crop_tensor_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CropTensorKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& shape, + const IntArray& offsets, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/crop_tensor_grad_kernel.cu b/paddle/phi/kernels/gpu/crop_tensor_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0af80233cb1ef84061f2c3ef1e26123c52734df9 --- /dev/null +++ b/paddle/phi/kernels/gpu/crop_tensor_grad_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/crop_tensor_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/crop_tensor_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(crop_tensor_grad, + GPU, + ALL_LAYOUT, + phi::CropTensorGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/crop_tensor_kernel.cu b/paddle/phi/kernels/gpu/crop_tensor_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..5aa4900c5097b54203767d013d929360b75a4c9e --- /dev/null +++ b/paddle/phi/kernels/gpu/crop_tensor_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/crop_tensor_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/crop_tensor_kernel_impl.h" + +PD_REGISTER_KERNEL(crop_tensor, + GPU, + ALL_LAYOUT, + phi::CropTensorKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/impl/crop_tensor_grad_kernel_impl.h b/paddle/phi/kernels/impl/crop_tensor_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..0d3e579fe8bc8a9a38cd85afad2e55bf36160b27 --- /dev/null +++ b/paddle/phi/kernels/impl/crop_tensor_grad_kernel_impl.h @@ -0,0 +1,105 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/crop_tensor_grad_kernel.h" + +#include + +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void CropTensorGradFunction(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const IntArray& offsets, + DenseTensor* x_grad) { + if (x_grad != nullptr) { + x_grad->Resize(x.dims()); + dev_ctx.template Alloc(x_grad); + + auto offsets_vec = offsets.GetData(); + std::array, D> paddings; + for (size_t i = 0; i < D; ++i) { + paddings[i].first = offsets_vec[i]; + paddings[i].second = + x_grad->dims()[i] - out_grad.dims()[i] - offsets_vec[i]; + } + auto x_grad_tensor = EigenTensor::From(*x_grad); + auto out_grad_tensor = EigenTensor::From(out_grad); + auto& place = *dev_ctx.eigen_device(); + + funcs::EigenPad, T, D>::Eval( + place, x_grad_tensor, out_grad_tensor, paddings, static_cast(0)); + } +} + +template +void CropTensorGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const IntArray& offsets, + DenseTensor* x_grad) { + size_t rank = out_grad.dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1, + errors::InvalidArgument( + "The number of dimensions of the input 'Out@GRAD' for " + "Op(crop_tensor_grad) must be greater than or equal to 1, but the " + "value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, + 6, + errors::InvalidArgument( + "The number of dimensions of the input 'Out@GRAD' for " + "Op(crop_tensor_grad) must be less than or equal to 6, but the " + "value received is %d.", + rank)); + switch (rank) { + case 1: + CropTensorGradFunction( + dev_ctx, out_grad, x, offsets, x_grad); + break; + case 2: + CropTensorGradFunction( + dev_ctx, out_grad, x, offsets, x_grad); + break; + case 3: + CropTensorGradFunction( + dev_ctx, out_grad, x, offsets, x_grad); + break; + case 4: + CropTensorGradFunction( + dev_ctx, out_grad, x, offsets, x_grad); + break; + case 5: + CropTensorGradFunction( + dev_ctx, out_grad, x, offsets, x_grad); + break; + case 6: + CropTensorGradFunction( + dev_ctx, out_grad, x, offsets, x_grad); + break; + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/crop_tensor_kernel_impl.h b/paddle/phi/kernels/impl/crop_tensor_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..e6d7f8f67265936d8443311e6fdd2735abc1a340 --- /dev/null +++ b/paddle/phi/kernels/impl/crop_tensor_kernel_impl.h @@ -0,0 +1,174 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/crop_tensor_kernel.h" + +#include +#include + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +static phi::DDim ValidateShape(const std::vector& shape, + const std::vector& offsets, + const phi::DDim& in_dims) { + auto in_dim_size = in_dims.size(); + auto shape_size = shape.size(); + PADDLE_ENFORCE_EQ( + in_dim_size, + shape_size, + errors::InvalidArgument( + "The number of elements (%d) for shape of Op(crop_tensor) should be " + "equal to the number of dimensions (%d) of the input tensor.", + shape_size, + in_dim_size)); + std::vector output_shape(shape.size(), 0); + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] <= 0 && in_dims[i] > 0) { + PADDLE_ENFORCE_NE(shape[i], + 0, + errors::InvalidArgument( + "The value (%d) of the %uth element for shape of " + "Op(crop_tensor) should not be zero.", + shape[i], + i)); + PADDLE_ENFORCE_EQ( + shape[i], + -1, + errors::InvalidArgument("When the value (%d) of the %uth " + "element for shape of Op(crop_tensor)" + " is negative, only -1 is supported.", + shape[i], + i)); + output_shape[i] = in_dims[i] - offsets[i]; + } else { + output_shape[i] = static_cast(shape[i]); + } + } + + return phi::make_ddim(output_shape); +} + +template +void CropTensorFunction(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& shape, + const IntArray& offsets, + DenseTensor* out) { + auto x_dims = x.dims(); + auto rank = x.dims().size(); + auto out_dims = out->dims(); + + auto shape_vec = shape.GetData(); + + if (shape_vec.size() == 0) { + for (int i = 0; i < out_dims.size(); ++i) { + shape_vec.push_back(out_dims[i]); + } + } + + auto offsets_vec = offsets.GetData(); + + PADDLE_ENFORCE_EQ( + rank, + static_cast(offsets_vec.size()), + errors::InvalidArgument("The number of elements (%d) for " + "input 'Offsets' must be equal to " + "the number of dimensions (%d) " + "of the input tensor.", + static_cast(offsets_vec.size()), + rank)); + + out_dims = ValidateShape(shape_vec, offsets_vec, x.dims()); + out->Resize(out_dims); + dev_ctx.template Alloc(out); + for (size_t i = 0; i < offsets_vec.size(); ++i) { + PADDLE_ENFORCE_LE(offsets_vec[i] + shape_vec[i], + x_dims[i], + errors::InvalidArgument( + "The sum of the %uth elements of " + "offsets (%d) and shape (%d) of Op(crop_tensor) " + "should be less than or " + "equal to the size of %uth dimension of the input.", + i, + offsets_vec[i], + shape_vec[i], + i)); + } + + auto x_tensor = EigenTensor::From(x); + auto out_tensor = EigenTensor::From(*out); + Eigen::DSizes e_offsets; + Eigen::DSizes e_shape; + for (size_t i = 0; i < D; ++i) { + e_offsets[i] = offsets_vec[i]; + e_shape[i] = out->dims()[i]; + } + auto& place = *dev_ctx.eigen_device(); + phi::funcs::EigenSlice, T, D>::Eval( + place, out_tensor, x_tensor, e_offsets, e_shape); +} + +template +void CropTensorKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& shape, + const IntArray& offsets, + DenseTensor* out) { + int rank = x.dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1, + errors::InvalidArgument( + "The number of dimensions of the input 'x' for " + "Op(crop_tensor) must be greater than or equal to 1, but the " + "value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, + 6, + errors::InvalidArgument( + "The number of dimensions of the input 'x' for " + "Op(crop_tensor) must be less than or equal to 6, but the " + "value received is %d.", + rank)); + switch (rank) { + case 1: + CropTensorFunction(dev_ctx, x, shape, offsets, out); + break; + case 2: + CropTensorFunction(dev_ctx, x, shape, offsets, out); + break; + case 3: + CropTensorFunction(dev_ctx, x, shape, offsets, out); + break; + case 4: + CropTensorFunction(dev_ctx, x, shape, offsets, out); + break; + case 5: + CropTensorFunction(dev_ctx, x, shape, offsets, out); + break; + case 6: + CropTensorFunction(dev_ctx, x, shape, offsets, out); + break; + } +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/crop_tensor_sig.cc b/paddle/phi/ops/compat/crop_tensor_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..994a7de8fb403648a7304bfa425d249fbec16190 --- /dev/null +++ b/paddle/phi/ops/compat/crop_tensor_sig.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature CropTensorOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.InputSize("ShapeTensor") > 0) { + if (ctx.InputSize("OffsetsTensor") > 0) { + return KernelSignature( + "crop_tensor", {"X"}, {"ShapeTensor", "OffsetsTensor"}, {"Out"}); + } else if (ctx.HasInput("Offsets")) { + return KernelSignature( + "crop_tensor", {"X"}, {"ShapeTensor", "Offsets"}, {"Out"}); + } else { + return KernelSignature( + "crop_tensor", {"X"}, {"ShapeTensor", "offsets"}, {"Out"}); + } + } else if (ctx.HasInput("Shape")) { + if (ctx.InputSize("OffsetsTensor") > 0) { + return KernelSignature( + "crop_tensor", {"X"}, {"Shape", "OffsetsTensor"}, {"Out"}); + } else if (ctx.HasInput("Offsets")) { + return KernelSignature( + "crop_tensor", {"X"}, {"Shape", "Offsets"}, {"Out"}); + } else { + return KernelSignature( + "crop_tensor", {"X"}, {"Shape", "offsets"}, {"Out"}); + } + } else { + if (ctx.InputSize("OffsetsTensor") > 0) { + return KernelSignature( + "crop_tensor", {"X"}, {"shape", "OffsetsTensor"}, {"Out"}); + } else if (ctx.HasInput("Offsets")) { + return KernelSignature( + "crop_tensor", {"X"}, {"shape", "Offsets"}, {"Out"}); + } else { + return KernelSignature( + "crop_tensor", {"X"}, {"shape", "offsets"}, {"Out"}); + } + } +} + +KernelSignature CropTensorGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.InputSize("OffsetsTensor") > 0) { + return KernelSignature( + "crop_tensor_grad", {"X", "Out@GRAD"}, {"OffsetsTensor"}, {"X@GRAD"}); + } else if (ctx.HasInput("Offsets")) { + return KernelSignature( + "crop_tensor_grad", {"X", "Out@GRAD"}, {"Offsets"}, {"X@GRAD"}); + } else { + return KernelSignature( + "crop_tensor_grad", {"X", "Out@GRAD"}, {"offsets"}, {"X@GRAD"}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(crop_tensor, phi::CropTensorOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(crop_tensor_grad, + phi::CropTensorGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py b/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py index 49805c578bf47acce178649ab59c863d223adecd..aa70e32cdc5b37fd4d7dfd7ed434db9c074fda85 100644 --- a/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py +++ b/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py @@ -51,6 +51,7 @@ class TestCropTensorOp(OpTest): self.offset_by_input = False self.unk_dim_idx = -1 self.attrs = {} + self.python_api = paddle.crop self.initTestCase() if self.shape_by_input: @@ -146,6 +147,7 @@ class TestCropTensorOpTensorAttr(OpTest): self.OffsetsTensor = False self.ShapeTensor = True self.attrs = {} + self.python_api = paddle.crop self.initTestCase() if self.ShapeTensor: diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index f3b67cf743debc0b76fb1357be51b5f8bf028303..8d7d91e2f2ec967abd9aed23beccb2206b66d53b 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -640,6 +640,7 @@ def crop(x, shape=None, offsets=None, name=None): # if offsets = [1, 1], out = [[5,6], [8,9]] """ + helper = LayerHelper('crop_tensor', **locals()) check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], 'crop_tensor') @@ -650,6 +651,9 @@ def crop(x, shape=None, offsets=None, name=None): if offsets is None: offsets = [0] * len(x.shape) + if in_dygraph_mode(): + return _C_ops.final_state_crop_tensor(x, shape, offsets) + out = helper.create_variable_for_type_inference(x.dtype) ipts = {'X': x} attrs = {}