From 17fb92b355a7f8d0f505c3221087f69d16571f94 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 28 Oct 2022 21:39:47 +0800 Subject: [PATCH] generate static graph code for some ops by yaml (#47416) --- paddle/fluid/operators/angle_op.cc | 94 -------------- paddle/fluid/operators/argsort_op.cc | 115 ------------------ paddle/fluid/operators/bmm_op.cc | 103 ---------------- paddle/fluid/operators/bmm_op.h | 63 ---------- paddle/fluid/operators/determinant_op.cc | 67 ---------- paddle/phi/api/yaml/backward.yaml | 43 +++++++ .../generator/templates/operator_utils.c.j2 | 2 +- paddle/phi/api/yaml/legacy_backward.yaml | 43 ------- paddle/phi/api/yaml/legacy_ops.yaml | 36 ------ paddle/phi/api/yaml/op_compat.yaml | 24 ++++ paddle/phi/api/yaml/ops.yaml | 36 ++++++ paddle/phi/kernels/cpu/angle_grad_kernel.cc | 4 +- paddle/phi/kernels/gpu/angle_grad_kernel.cu | 4 +- paddle/phi/ops/compat/angle_sig.cc | 30 ----- paddle/phi/ops/compat/argsort_sig.cc | 29 ----- paddle/phi/ops/compat/bmm_sig.cc | 26 ---- paddle/phi/ops/compat/determinant_sig.cc | 28 ----- 17 files changed, 110 insertions(+), 637 deletions(-) delete mode 100644 paddle/fluid/operators/angle_op.cc delete mode 100644 paddle/fluid/operators/argsort_op.cc delete mode 100644 paddle/fluid/operators/bmm_op.cc delete mode 100644 paddle/fluid/operators/bmm_op.h delete mode 100644 paddle/phi/ops/compat/angle_sig.cc delete mode 100644 paddle/phi/ops/compat/argsort_sig.cc delete mode 100644 paddle/phi/ops/compat/bmm_sig.cc delete mode 100644 paddle/phi/ops/compat/determinant_sig.cc diff --git a/paddle/fluid/operators/angle_op.cc b/paddle/fluid/operators/angle_op.cc deleted file mode 100644 index ccd5584e8d..0000000000 --- a/paddle/fluid/operators/angle_op.cc +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class AngleOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -class AngleOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of angle op."); - AddOutput("Out", "(Tensor), The output tensor of angle op."); - AddComment(R"DOC( -Angle Operator. - -This operator is used to perform elementwise angle for input $X$. -$$out = angle(x)$$ - -)DOC"); - } -}; - -class AngleGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(dtype, ctx.GetPlace()); - } -}; - -template -class AngleGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr retv) const override { - retv->SetType("angle_grad"); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetInput("X", this->Input("X")); - retv->SetAttrMap(this->Attrs()); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(angle, - AngleInferShapeFunctor, - PD_INFER_META(phi::RealAndImagInferMeta)); - -DECLARE_INFER_SHAPE_FUNCTOR(angle_grad, - AngleGradInferShapeFunctor, - PD_INFER_META(phi::AngleGradInferMeta)); - -REGISTER_OPERATOR(angle, - ops::AngleOp, - ops::AngleOpMaker, - ops::AngleGradMaker, - ops::AngleGradMaker, - AngleInferShapeFunctor); - -REGISTER_OPERATOR(angle_grad, ops::AngleGradOp, AngleGradInferShapeFunctor); diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc deleted file mode 100644 index f17723bf83..0000000000 --- a/paddle/fluid/operators/argsort_op.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class ArgsortOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -class ArgsortGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X")); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); - } -}; - -class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor) The input of Argsort op."); - AddOutput("Out", - "(Tensor) The sorted tensor of Argsort op, with the same " - "shape as Input(X)."); - AddOutput("Indices", - "(Tensor) The indices of a tensor giving the sorted order, with " - "the same shape as Input(X)."); - AddComment(R"DOC( -Argsort operator - -Performs sorting on the input tensor along the given axis and outputs two -tensors, Output(Out) and Output(Indices). They reserve the same shape -with Input(X), and Output(Out) represents the sorted tensor while -Output(Indices) gives the sorted order along the given axis Attr(axis). - - )DOC"); - AddAttr("axis", - "(int, default -1) The axis along which to sort the tensor. " - "When axis < 0, the actual axis will be the |axis|'th " - "counting backwards. Default -1, the last dimension.") - .SetDefault(-1); - AddAttr( - "descending", - "(bool, default false) The descending attribute is a flag to tell" - "algorithm how to sort the input data." - "If descending is true, will sort by descending order," - "else if false, sort by ascending order. Default value is false.") - .SetDefault(false); - } -}; - -template -class ArgsortGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("argsort_grad"); - op->SetInput("Indices", this->Output("Indices")); - op->SetInput("X", this->Input("X")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(ArgsortGradNoNeedBufferVarsInferer, "X"); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(argsort, - ArgsortInferShapeFunctor, - PD_INFER_META(phi::ArgsortInferMeta)); -REGISTER_OPERATOR(argsort, - ops::ArgsortOp, - ops::ArgsortOpMaker, - ops::ArgsortGradOpMaker, - ops::ArgsortGradOpMaker, - ArgsortInferShapeFunctor); -REGISTER_OPERATOR(argsort_grad, - ops::ArgsortGradOp, - ops::ArgsortGradNoNeedBufferVarsInferer); diff --git a/paddle/fluid/operators/bmm_op.cc b/paddle/fluid/operators/bmm_op.cc deleted file mode 100644 index b27594eed3..0000000000 --- a/paddle/fluid/operators/bmm_op.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#include "paddle/fluid/operators/bmm_op.h" - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -class BmmOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); - } -}; - -class BmmOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The first input tensor of Bmm op."); - AddInput("Y", "(Tensor), The second input tensor of Bmm op."); - AddOutput("Out", "(Tensor), The output tensor of Bmm op."); - AddComment(R"DOC( -The Bmm operator is used to perform batched matrix multiplication -over the last two dimensions of the input tensors `X` and `Y` -which are both 3-dimentionsal. - -Examples: -- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] - - )DOC"); - } -}; - -class BmmOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); - } -}; - -template -class BmmOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr retv) const override { - retv->SetType("bmm_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput("Y", this->Input("Y")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(bmm, - BmmInferShapeFunctor, - PD_INFER_META(phi::BmmInferMeta)); -DECLARE_INFER_SHAPE_FUNCTOR(bmm_grad, - BmmGradInferShapeFunctor, - PD_INFER_META(phi::BmmGradInferMeta)); -REGISTER_OPERATOR(bmm, - ops::BmmOp, - ops::BmmOpMaker, - ops::BmmOpGradMaker, - ops::BmmOpGradMaker, - BmmInferShapeFunctor); -REGISTER_OPERATOR(bmm_grad, ops::BmmOpGrad, BmmGradInferShapeFunctor); diff --git a/paddle/fluid/operators/bmm_op.h b/paddle/fluid/operators/bmm_op.h deleted file mode 100644 index 5ca8df0182..0000000000 --- a/paddle/fluid/operators/bmm_op.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#ifndef PADDLE_FLUID_OPERATORS_BMM_OP_H_ -#define PADDLE_FLUID_OPERATORS_BMM_OP_H_ - -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" -namespace paddle { -namespace operators { - -using Tensor = phi::DenseTensor; - -static void ReshapeTensorIntoMatrixSequence( - phi::DenseTensor *x, const phi::funcs::MatDescriptor &descriptor) { - int64_t h, w; - h = descriptor.height_; - w = descriptor.width_; - if (descriptor.trans_) { - std::swap(w, h); - } - - x->Resize({descriptor.batch_size_, h, w}); -} - -static void ReshapeXYOutIntoMatrixSequence(phi::DenseTensor *x, - phi::DenseTensor *y, - phi::DenseTensor *out, - bool trans_x, - bool trans_y) { - auto x_dim = x->dims(); - auto y_dim = y->dims(); - auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, false); - auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, false); - - out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), - mat_dim_x.height_, - mat_dim_y.width_}); - - ReshapeTensorIntoMatrixSequence(x, mat_dim_x); - ReshapeTensorIntoMatrixSequence(y, mat_dim_y); -} - -} // namespace operators -} // namespace paddle -#endif // PADDLE_FLUID_OPERATORS_BMM_OP_H_ diff --git a/paddle/fluid/operators/determinant_op.cc b/paddle/fluid/operators/determinant_op.cc index 89d5d2ded1..56e39747af 100644 --- a/paddle/fluid/operators/determinant_op.cc +++ b/paddle/fluid/operators/determinant_op.cc @@ -23,57 +23,6 @@ namespace paddle { namespace operators { -class DeterminantOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -class DeterminantOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Input", "(Tensor) The input tensor of determinant."); - AddOutput("Out", - "(Tensor) The output Tensor containing the determinant" - "value of a square matrix or batches of square matrices "); - - AddComment(R"DOC( -Determinant Operator.)DOC"); - } -}; - -class DeterminantGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); - } -}; - -template -class DeterminantGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("determinant_grad"); - grad_op->SetInput("Input", this->Input("Input")); - grad_op->SetInput("Out", this->Output("Out")); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad_op->SetOutput(framework::GradVarName("Input"), - this->InputGrad("Input")); - grad_op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(DeterminantGradNoNeedBufferVarsInferer, - "Input"); - class SlogDeterminantOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -154,22 +103,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer, namespace ops = paddle::operators; namespace plat = paddle::platform; -DECLARE_INFER_SHAPE_FUNCTOR(determinant, - DeterminantInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); -REGISTER_OPERATOR(determinant, - ops::DeterminantOp, - ops::DeterminantOpMaker, - ops::DeterminantGradOpMaker, - ops::DeterminantGradOpMaker, - DeterminantInferShapeFunctor); - -DECLARE_INFER_SHAPE_FUNCTOR(determinant_grad, - DeterminantGradInferShapeFunctor, - PD_INFER_META(phi::GeneralUnaryGradInferMeta)); -REGISTER_OPERATOR(determinant_grad, - ops::DeterminantGradOp, - DeterminantGradInferShapeFunctor); DECLARE_INFER_SHAPE_FUNCTOR(slogdeterminant, SlogDeterminantInferShapeFunctor, diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index db97795b5b..cb51e8fa13 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -20,6 +20,28 @@ func : acosh_grad inplace : (out_grad -> x_grad) +- backward_op : angle_grad + forward : angle (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : angle_grad + +- backward_op : argsort_grad + forward : argsort (Tensor x, int axis, bool descending) -> Tensor(out), Tensor(indices) + args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : argsort_grad + data_type : out_grad + no_need_buffer : x + - backward_op : asin_grad forward : asin (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) @@ -74,6 +96,16 @@ func : atanh_grad inplace : (out_grad -> x_grad) +- backward_op : bmm_grad + forward : bmm (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : BmmGradInferMeta + kernel : + func : bmm_grad + data_type : out_grad + - backward_op : cholesky_grad forward : cholesky (Tensor x, bool upper) -> Tensor(out) args : (Tensor out, Tensor out_grad, bool upper) @@ -127,6 +159,17 @@ func : cross_grad data_type : out_grad +- backward_op : det_grad + forward : det (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : determinant_grad + data_type : out_grad + - backward_op : diag_grad forward : diag (Tensor x, int offset, float padding_value) -> Tensor(out) args : (Tensor x, Tensor out_grad, int offset) diff --git a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 index d2b0cf3290..60fd251f44 100644 --- a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 +++ b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 @@ -109,7 +109,7 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu {% endfor %} {{get_output_list(api["outputs"], kernel_args)}}; {% if api["kernel"]["func"] | length == 1 %} - KernelSignature sig("{{api["name"]}}", std::move(inputs), std::move(attrs), std::move(outputs)); + KernelSignature sig("{{api["kernel"]["func"][0]}}", std::move(inputs), std::move(attrs), std::move(outputs)); return sig; {% else %}{# it has kernel for selected rows #} const char* kernel_name = ctx.IsSelectedRowsInput({{kernel_args[0] | to_opmaker_name_cstr}}) ? "{{api["kernel"]["func"][1]}}" : "{{api["kernel"]["func"][0]}}"; diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index ced3d75bb9..916f5c405d 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -100,30 +100,6 @@ kernel : func : amin_grad -- backward_op : angle_grad - forward : angle (Tensor x) -> Tensor(out) - args : (Tensor x, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : angle_grad - data_transform: - skip_transform : out_grad - -- backward_op : argsort_grad - forward : argsort (Tensor x, int axis, bool descending) -> Tensor(out), Tensor(indices) - args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : argsort_grad - data_type : out_grad - no_need_buffer : x - - backward_op : as_complex_grad forward : as_complex (Tensor x) -> Tensor(out) args : (Tensor out_grad) @@ -222,15 +198,6 @@ kernel : func : bilinear_tensor_product_grad -- backward_op : bmm_grad - forward : bmm (Tensor x, Tensor y) -> Tensor(out) - args : (Tensor x, Tensor y, Tensor out_grad) - output : Tensor(x_grad), Tensor(y_grad) - infer_meta : - func : BmmGradInferMeta - kernel : - func : bmm_grad - - backward_op : brelu_grad forward : brelu (Tensor x, float t_min, float t_max) -> Tensor(out) args : (Tensor x, Tensor out_grad, float t_min, float t_max) @@ -515,16 +482,6 @@ kernel : func : depthwise_conv2d_transpose_grad -- backward_op : det_grad - forward : det (Tensor x) -> Tensor(out) - args : (Tensor x, Tensor out, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : determinant_grad - - backward_op : divide_double_grad forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y) args : (Tensor y, Tensor out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index b0d79886c1..de290bd169 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -144,15 +144,6 @@ func : amin backward : amin_grad -- op : angle - args : (Tensor x) - output : Tensor - infer_meta : - func : RealAndImagInferMeta - kernel : - func : angle - backward : angle_grad - - op : any args : (Tensor x, int64_t[] axis={}, bool keepdim=false) output : Tensor(out) @@ -191,15 +182,6 @@ kernel : func : arg_min -- op : argsort - args : (Tensor x, int axis=-1, bool descending=false) - output : Tensor(out), Tensor(indices) - infer_meta : - func : ArgsortInferMeta - kernel : - func : argsort - backward : argsort_grad - - op : as_complex args : (Tensor x) output : Tensor @@ -355,15 +337,6 @@ kernel : func : bitwise_xor -- op : bmm - args : (Tensor x, Tensor y) - output : Tensor - infer_meta : - func : BmmInferMeta - kernel : - func : bmm - backward : bmm_grad - - op : box_coder args : (Tensor prior_box, Tensor prior_box_var, Tensor target_box, str code_type, bool box_normalized, int axis, float[] variance) output : Tensor(output_box) @@ -618,15 +591,6 @@ func : depthwise_conv2d_transpose backward : depthwise_conv2d_transpose_grad -- op : det - args : (Tensor x) - output : Tensor - infer_meta : - func : UnchangedInferMeta - kernel : - func : determinant - backward : det_grad - - op : diag_embed args : (Tensor input, int offset, int dim1, int dim2) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 59d258f0b0..304027861e 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -41,9 +41,20 @@ - op : angle backward : angle_grad + inputs : + x : X + outputs : + out : Out extra : attrs : [bool use_mkldnn = false] +- op : argsort + inputs : + x : X + outputs : + out : Out + indices : Indices + - op : asin inputs : x : X @@ -101,6 +112,12 @@ extra : attrs : [bool use_mkldnn = false] +- op : bmm + inputs : + {x : X, y : Y} + outputs : + out : Out + - op : ceil backward : ceil_grad extra : @@ -226,6 +243,13 @@ extra : attrs : [float moving_rate = 0.9] +- op : det (determinant) + backward : det_grad (determinant_grad) + inputs : + x : Input + outputs : + out : Out + - op : diag (diag_v2) backward : diag_grad (diag_v2_grad) inputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index ec1ba17be6..e61b7490a1 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -16,6 +16,24 @@ func : acosh backward : acosh_grad +- op : angle + args : (Tensor x) + output : Tensor + infer_meta : + func : RealAndImagInferMeta + kernel : + func : angle + backward : angle_grad + +- op : argsort + args : (Tensor x, int axis=-1, bool descending=false) + output : Tensor(out), Tensor(indices) + infer_meta : + func : ArgsortInferMeta + kernel : + func : argsort + backward : argsort_grad + - op : asin args : (Tensor x) output : Tensor @@ -69,6 +87,15 @@ kernel : func : bernoulli +- op : bmm + args : (Tensor x, Tensor y) + output : Tensor + infer_meta : + func : BmmInferMeta + kernel : + func : bmm + backward : bmm_grad + - op : cholesky args : (Tensor x, bool upper=false) output : Tensor @@ -115,6 +142,15 @@ data_type : x backward : cross_grad +- op : det + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : determinant + backward : det_grad + - op : diag args : (Tensor x, int offset = 0, float padding_value = 0.0) output : Tensor diff --git a/paddle/phi/kernels/cpu/angle_grad_kernel.cc b/paddle/phi/kernels/cpu/angle_grad_kernel.cc index d12501916d..e3b10f0fc4 100644 --- a/paddle/phi/kernels/cpu/angle_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/angle_grad_kernel.cc @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(angle_grad, float, double, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/gpu/angle_grad_kernel.cu b/paddle/phi/kernels/gpu/angle_grad_kernel.cu index 062c39a9d1..e32c50e4c4 100644 --- a/paddle/phi/kernels/gpu/angle_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/angle_grad_kernel.cu @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(angle_grad, float, double, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} diff --git a/paddle/phi/ops/compat/angle_sig.cc b/paddle/phi/ops/compat/angle_sig.cc deleted file mode 100644 index 63b10e6bf4..0000000000 --- a/paddle/phi/ops/compat/angle_sig.cc +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature AngleOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("angle", {"X"}, {}, {"Out"}); -} - -KernelSignature AngleGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("angle_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(angle, phi::AngleOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(angle_grad, phi::AngleGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/argsort_sig.cc b/paddle/phi/ops/compat/argsort_sig.cc deleted file mode 100644 index 70531f1691..0000000000 --- a/paddle/phi/ops/compat/argsort_sig.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature ArgsortGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("argsort_grad", - {"Indices", "X", "Out@GRAD"}, - {"axis", "descending"}, - {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(argsort_grad, phi::ArgsortGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/bmm_sig.cc b/paddle/phi/ops/compat/bmm_sig.cc deleted file mode 100644 index 415a90c3d3..0000000000 --- a/paddle/phi/ops/compat/bmm_sig.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature BmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "bmm_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(bmm_grad, phi::BmmGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/determinant_sig.cc b/paddle/phi/ops/compat/determinant_sig.cc deleted file mode 100644 index ee1d53704c..0000000000 --- a/paddle/phi/ops/compat/determinant_sig.cc +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature DeterminantGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "determinant_grad", {"Input", "Out", "Out@GRAD"}, {}, {"Input@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(determinant_grad, - phi::DeterminantGradOpArgumentMapping); -- GitLab