From 1804f8347ece34dcd7d9c3f7d37b930019237ba2 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Wed, 28 Dec 2022 14:01:54 +0800 Subject: [PATCH] generate the static graph code of some ops (#49212) * generate the static op of some ops * add the VERSION of pixel_shuffle * change the API doc of isclose * change the API doc of isclose * fix the isclose op comment --- paddle/fluid/operators/frame_op.cc | 104 -------------- paddle/fluid/operators/gather_nd_op.cc | 150 -------------------- paddle/fluid/operators/isclose_op.cc | 97 ------------- paddle/fluid/operators/overlap_add_op.cc | 102 ------------- paddle/fluid/operators/pixel_shuffle_op.cc | 115 --------------- paddle/phi/api/yaml/backward.yaml | 39 +++++ paddle/phi/api/yaml/legacy_backward.yaml | 40 ------ paddle/phi/api/yaml/legacy_ops.yaml | 46 ------ paddle/phi/api/yaml/op_compat.yaml | 41 ++++++ paddle/phi/api/yaml/op_version.yaml | 9 ++ paddle/phi/api/yaml/ops.yaml | 48 +++++++ paddle/phi/ops/compat/frame_sig.cc | 28 ---- paddle/phi/ops/compat/gather_scatter_sig.cc | 26 ---- paddle/phi/ops/compat/isclose_sig.cc | 50 ------- paddle/phi/ops/compat/overlap_add_sig.cc | 30 ---- paddle/phi/ops/compat/pixel_shuffle_sig.cc | 37 ----- python/paddle/tensor/logic.py | 20 ++- 17 files changed, 151 insertions(+), 831 deletions(-) delete mode 100644 paddle/fluid/operators/frame_op.cc delete mode 100644 paddle/fluid/operators/gather_nd_op.cc delete mode 100644 paddle/fluid/operators/isclose_op.cc delete mode 100644 paddle/fluid/operators/overlap_add_op.cc delete mode 100644 paddle/fluid/operators/pixel_shuffle_op.cc delete mode 100644 paddle/phi/ops/compat/frame_sig.cc delete mode 100644 paddle/phi/ops/compat/gather_scatter_sig.cc delete mode 100644 paddle/phi/ops/compat/isclose_sig.cc delete mode 100644 paddle/phi/ops/compat/overlap_add_sig.cc delete mode 100644 paddle/phi/ops/compat/pixel_shuffle_sig.cc diff --git a/paddle/fluid/operators/frame_op.cc b/paddle/fluid/operators/frame_op.cc deleted file mode 100644 index 8acb372a559..00000000000 --- a/paddle/fluid/operators/frame_op.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/enforce.h" - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class FrameOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(in_dtype, ctx.GetPlace()); - } -}; - -class FrameOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of frame op."); - AddOutput("Out", "(Tensor), The output tensor of frame op."); - AddAttr( - "frame_length", - "Length of the frame and `0 < frame_length <= x.shape[axis]`."); - AddAttr("hop_length", - "Number of steps to advance between adjacent frames and " - "`0 < hop_length`."); - AddAttr("axis", - "Specify the axis to operate on the input Tensors. Its value " - "should be 0(the first dimension) or -1(the last dimension).") - .SetDefault(-1); - AddComment(R"DOC( - Slice the N-dimensional (where N >= 1) input into (overlapping) frames. - )DOC"); - } -}; - -class FrameOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(in_dtype, ctx.GetPlace()); - } -}; - -template -class FrameOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - void Apply(GradOpPtr retv) const override { - retv->SetType("frame_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(frame, - FrameInferShapeFunctor, - PD_INFER_META(phi::FrameInferMeta)); - -DECLARE_INFER_SHAPE_FUNCTOR(frame_grad, - FrameGradInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); - -REGISTER_OPERATOR(frame, - ops::FrameOp, - ops::FrameOpMaker, - ops::FrameOpGradMaker, - ops::FrameOpGradMaker, - FrameInferShapeFunctor); - -REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad, FrameGradInferShapeFunctor); diff --git a/paddle/fluid/operators/gather_nd_op.cc b/paddle/fluid/operators/gather_nd_op.cc deleted file mode 100644 index 3198e35b8a4..00000000000 --- a/paddle/fluid/operators/gather_nd_op.cc +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -class GatherNdOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - const auto& x_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType( - x_type, - x_type == framework::proto::VarType::BOOL - ? x->place() // to be consistent with compare and logical ops - : ctx.device_context().GetPlace()); - } -}; - -class GatherNdGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); - } -}; - -class GatherNdOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The source input of gather_nd op"); - AddInput("Index", "The index input of gather_nd op"); - AddOutput("Out", "The output of gather_nd op"); - AddComment(R"DOC( - Gather_Nd Operator. - - This function is actually a high-dimensional extension of gather - and supports for simultaneous indexing by multiple axes. Out is - obtained by gathering slices from X into a tensor with shape - Index.shape[:-1] + X.shape[Index.shape[-1]:]. - - Example: - - Given: - X = [[[ 0, 1, 2, 3], - [ 4, 5, 6, 7], - [ 8, 9, 10, 11]], - [[12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23]]] - - X.shape = (2, 3, 4) - - *Case 1: - - Index = [[1]] - - we get: - Out = - [[12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23]] - - *Case 2: - - Index = [[0,2]] - - we get: - - Out = [8, 9, 10, 11] - - *Case 3: - - Index = [[1, 2, 3]] - - we get: - - Out = [23] - -)DOC"); - } -}; - -template -class GatherNdGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("gather_nd_grad"); - op->SetInput("Index", this->Input("Index")); - op->SetInput("X", this->Input("X")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherNdGradNoNeedBufferVarInferer, "X"); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(gather_nd, - GatherNdInferShapeFunctor, - PD_INFER_META(phi::GatherNdInferMeta)); - -DECLARE_INFER_SHAPE_FUNCTOR(gather_nd_grad, - GatherNdGradInferShapeFunctor, - PD_INFER_META(phi::GatherNdGradInferMeta)); - -REGISTER_OPERATOR(gather_nd, - ops::GatherNdOp, - ops::GatherNdOpMaker, - ops::GatherNdGradOpMaker, - ops::GatherNdGradOpMaker, - GatherNdInferShapeFunctor); - -REGISTER_OPERATOR(gather_nd_grad, - ops::GatherNdGradOp, - ops::GatherNdGradNoNeedBufferVarInferer, - GatherNdGradInferShapeFunctor); diff --git a/paddle/fluid/operators/isclose_op.cc b/paddle/fluid/operators/isclose_op.cc deleted file mode 100644 index 8d0cd10097f..00000000000 --- a/paddle/fluid/operators/isclose_op.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -class IscloseOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Input", - "The input tensor, it's data type should be float32, float64."); - AddInput("Other", - "The input tensor, it's data type should be float32, float64."); - AddInput("Rtol", "The relative tolerance.").AsDispensable(); - AddInput("Atol", "The absolute tolerance.").AsDispensable(); - AddOutput("Out", "The output tensor, it's data type is bool."); - AddAttr("rtol", - "The relative tolerance. Default: :math:`1e-5` .") - .SetDefault("1e-5"); - AddAttr("atol", - "The absolute tolerance. Default: :math:`1e-8` .") - .SetDefault("1e-8"); - AddAttr("equal_nan", - "If :math:`True` , then two :math:`NaNs` will be " - "compared as equal. Default: :math:`False` .") - .SetDefault(false); - - AddComment(R"DOC( -This operator checks if all :math:`x` and :math:`y` satisfy the condition: - -.. math:: - \left| x - y \right| \leq atol + rtol \times \left| y \right| - -elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this -operator is analogous to :math:`numpy.isclose`, namely that it returns :math:`True` if -two tensors are elementwise equal within a tolerance. -)DOC"); - } -}; - -class IscloseOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); - } -}; - -class IscloseOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext* ctx) const override { - ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(isclose, - IscloseInferShapeFunctor, - PD_INFER_META(phi::ValueCompareInferMeta)); -REGISTER_OPERATOR( - isclose, - ops::IscloseOp, - ops::IscloseOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - ops::IscloseOpVarTypeInference, - IscloseInferShapeFunctor); diff --git a/paddle/fluid/operators/overlap_add_op.cc b/paddle/fluid/operators/overlap_add_op.cc deleted file mode 100644 index 4ead2161357..00000000000 --- a/paddle/fluid/operators/overlap_add_op.cc +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class OverlapAddOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(in_dtype, ctx.GetPlace()); - } -}; - -class OverlapAddOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of overlap_add op."); - AddOutput("Out", "(Tensor), The output tensor of overlap_add op."); - AddAttr("hop_length", - "Number of steps to advance between adjacent frames and " - "`0 < hop_length <= frame_length`."); - AddAttr("axis", - "Specify the axis to operate on the input Tensors. Its value " - "should be 0(the first dimension) or -1(the last dimension).") - .SetDefault(-1); - AddComment(R"DOC( - Reconstructs a tensor consisted of overlap added sequences from input frames. - )DOC"); - } -}; - -class OverlapAddOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(in_dtype, ctx.GetPlace()); - } -}; - -template -class OverlapAddOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr retv) const override { - retv->SetType("overlap_add_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(overlap_add, - OverlapAddInferShapeFunctor, - PD_INFER_META(phi::OverlapAddInferMeta)); - -DECLARE_INFER_SHAPE_FUNCTOR(overlap_add_grad, - OverlapAddGradInferShapeFunctor, - PD_INFER_META(phi::OverlapAddGradInferMeta)); - -REGISTER_OPERATOR(overlap_add, - ops::OverlapAddOp, - ops::OverlapAddOpMaker, - ops::OverlapAddOpGradMaker, - ops::OverlapAddOpGradMaker, - OverlapAddInferShapeFunctor); - -REGISTER_OPERATOR(overlap_add_grad, - ops::OverlapAddOpGrad, - OverlapAddGradInferShapeFunctor); diff --git a/paddle/fluid/operators/pixel_shuffle_op.cc b/paddle/fluid/operators/pixel_shuffle_op.cc deleted file mode 100644 index 098395d8850..00000000000 --- a/paddle/fluid/operators/pixel_shuffle_op.cc +++ /dev/null @@ -1,115 +0,0 @@ -/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class PixelShuffleOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Tensor, default Tensor), " - "the input feature data of PixelShuffleOp, the layout is [N, C, " - "H, W] or [N, H, W, C]."); - AddOutput("Out", - "(Tensor, default Tensor), the output of " - "PixelShuffleOp. The layout is [N, C/factor^2, H*factor, " - "W*factor] or [N, H*factor, W*factor, C/factor^2]."); - AddAttr("upscale_factor", - "the factor to increase spatial resolution by.") - .SetDefault(1) - .AddCustomChecker([](const int& upscale_factor) { - PADDLE_ENFORCE_GE(upscale_factor, - 1, - platform::errors::InvalidArgument( - "upscale_factor should be larger than 0.")); - }); - AddAttr( - "data_format", - "An optional string from: \"NHWC\", \"NCHW\". " - "Defaults to \"NHWC\", Specify the data format of the input data.") - .SetDefault("NCHW"); - - AddComment(R"DOC( - Pixel Shuffle operator - This operator rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` - to a tensor of shape :math:`(C, H \times r, W \times r)`. - - This is useful for implementing efficient sub-pixel convolution - with a stride of :math:`1/r`. - - Please refer to the paper: - `Real-Time Single Image and Video Super-Resolution Using an Efficient - Sub-Pixel Convolutional Neural Network `_ - by Shi et. al (2016) for more details. - )DOC"); - } -}; - -template -class PixelShuffleGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr op) const override { - op->SetType("pixel_shuffle_grad"); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetAttrMap(this->Attrs()); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - } -}; - -class PixelShuffleGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(pixel_shuffle, - PixelShuffleInferShapeFunctor, - PD_INFER_META(phi::PixelShuffleInferMeta)); - -REGISTER_OPERATOR(pixel_shuffle, - ops::PixelShuffleOp, - ops::PixelShuffleOpMaker, - ops::PixelShuffleGradMaker, - ops::PixelShuffleGradMaker, - PixelShuffleInferShapeFunctor); - -DECLARE_INFER_SHAPE_FUNCTOR(pixel_shuffle_grad, - PixelShuffleGradInferShapeFunctor, - PD_INFER_META(phi::PixelShuffleGradInferMeta)); -REGISTER_OPERATOR(pixel_shuffle_grad, - ops::PixelShuffleGradOp, - PixelShuffleGradInferShapeFunctor); - -REGISTER_OP_VERSION(pixel_shuffle) - .AddCheckpoint( - R"ROC( - Compatible upgrade of pixel_shuffle, add a new attribute [data_format])ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "data_format", "Specify the data format of the input data", true)); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index e9ef3bebfc6..dead42d03f7 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -491,6 +491,26 @@ data_type : out_grad no_need_buffer : x +- backward_op : frame_grad + forward : frame(Tensor x, int frame_length, int hop_length, int axis=-1) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int frame_length, int hop_length, int axis) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : frame_grad + +- backward_op : gather_nd_grad + forward : gather_nd (Tensor x, Tensor index) -> Tensor(out) + args : (Tensor x, Tensor index, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : GatherNdGradInferMeta + kernel : + func : gather_nd_grad + no_need_buffer : x + - backward_op : gelu_grad forward : gelu(Tensor x, bool approximate) -> Tensor(out) args : (Tensor x, Tensor out_grad, bool approximate) @@ -799,6 +819,25 @@ data_type : input optional : weight +- backward_op : overlap_add_grad + forward : overlap_add(Tensor x, int hop_length, int axis) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int hop_length, int axis) + output : Tensor(x_grad) + infer_meta : + func : OverlapAddGradInferMeta + kernel : + func : overlap_add_grad + data_type : x + +- backward_op : pixel_shuffle_grad + forward : pixel_shuffle (Tensor x, int upscale_factor=1, str data_format="NCHW") -> Tensor(out) + args : (Tensor out_grad, int upscale_factor, str data_format) + output : Tensor(x_grad) + infer_meta : + func : PixelShuffleGradInferMeta + kernel : + func : pixel_shuffle_grad + - backward_op : poisson_grad forward : poisson (Tensor x) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 3e7f0210adf..8d7af90a90a 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -539,16 +539,6 @@ kernel : func : fmin_grad -- backward_op : frame_grad - forward : frame(Tensor x, int frame_length, int hop_length, int axis) -> Tensor(out) - args : (Tensor x, Tensor out_grad, int frame_length, int hop_length, int axis) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : frame_grad - - backward_op : frobenius_norm_grad forward : frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all) @@ -571,17 +561,6 @@ func : gather_grad no_need_buffer : x -- backward_op : gather_nd_grad - forward : gather_nd (Tensor x, Tensor index) -> Tensor(out) - args : (Tensor x, Tensor index, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : gather_nd_grad - no_need_buffer : x - - backward_op : group_norm_grad forward : group_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int groups, str data_layout) -> Tensor(y), Tensor(mean), Tensor(variance) args : (Tensor x, Tensor scale, Tensor bias, Tensor y, Tensor mean, Tensor variance, Tensor y_grad, float epsilon, int groups, str data_layout) @@ -992,16 +971,6 @@ kernel : func : norm_grad -- backward_op : overlap_add_grad - forward : overlap_add(Tensor x, int hop_length, int axis) -> Tensor(out) - args : (Tensor x, Tensor out_grad, int hop_length, int axis) - output : Tensor(x_grad) - infer_meta : - func : OverlapAddGradInferMeta - kernel : - func : overlap_add_grad - data_type : x - - backward_op : p_norm_grad forward : p_norm(Tensor x, float porder, int axis, float epsilon, bool keepdim, bool asvector=false) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, float porder, int axis, float epsilon, bool keepdim, bool asvector) @@ -1055,15 +1024,6 @@ no_need_buffer : x backward : pad_double_grad -- backward_op : pixel_shuffle_grad - forward : pixel_shuffle (Tensor x, int upscale_factor, str data_format) -> Tensor(out) - args : (Tensor out_grad, int upscale_factor, str data_format) - output : Tensor(x_grad) - infer_meta : - func : PixelShuffleGradInferMeta - kernel : - func : pixel_shuffle_grad - - backward_op : pool2d_double_grad forward : pool2d_grad(Tensor x, Tensor out, Tensor grad_out, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(grad_x) args : (Tensor x, Tensor grad_x_grad, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index c0f92cd0175..b93ca2944ab 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -743,15 +743,6 @@ func : fmin backward : fmin_grad -- op : frame - args : (Tensor x, int frame_length, int hop_length, int axis) - output : Tensor(out) - infer_meta : - func : FrameInferMeta - kernel : - func : frame - backward : frame_grad - - op : frobenius_norm args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) output : Tensor(out) @@ -822,16 +813,6 @@ data_type: x backward : gather_grad -- op : gather_nd - args : (Tensor x, Tensor index) - output : Tensor - infer_meta : - func : GatherNdInferMeta - kernel : - func : gather_nd - data_type : x - backward : gather_nd_grad - - op : gaussian args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={}) output: Tensor(out) @@ -960,15 +941,6 @@ intermediate : saved_mean, saved_variance backward : instance_norm_grad -- op : isclose - args : (Tensor x, Tensor y, Scalar rtol, Scalar atol, bool equal_nan) - output : Tensor(out) - infer_meta : - func : ValueCompareInferMeta - param: [x, y] - kernel : - func : isclose - - op : kldiv_loss args : (Tensor x, Tensor label, str reduction) output : Tensor(out) @@ -1414,15 +1386,6 @@ output : Tensor(out) invoke : full_like(x, 1, dtype, place) -- op : overlap_add - args: (Tensor x, int hop_length, int axis) - output: Tensor - infer_meta: - func: OverlapAddInferMeta - kernel: - func: overlap_add - backward: overlap_add_grad - - op : p_norm args : (Tensor x, float porder, int axis, float epsilon, bool keepdim, bool asvector=false) output : Tensor(out) @@ -1450,15 +1413,6 @@ func : pad3d backward : pad3d_grad -- op : pixel_shuffle - args : (Tensor x, int upscale_factor, str data_format) - output : Tensor - infer_meta : - func : PixelShuffleInferMeta - kernel : - func : pixel_shuffle - backward : pixel_shuffle_grad - - op : pool2d args : (Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 574d68a8347..7e960d73bbb 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -547,6 +547,13 @@ outputs : out : Y +- op : frame + backward : frame_grad + inputs : + x : X + outputs : + out : Out + - op : frobenius_norm backward : frobenius_norm_grad extra : @@ -561,6 +568,13 @@ extra : attrs : [bool overwrite = true] +- op : gather_nd + backward : gather_nd_grad + inputs : + {x : X, index : Index} + outputs : + out : Out + - op : gather_tree inputs : {ids : Ids, parents : Parents} @@ -663,6 +677,19 @@ outputs : out : Out +- op : isclose + inputs : + {x : Input, y : Other} + outputs : + out : Out + scalar : + rtol : + data_type : std::string + tensor_name : Rtol + atol : + data_type : std::string + tensor_name : Atol + - op : isfinite (isfinite_v2) inputs : x : X @@ -899,6 +926,13 @@ outputs : {out : Out, total_weight : Total_weight} +- op : overlap_add + backward : overlap_add_grad + inputs : + x : X + outputs : + out : Out + - op : pad2d backward : pad2d_grad extra : @@ -914,6 +948,13 @@ extra : attrs : [bool use_mkldnn = false] +- op : pixel_shuffle + backward : pixel_shuffle_grad + inputs : + x : X + outputs : + out : Out + - op : poisson inputs : x : X diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 6805da4daad..77e722a2978 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -16,6 +16,15 @@ comment : In order to specify interpolation mode default : std::string("bilinear") +- op : pixel_shuffle + version : + - checkpoint : Compatible upgrade of pixel_shuffle, add a new attribute [data_format] + action : + - add_attr : + name : data_format + comment : Specify the data format of the input data + default : "true" + - op : roll version : - checkpoint : Upgrade roll add 1 attribute [axis], delete 1 attribute[dims]. diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 16ae0e9f710..0e85b2d8dff 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -431,6 +431,25 @@ func: fold backward: fold_grad +- op : frame + args : (Tensor x, int frame_length, int hop_length, int axis=-1) + output : Tensor(out) + infer_meta : + func : FrameInferMeta + kernel : + func : frame + backward : frame_grad + +- op : gather_nd + args : (Tensor x, Tensor index) + output : Tensor + infer_meta : + func : GatherNdInferMeta + kernel : + func : gather_nd + data_type : x + backward : gather_nd_grad + - op : gather_tree args : (Tensor ids, Tensor parents) output : Tensor(out) @@ -535,6 +554,16 @@ kernel : func : is_empty +- op : isclose + args : (Tensor x, Tensor y, Scalar rtol="1e-5", Scalar atol="1e-8", bool equal_nan=false) + output : Tensor(out) + infer_meta : + func : ValueCompareInferMeta + param: [x, y] + kernel : + func : isclose + data_type : x + - op : isfinite args : (Tensor x) output : Tensor(out) @@ -761,6 +790,25 @@ kernel : func : npu_identity +- op : overlap_add + args: (Tensor x, int hop_length, int axis=-1) + output: Tensor + infer_meta: + func: OverlapAddInferMeta + kernel: + func: overlap_add + data_type : x + backward: overlap_add_grad + +- op : pixel_shuffle + args : (Tensor x, int upscale_factor=1, str data_format="NCHW") + output : Tensor + infer_meta : + func : PixelShuffleInferMeta + kernel : + func : pixel_shuffle + backward : pixel_shuffle_grad + - op : poisson args : (Tensor x) output : Tensor diff --git a/paddle/phi/ops/compat/frame_sig.cc b/paddle/phi/ops/compat/frame_sig.cc deleted file mode 100644 index cbe24095b0f..00000000000 --- a/paddle/phi/ops/compat/frame_sig.cc +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature FrameGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("frame_grad", - {"X", "Out@GRAD"}, - {"frame_length", "hop_length", "axis"}, - {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(frame_grad, phi::FrameGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/gather_scatter_sig.cc b/paddle/phi/ops/compat/gather_scatter_sig.cc deleted file mode 100644 index e37ba0ff401..00000000000 --- a/paddle/phi/ops/compat/gather_scatter_sig.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "gather_nd_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(gather_nd_grad, phi::GatherNdGradArgumentMapping); diff --git a/paddle/phi/ops/compat/isclose_sig.cc b/paddle/phi/ops/compat/isclose_sig.cc deleted file mode 100644 index 08632e99095..00000000000 --- a/paddle/phi/ops/compat/isclose_sig.cc +++ /dev/null @@ -1,50 +0,0 @@ - -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature IscloseOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.HasInput("Rtol")) { - if (ctx.HasInput("Atol")) { - return KernelSignature("isclose", - {"Input", "Other"}, - {"Rtol", "Atol", "equal_nan"}, - {"Out"}); - - } else { - return KernelSignature("isclose", - {"Input", "Other"}, - {"Rtol", "atol", "equal_nan"}, - {"Out"}); - } - } else { - if (ctx.HasInput("Atol")) { - return KernelSignature("isclose", - {"Input", "Other"}, - {"rtol", "Atol", "equal_nan"}, - {"Out"}); - } else { - return KernelSignature("isclose", - {"Input", "Other"}, - {"rtol", "atol", "equal_nan"}, - {"Out"}); - } - } -} - -} // namespace phi -PD_REGISTER_ARG_MAPPING_FN(isclose, phi::IscloseOpArgumentMapping); diff --git a/paddle/phi/ops/compat/overlap_add_sig.cc b/paddle/phi/ops/compat/overlap_add_sig.cc deleted file mode 100644 index c694b97f8bb..00000000000 --- a/paddle/phi/ops/compat/overlap_add_sig.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature OverlapAddGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("overlap_add_grad", - {"X", "Out@GRAD"}, - {"hop_length", "axis"}, - {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(overlap_add_grad, - phi::OverlapAddGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/pixel_shuffle_sig.cc b/paddle/phi/ops/compat/pixel_shuffle_sig.cc deleted file mode 100644 index 96cb01a38fc..00000000000 --- a/paddle/phi/ops/compat/pixel_shuffle_sig.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature PixelShuffleOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "pixel_shuffle", {"X"}, {"upscale_factor", "data_format"}, {"Out"}); -} - -KernelSignature PixelShuffleGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("pixel_shuffle_grad", - {"Out@GRAD"}, - {"upscale_factor", "data_format"}, - {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(pixel_shuffle, phi::PixelShuffleOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(pixel_shuffle_grad, - phi::PixelShuffleGradOpArgumentMapping); diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 0d66daf000c..c53d530e609 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -936,20 +936,28 @@ def bitwise_not(x, out=None, name=None): @templatedoc() def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): - """ - ${comment} + r""" + Checks if all :math:`x` and :math:`y` satisfy the condition: + + .. math:: + + \left| x - y \right| \leq atol + rtol \times \left| y \right| + + elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this + operator is analogous to :math:`numpy.isclose`, namely that it returns :math:`True` if + two tensors are elementwise equal within a tolerance. Args: - x(Tensor): ${input_comment}. - y(Tensor): ${other_comment}. + x(Tensor): The input tensor, it's data type should be float32, float64. + y(Tensor): The input tensor, it's data type should be float32, float64. rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` . atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` . - equal_nan(equalnantype, optional): ${equal_nan_comment}. + equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` . name (str, optional): Name for the operation. For more information, please refer to :ref:`api_guide_Name`. Default: None. Returns: - Tensor: ${out_comment}. + Tensor: The output tensor, it's data type is bool. Examples: .. code-block:: python -- GitLab