diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 43fd505acdae415695c8f501a21b8805ed52a0af..fee4b47049301087390123c900f73c14db7ff44c 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -128,23 +125,17 @@ Expand operator tiles the input by given times number. You should set times number for each dimension by providing attribute 'expand_times'. The rank of X should be in [1, 6]. Please note that size of 'expand_times' must be the same with X's rank. Following is a using case: - Input(X) is a 3-D tensor with shape [2, 3, 1]: - [ [[1], [2], [3]], [[4], [5], [6]] ] - Attr(expand_times): [1, 2, 2] - Output(Out) is a 3-D tensor with shape [2, 6, 2]: - [ [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] ] - )DOC"); } }; diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc deleted file mode 100644 index 6df6422f7173c9cf0fcde2624d402484de85b322..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/expand_v2_op.cc +++ /dev/null @@ -1,255 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/expand_v2_op.h" - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" -#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" -#include "paddle/fluid/prim/utils/static/desc_tensor.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -#define MAX_RANK_SUPPORTED 6 - -namespace paddle { -namespace operators { - -class ExpandV2Op : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = - framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const phi::KernelKey& expected_kernel_type) const override { - if (var_name == "expand_shapes_tensor" || var_name == "Shape") { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } -}; - -class ExpandV2OpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Tensor, default Tensor). A tensor with rank in [1, 6]." - "X is the input to be expanded."); - AddInput("Shape", - "(Tensor), optional). If provided, expand according to " - "this given Shape. It has a higher priority than " - "expand_shapes_tensor and the shape attribute.") - .AsDispensable(); - AddInput("expand_shapes_tensor", - "(Tensor Tensor), epxanded shape for X." - "It has a higher priority than shape attribute, but a lower " - "priority than the input Shape") - .AsDuplicable() - .AsDispensable(); - AddOutput("Out", - "(Tensor, default Tensor). A tensor with rank in [1, 6]." - "The rank of Output(Out) have the same with Input(X). " - "After expanding, size of each dimension of Output(Out) is equal " - "to size of the corresponding dimension of Input(X) multiplying " - "the corresponding value given by Attr(expand_times)."); - AddAttr>("shape", "The expanded shape for each dimension.") - .SetDefault({}); - AddComment(R"DOC( -Expand the input to the given shape. The rank of X -should be in [1, 6] and size of 'shape' must be in [1, 6] also. -Following is a using case: - -Input(X) is a 3-D tensor with shape [2, 3, 1]: - - [ - [[1], [2], [3]], - [[4], [5], [6]] - ] - -Attr(shape): [2, 6, 2] - -Output(Out) is a 3-D tensor with shape [2, 6, 2]: - - [ - [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], - [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] - ] - -)DOC"); - } -}; - -class ExpandV2GradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandV2Grad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "ExpandV2Grad"); - - auto x_dims = ctx->GetInputDim("X"); - std::vector expand_shape = ctx->Attrs().Get>("shape"); - if (expand_shape.size() == 0) { - expand_shape = std::vector(x_dims.size(), -1); - } - - auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - auto x_dim_vec = phi::vectorize(x_dims); - auto diff = expand_shape.size() - x_dim_vec.size(); - x_dim_vec.insert(x_dim_vec.begin(), diff, -1); - - for (size_t i = 0; i < expand_shape.size(); ++i) { - if (expand_shape[i] < 0 || x_dim_vec[i] == -1) { - continue; - } else { - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - expand_shape[i], - out_dims[i], - platform::errors::InvalidArgument( - "The size (%d) of the dimension %d of Input(Out@GRAD) should " - "be equal to the crroresponding dimension size of shape(%d).", - out_dims[i], - i, - expand_shape[i])); - } - } - } - auto x_grad_name = framework::GradVarName("X"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const phi::KernelKey& expected_kernel_type) const override { - if (var_name == "expand_shapes_tensor" || var_name == "Shape") { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } -}; - -template -class ExpandV2GradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("expand_v2_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetInput("expand_shapes_tensor", this->Input("expand_shapes_tensor")); - op->SetInput("Shape", this->Input("Shape")); - op->SetAttrMap(this->Attrs()); - } -}; - -class ExpandV2CompositeGradOpMaker : public prim::CompositeGradOpMakerBase { - using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; - - public: - void Apply() override { - auto x = this->GetSingleForwardInput("X"); - auto out_grad = this->GetSingleOutputGrad("Out"); - auto x_grad = this->GetSingleInputGrad("X"); - auto x_grad_p = this->GetOutputPtr(&x_grad); - auto x_grad_name = this->GetOutputName(x_grad); - auto shape = this->Attr>("shape"); - prim::expand_grad( - x, out_grad, paddle::experimental::IntArray(shape), x_grad_p); - VLOG(6) << "Runing expand_v2 composite func"; - this->RecoverOutputName(x_grad, x_grad_name); - } -}; - -template -class ExpandV2DoubleGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("expand_v2"); - op->SetInput("X", this->OutputGrad(framework::GradVarName("X"))); - op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out"))); - if (this->HasInput("expand_shapes_tensor")) { - op->SetInput("expand_shapes_tensor", this->Input("expand_shapes_tensor")); - } - if (this->HasInput("Shape")) { - op->SetInput("Shape", this->Input("Shape")); - } - op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandV2GradNoNeedBufVarsInferer, "X"); - -} // namespace operators -} // namespace paddle - -DECLARE_INFER_SHAPE_FUNCTOR(expand_v2, - ExpandInferShapeFunctor, - PD_INFER_META(phi::ExpandInferMeta)); - -namespace ops = paddle::operators; -REGISTER_OPERATOR(expand_v2, - ops::ExpandV2Op, - ops::ExpandV2OpMaker, - ops::ExpandV2CompositeGradOpMaker, - ops::ExpandV2GradOpMaker, - ops::ExpandV2GradOpMaker, - ExpandInferShapeFunctor); -REGISTER_OPERATOR(expand_v2_grad, - ops::ExpandV2GradOp, - ops::ExpandV2DoubleGradOpMaker, - ops::ExpandV2DoubleGradOpMaker, - ops::ExpandV2GradNoNeedBufVarsInferer); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index ccc5152c51964678aafaf8f5164cc25fcf411f3d..9d9e842c766e72658b9b3e1f2c9e0d79562847ef 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -682,6 +682,26 @@ func : expand_as_grad no_need_buffer : x +- backward_op : expand_double_grad + forward : expand_grad (Tensor x, Tensor grad_out, IntArray shape) -> Tensor(grad_x) + args : (Tensor grad_x_grad, IntArray shape) + output : Tensor(grad_out_grad) + invoke : expand(grad_x_grad, shape) + +- backward_op : expand_grad + forward : expand (Tensor x, IntArray shape) -> Tensor(out) + args : (Tensor x, Tensor out_grad, IntArray shape) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : expand_grad + data_type : out_grad + no_need_buffer : x + backward : expand_double_grad + composite: expand_grad(x, out_grad, shape, x_grad) + - backward_op : expm1_grad forward : expm1 (Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 10135adf6a0f453fdfcd7f360a0ded4dc3578e09..d2c435c8d9f1d4675a9a7e7c0a6f6c2d81177475 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -255,25 +255,6 @@ invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad) no_need_buffer : weight -- backward_op : expand_double_grad - forward : expand_grad (Tensor x, Tensor grad_out, IntArray shape) -> Tensor(grad_x) - args : (Tensor grad_x_grad, IntArray shape) - output : Tensor(grad_out_grad) - invoke : expand(grad_x_grad, shape) - -- backward_op : expand_grad - forward : expand (Tensor x, IntArray shape) -> Tensor(out) - args : (Tensor x, Tensor out_grad, IntArray shape) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : expand_grad - no_need_buffer : x - backward : expand_double_grad - composite: expand_grad(x, out_grad, shape, x_grad) - - backward_op : exponential__grad forward : exponential_ (Tensor x, float lam) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index b440ef6cd98c77576f86286bc3ff544d0dfdd5c4..8052e0cd47d6d5175c4f05be9872dfa3c4812d2c 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -302,15 +302,6 @@ kernel : func : equal -- op : expand - args : (Tensor x, IntArray shape) - output : Tensor - infer_meta : - func : ExpandInferMeta - kernel : - func : expand - backward : expand_grad - - op : exponential_ args : (Tensor x, float lam) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 301eb88662ac4110783e8a70ca57fdac38559b6f..2628bab6eda40df8c6ed11914c2f1f6fb1b77a6e 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -914,7 +914,7 @@ attrs : [bool use_mkldnn = false, bool use_cudnn = false] - op : expand (expand_v2) - backward : expand_grad (expand_v2_grad) + backward : expand_grad (expand_v2_grad), expand_double_grad(expand_v2_double_grad) inputs : x : X attrs : @@ -928,6 +928,7 @@ tensors_name : expand_shapes_tensor extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] + manual_signature : [expand, expand_grad] - op : expand_as (expand_as_v2) backward : expand_as_grad (expand_as_v2_grad) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 34c41c1d0a2900599f91a79e179e10bfe860ca30..5817122a8cb08430960a92734bb7120903b0cc91 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -762,6 +762,16 @@ inplace : (x -> out) backward : exp_grad +- op : expand + args : (Tensor x, IntArray shape = {}) + output : Tensor(out) + infer_meta : + func : ExpandInferMeta + kernel : + func : expand + data_type : x + backward : expand_grad + - op : expand_as args : (Tensor x, Tensor y, int[] target_shape = {}) output : Tensor(out) diff --git a/test/prim/process/test_check_inputs.py b/test/prim/process/test_check_inputs.py index c7392430e5d9362bc35f902f739d33a17c20f930..bf1d3e7a5064fafea3dfc4db2ea0c85be91903f8 100644 --- a/test/prim/process/test_check_inputs.py +++ b/test/prim/process/test_check_inputs.py @@ -44,7 +44,7 @@ class TestIntarrayInput(unittest.TestCase): tensor_data = paddle.to_tensor(np_data) shape = paddle.to_tensor([2, 3, 4]) net = paddle.jit.to_static(fn) - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): _ = net(tensor_data, shape).numpy() core._set_prim_all_enabled(False)