From 4d7ddb49ad06e86af22209118c2e862a6038095e Mon Sep 17 00:00:00 2001 From: Zhenghai Zhang <65210872+ccsuzzh@users.noreply.github.com> Date: Fri, 30 Jun 2023 10:21:30 +0800 Subject: [PATCH] static graph autogen code for expand (#54628) * static graph autogen code for expand * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug --- paddle/fluid/operators/expand_op.cc | 9 - paddle/fluid/operators/expand_v2_op.cc | 255 ----------------------- paddle/phi/api/yaml/backward.yaml | 20 ++ paddle/phi/api/yaml/legacy_backward.yaml | 19 -- paddle/phi/api/yaml/legacy_ops.yaml | 9 - paddle/phi/api/yaml/op_compat.yaml | 3 +- paddle/phi/api/yaml/ops.yaml | 10 + test/prim/process/test_check_inputs.py | 2 +- 8 files changed, 33 insertions(+), 294 deletions(-) delete mode 100644 paddle/fluid/operators/expand_v2_op.cc diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 43fd505acda..fee4b470493 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -128,23 +125,17 @@ Expand operator tiles the input by given times number. You should set times number for each dimension by providing attribute 'expand_times'. The rank of X should be in [1, 6]. Please note that size of 'expand_times' must be the same with X's rank. Following is a using case: - Input(X) is a 3-D tensor with shape [2, 3, 1]: - [ [[1], [2], [3]], [[4], [5], [6]] ] - Attr(expand_times): [1, 2, 2] - Output(Out) is a 3-D tensor with shape [2, 6, 2]: - [ [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] ] - )DOC"); } }; diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc deleted file mode 100644 index 6df6422f717..00000000000 --- a/paddle/fluid/operators/expand_v2_op.cc +++ /dev/null @@ -1,255 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/expand_v2_op.h" - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" -#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" -#include "paddle/fluid/prim/utils/static/desc_tensor.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -#define MAX_RANK_SUPPORTED 6 - -namespace paddle { -namespace operators { - -class ExpandV2Op : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = - framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const phi::KernelKey& expected_kernel_type) const override { - if (var_name == "expand_shapes_tensor" || var_name == "Shape") { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } -}; - -class ExpandV2OpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Tensor, default Tensor). A tensor with rank in [1, 6]." - "X is the input to be expanded."); - AddInput("Shape", - "(Tensor), optional). If provided, expand according to " - "this given Shape. It has a higher priority than " - "expand_shapes_tensor and the shape attribute.") - .AsDispensable(); - AddInput("expand_shapes_tensor", - "(Tensor Tensor), epxanded shape for X." - "It has a higher priority than shape attribute, but a lower " - "priority than the input Shape") - .AsDuplicable() - .AsDispensable(); - AddOutput("Out", - "(Tensor, default Tensor). A tensor with rank in [1, 6]." - "The rank of Output(Out) have the same with Input(X). " - "After expanding, size of each dimension of Output(Out) is equal " - "to size of the corresponding dimension of Input(X) multiplying " - "the corresponding value given by Attr(expand_times)."); - AddAttr>("shape", "The expanded shape for each dimension.") - .SetDefault({}); - AddComment(R"DOC( -Expand the input to the given shape. The rank of X -should be in [1, 6] and size of 'shape' must be in [1, 6] also. -Following is a using case: - -Input(X) is a 3-D tensor with shape [2, 3, 1]: - - [ - [[1], [2], [3]], - [[4], [5], [6]] - ] - -Attr(shape): [2, 6, 2] - -Output(Out) is a 3-D tensor with shape [2, 6, 2]: - - [ - [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], - [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] - ] - -)DOC"); - } -}; - -class ExpandV2GradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandV2Grad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "ExpandV2Grad"); - - auto x_dims = ctx->GetInputDim("X"); - std::vector expand_shape = ctx->Attrs().Get>("shape"); - if (expand_shape.size() == 0) { - expand_shape = std::vector(x_dims.size(), -1); - } - - auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - auto x_dim_vec = phi::vectorize(x_dims); - auto diff = expand_shape.size() - x_dim_vec.size(); - x_dim_vec.insert(x_dim_vec.begin(), diff, -1); - - for (size_t i = 0; i < expand_shape.size(); ++i) { - if (expand_shape[i] < 0 || x_dim_vec[i] == -1) { - continue; - } else { - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - expand_shape[i], - out_dims[i], - platform::errors::InvalidArgument( - "The size (%d) of the dimension %d of Input(Out@GRAD) should " - "be equal to the crroresponding dimension size of shape(%d).", - out_dims[i], - i, - expand_shape[i])); - } - } - } - auto x_grad_name = framework::GradVarName("X"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const phi::KernelKey& expected_kernel_type) const override { - if (var_name == "expand_shapes_tensor" || var_name == "Shape") { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } -}; - -template -class ExpandV2GradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("expand_v2_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetInput("expand_shapes_tensor", this->Input("expand_shapes_tensor")); - op->SetInput("Shape", this->Input("Shape")); - op->SetAttrMap(this->Attrs()); - } -}; - -class ExpandV2CompositeGradOpMaker : public prim::CompositeGradOpMakerBase { - using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; - - public: - void Apply() override { - auto x = this->GetSingleForwardInput("X"); - auto out_grad = this->GetSingleOutputGrad("Out"); - auto x_grad = this->GetSingleInputGrad("X"); - auto x_grad_p = this->GetOutputPtr(&x_grad); - auto x_grad_name = this->GetOutputName(x_grad); - auto shape = this->Attr>("shape"); - prim::expand_grad( - x, out_grad, paddle::experimental::IntArray(shape), x_grad_p); - VLOG(6) << "Runing expand_v2 composite func"; - this->RecoverOutputName(x_grad, x_grad_name); - } -}; - -template -class ExpandV2DoubleGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("expand_v2"); - op->SetInput("X", this->OutputGrad(framework::GradVarName("X"))); - op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out"))); - if (this->HasInput("expand_shapes_tensor")) { - op->SetInput("expand_shapes_tensor", this->Input("expand_shapes_tensor")); - } - if (this->HasInput("Shape")) { - op->SetInput("Shape", this->Input("Shape")); - } - op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandV2GradNoNeedBufVarsInferer, "X"); - -} // namespace operators -} // namespace paddle - -DECLARE_INFER_SHAPE_FUNCTOR(expand_v2, - ExpandInferShapeFunctor, - PD_INFER_META(phi::ExpandInferMeta)); - -namespace ops = paddle::operators; -REGISTER_OPERATOR(expand_v2, - ops::ExpandV2Op, - ops::ExpandV2OpMaker, - ops::ExpandV2CompositeGradOpMaker, - ops::ExpandV2GradOpMaker, - ops::ExpandV2GradOpMaker, - ExpandInferShapeFunctor); -REGISTER_OPERATOR(expand_v2_grad, - ops::ExpandV2GradOp, - ops::ExpandV2DoubleGradOpMaker, - ops::ExpandV2DoubleGradOpMaker, - ops::ExpandV2GradNoNeedBufVarsInferer); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index ccc5152c519..9d9e842c766 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -682,6 +682,26 @@ func : expand_as_grad no_need_buffer : x +- backward_op : expand_double_grad + forward : expand_grad (Tensor x, Tensor grad_out, IntArray shape) -> Tensor(grad_x) + args : (Tensor grad_x_grad, IntArray shape) + output : Tensor(grad_out_grad) + invoke : expand(grad_x_grad, shape) + +- backward_op : expand_grad + forward : expand (Tensor x, IntArray shape) -> Tensor(out) + args : (Tensor x, Tensor out_grad, IntArray shape) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : expand_grad + data_type : out_grad + no_need_buffer : x + backward : expand_double_grad + composite: expand_grad(x, out_grad, shape, x_grad) + - backward_op : expm1_grad forward : expm1 (Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 10135adf6a0..d2c435c8d9f 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -255,25 +255,6 @@ invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad) no_need_buffer : weight -- backward_op : expand_double_grad - forward : expand_grad (Tensor x, Tensor grad_out, IntArray shape) -> Tensor(grad_x) - args : (Tensor grad_x_grad, IntArray shape) - output : Tensor(grad_out_grad) - invoke : expand(grad_x_grad, shape) - -- backward_op : expand_grad - forward : expand (Tensor x, IntArray shape) -> Tensor(out) - args : (Tensor x, Tensor out_grad, IntArray shape) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : expand_grad - no_need_buffer : x - backward : expand_double_grad - composite: expand_grad(x, out_grad, shape, x_grad) - - backward_op : exponential__grad forward : exponential_ (Tensor x, float lam) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index b440ef6cd98..8052e0cd47d 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -302,15 +302,6 @@ kernel : func : equal -- op : expand - args : (Tensor x, IntArray shape) - output : Tensor - infer_meta : - func : ExpandInferMeta - kernel : - func : expand - backward : expand_grad - - op : exponential_ args : (Tensor x, float lam) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 301eb88662a..2628bab6eda 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -914,7 +914,7 @@ attrs : [bool use_mkldnn = false, bool use_cudnn = false] - op : expand (expand_v2) - backward : expand_grad (expand_v2_grad) + backward : expand_grad (expand_v2_grad), expand_double_grad(expand_v2_double_grad) inputs : x : X attrs : @@ -928,6 +928,7 @@ tensors_name : expand_shapes_tensor extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] + manual_signature : [expand, expand_grad] - op : expand_as (expand_as_v2) backward : expand_as_grad (expand_as_v2_grad) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 34c41c1d0a2..5817122a8cb 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -762,6 +762,16 @@ inplace : (x -> out) backward : exp_grad +- op : expand + args : (Tensor x, IntArray shape = {}) + output : Tensor(out) + infer_meta : + func : ExpandInferMeta + kernel : + func : expand + data_type : x + backward : expand_grad + - op : expand_as args : (Tensor x, Tensor y, int[] target_shape = {}) output : Tensor(out) diff --git a/test/prim/process/test_check_inputs.py b/test/prim/process/test_check_inputs.py index c7392430e5d..bf1d3e7a506 100644 --- a/test/prim/process/test_check_inputs.py +++ b/test/prim/process/test_check_inputs.py @@ -44,7 +44,7 @@ class TestIntarrayInput(unittest.TestCase): tensor_data = paddle.to_tensor(np_data) shape = paddle.to_tensor([2, 3, 4]) net = paddle.jit.to_static(fn) - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): _ = net(tensor_data, shape).numpy() core._set_prim_all_enabled(False) -- GitLab