From 241b44db14456b2e140d831a6a3df58e76b675f2 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Sun, 16 Aug 2020 12:24:38 +0800 Subject: [PATCH] [API 2.0] adaptive expand op to use shape instead of expand_times (#26206) * adaptive expand op to 2.0 (align to torch.expand) , test=develop --- paddle/fluid/operators/expand_v2_op.cc | 255 +++++++++++++++ paddle/fluid/operators/expand_v2_op.cu | 32 ++ paddle/fluid/operators/expand_v2_op.h | 296 ++++++++++++++++++ .../tests/unittests/test_expand_v2_op.py | 234 ++++++++++++++ .../tests/unittests/test_retain_graph.py | 6 +- python/paddle/tensor/manipulation.py | 108 ++++++- 6 files changed, 913 insertions(+), 18 deletions(-) create mode 100644 paddle/fluid/operators/expand_v2_op.cc create mode 100644 paddle/fluid/operators/expand_v2_op.cu create mode 100644 paddle/fluid/operators/expand_v2_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_expand_v2_op.py diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc new file mode 100644 index 0000000000..359d512c34 --- /dev/null +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -0,0 +1,255 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/expand_v2_op.h" +#include +#include +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class ExpandV2Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandV2"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExpandV2"); + auto x_dims = ctx->GetInputDim("X"); + auto expand_shape = ctx->Attrs().Get>("shape"); + + if (expand_shape.size() == 0) { + expand_shape = std::vector(x_dims.size(), -1); + } + + PADDLE_ENFORCE_GE( + expand_shape.size(), static_cast(x_dims.size()), + platform::errors::InvalidArgument( + "The number of elements (%d) of 'shape' for " + "expand_v2 op must be greater than or equal to the rank " + "(%d) of the input.", + expand_shape.size(), static_cast(x_dims.size()))); + PADDLE_ENFORCE_LE(expand_shape.size(), MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The number of elements (%d) of 'shape' for " + "must not be greater than %d.", + expand_shape.size(), MAX_RANK_SUPPORTED)); + PADDLE_ENFORCE_GE(expand_shape.size(), 1, + platform::errors::InvalidArgument( + "The number of elements (%d) of 'shape' for " + "must be a positive integer.", + expand_shape.size())); + + auto out_rank = + std::max(static_cast(x_dims.size()), expand_shape.size()); + std::vector out_shape(out_rank); + auto x_dim_vec = framework::vectorize(x_dims); + auto diff = expand_shape.size() - x_dim_vec.size(); + x_dim_vec.insert(x_dim_vec.begin(), diff, -1); + for (size_t i = 0; i < expand_shape.size(); ++i) { + if (x_dims[i] == -1) { + out_shape[i] = -1; + } else if (expand_shape[i] == -1) { + out_shape[i] = x_dims[i]; + } else { + PADDLE_ENFORCE_GT( + expand_shape[i], 0, + platform::errors::InvalidArgument( + "The %uth element of 'shape' for expand_v2 op must be " + "greater than 0, but the value given is %d.", + i, expand_shape[i])); + out_shape[i] = expand_shape[i]; + } + } + + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); + if (out_shape[0] == x_dims[0]) { + ctx->ShareLoD("X", "Out"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "expand_shapes_tensor" || var_name == "Shape") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +class ExpandV2OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "X is the input to be expanded."); + AddInput("Shape", + "(Tensor), optional). If provided, expand according to " + "this given Shape. It has a higher priority than " + "expand_shapes_tensor and the shape attribute.") + .AsDispensable(); + AddInput("expand_shapes_tensor", + "(Tensor Tensor), epxanded shape for X." + "It has a higher priority than shape attribute, but a lower " + "priority than the input Shape") + .AsDuplicable() + .AsDispensable(); + AddOutput("Out", + "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "The rank of Output(Out) have the same with Input(X). " + "After expanding, size of each dimension of Output(Out) is equal " + "to size of the corresponding dimension of Input(X) multiplying " + "the corresponding value given by Attr(expand_times)."); + AddAttr>("shape", "The expanded shape for each dimension.") + .SetDefault({}); + AddComment(R"DOC( +Expand the input to the given shape. The rank of X +should be in [1, 6] and size of 'shape' must be in [1, 6] also. +Following is a using case: + +Input(X) is a 3-D tensor with shape [2, 3, 1]: + + [ + [[1], [2], [3]], + [[4], [5], [6]] + ] + +Attr(shape): [2, 6, 2] + +Output(Out) is a 3-D tensor with shape [2, 6, 2]: + + [ + [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], + [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] + ] + +)DOC"); + } +}; + +class ExpandV2GradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandV2Grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "ExpandV2Grad"); + + auto x_dims = ctx->GetInputDim("X"); + std::vector expand_shape = ctx->Attrs().Get>("shape"); + if (expand_shape.size() == 0) { + expand_shape = std::vector(x_dims.size(), -1); + } + + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_dim_vec = framework::vectorize(x_dims); + auto diff = expand_shape.size() - x_dim_vec.size(); + x_dim_vec.insert(x_dim_vec.begin(), diff, -1); + + for (size_t i = 0; i < expand_shape.size(); ++i) { + if (expand_shape[i] == -1 || x_dim_vec[i] == -1) { + continue; + } else { + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + expand_shape[i], out_dims[i], + platform::errors::InvalidArgument( + "The size (%d) of the dimension %d of Input(Out@GRAD) should " + "be equal to the crroresponding dimension size of shape(%d).", + out_dims[i], i, expand_shape[i])); + } + } + } + auto x_grad_name = framework::GradVarName("X"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "expand_shapes_tensor" || var_name == "Shape") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +template +class ExpandV2GradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("expand_v2_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetInput("expand_shapes_tensor", this->Input("expand_shapes_tensor")); + op->SetInput("Shape", this->Input("Shape")); + op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandV2GradNoNeedBufVarsInferer, "X"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(expand_v2, ops::ExpandV2Op, ops::ExpandV2OpMaker, + ops::ExpandV2GradOpMaker, + ops::ExpandV2GradOpMaker); +REGISTER_OPERATOR(expand_v2_grad, ops::ExpandV2GradOp, + ops::ExpandV2GradNoNeedBufVarsInferer); +REGISTER_OP_CPU_KERNEL( + expand_v2, ops::ExpandV2Kernel, + ops::ExpandV2Kernel, + ops::ExpandV2Kernel, + ops::ExpandV2Kernel, + ops::ExpandV2Kernel); +REGISTER_OP_CPU_KERNEL( + expand_v2_grad, + ops::ExpandV2GradKernel, + ops::ExpandV2GradKernel, + ops::ExpandV2GradKernel, + ops::ExpandV2GradKernel); diff --git a/paddle/fluid/operators/expand_v2_op.cu b/paddle/fluid/operators/expand_v2_op.cu new file mode 100644 index 0000000000..e096dbc27f --- /dev/null +++ b/paddle/fluid/operators/expand_v2_op.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/expand_v2_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + expand_v2, ops::ExpandV2Kernel, + ops::ExpandV2Kernel, + ops::ExpandV2Kernel, + ops::ExpandV2Kernel, + ops::ExpandV2Kernel, + ops::ExpandV2Kernel); +REGISTER_OP_CUDA_KERNEL( + expand_v2_grad, + ops::ExpandV2GradKernel, + ops::ExpandV2GradKernel, + ops::ExpandV2GradKernel, + ops::ExpandV2GradKernel, + ops::ExpandV2GradKernel); diff --git a/paddle/fluid/operators/expand_v2_op.h b/paddle/fluid/operators/expand_v2_op.h new file mode 100644 index 0000000000..ec9c6e62f2 --- /dev/null +++ b/paddle/fluid/operators/expand_v2_op.h @@ -0,0 +1,296 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +#define MAX_RANK_SUPPORTED 6 + +#define EXPAND_TEMPLATE(z, n, data) \ + case n + 1: { \ + Expand(context); \ + break; \ + } +#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~) +#define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED)) +#define EXPAND_GRAD_CASE(n) \ + case n: { \ + ExpandBackward(context, reshape_dims_vec, reduce_dims_vec); \ + break; \ + } +#define EXPAND_GRAD_TEMPLATE(z, n, data) \ + BOOST_PP_IF(COND(n), EXPAND_GRAD_CASE(n), ) +#define REP_EXPAND_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_GRAD_TEMPLATE, ~) + +namespace paddle { +namespace operators { +inline std::vector get_expand_shape( + const framework::ExecutionContext& ctx) { + if (ctx.HasInput("Shape")) { + auto* shape_tensor = ctx.Input("Shape"); + auto* shape_data = shape_tensor->data(); + framework::Tensor cpu_shape_tensor; + if (platform::is_gpu_place(shape_tensor->place())) { + TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); + shape_data = cpu_shape_tensor.data(); + } + auto vec_shape = + std::vector(shape_data, shape_data + shape_tensor->numel()); + return vec_shape; + } + + auto list_expand_shapes_tensor = + ctx.MultiInput("expand_shapes_tensor"); + if (list_expand_shapes_tensor.size() > 0) { + // get tensor from + std::vector vec_epxand_shape; + for (size_t i = 0; i < list_expand_shapes_tensor.size(); ++i) { + auto tensor = list_expand_shapes_tensor[i]; + if (platform::is_gpu_place(tensor->place())) { + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_epxand_shape.push_back(*temp.data()); + } else { + vec_epxand_shape.push_back(*tensor->data()); + } + } + return vec_epxand_shape; + } else { + return ctx.Attr>("shape"); + } +} + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; +template +using EigenTensor = framework::EigenTensor; +using framework::To32BitIndex; + +template +class ExpandV2Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto rank = context.Input("X")->dims().size(); + PADDLE_ENFORCE_GE( + rank, 1, + platform::errors::InvalidArgument( + "The rank of the input 'X' for expand_v2 op must be positive, " + "but the value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The rank of the input 'X' for expand_v2 op must be less than " + "or equal to %d, but the value received is %d.", + MAX_RANK_SUPPORTED, rank)); + auto expand_shape = get_expand_shape(context); + auto shape_size = expand_shape.size(); + PADDLE_ENFORCE_GE( + shape_size, rank, + platform::errors::InvalidArgument( + "The number (%d) of elements of 'shape' for expand_v2 op must be " + "greater than or equal to the rank (%d) of the input 'X'.", + shape_size, rank)); + PADDLE_ENFORCE_LE( + shape_size, MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The number (%d) of elements of 'shape' for expand_v2 op must be " + "less than or equal to %d.", + shape_size, MAX_RANK_SUPPORTED)); + rank = std::max(rank, static_cast(shape_size)); + switch (rank) { REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) } + } + + protected: + template + void Expand(const framework::ExecutionContext& context) const { + auto* in0 = context.Input("X"); + + auto in_dims = in0->dims(); + auto expand_shape = get_expand_shape(context); + auto vec_in_dims = framework::vectorize(in_dims); + auto diff = expand_shape.size() - vec_in_dims.size(); + vec_in_dims.insert(vec_in_dims.begin(), diff, 1); + std::vector repeat_times(vec_in_dims.size()); + for (size_t i = 0; i < vec_in_dims.size(); ++i) { + PADDLE_ENFORCE_NE(expand_shape[i], 0, + platform::errors::InvalidArgument( + "The expanded size cannot be zero.")); + if (i < diff) { + PADDLE_ENFORCE_GT( + expand_shape[i], 0, + platform::errors::InvalidArgument( + "The expanded size (%d) for non-existing dimensions must be " + "positive for expand_v2 op.", + expand_shape[i])); + repeat_times[i] = expand_shape[i]; + } else if (expand_shape[i] > 0) { + if (vec_in_dims[i] != 1) { + PADDLE_ENFORCE_EQ( + vec_in_dims[i], expand_shape[i], + platform::errors::InvalidArgument( + "The value (%d) of the non-singleton dimension does not match" + " the corresponding value (%d) in shape for expand_v2 op.", + vec_in_dims[i], expand_shape[i])); + repeat_times[i] = 1; + } else { + repeat_times[i] = expand_shape[i]; + } + } else { + PADDLE_ENFORCE_EQ( + expand_shape[i], -1, + platform::errors::InvalidArgument( + "When the value in shape is negative for expand_v2 op, " + "only -1 is supported, but the value received is %d.", + expand_shape[i])); + repeat_times[i] = 1; + } + } + + auto* out0 = context.Output("Out"); + Eigen::DSizes bcast_dims; + for (size_t i = 0; i < repeat_times.size(); ++i) { + bcast_dims[i] = repeat_times[i]; + } + + framework::DDim new_in_dims = framework::make_ddim(vec_in_dims); + framework::DDim out_dims(new_in_dims); + for (size_t i = 0; i < repeat_times.size(); ++i) { + out_dims[i] *= repeat_times[i]; + } + + out0->Resize(out_dims); + auto x = EigenTensor::From(*in0, new_in_dims); + out0->mutable_data(context.GetPlace()); + auto y = EigenTensor::From(*out0, out_dims); + auto& place = + *context.template device_context().eigen_device(); + // use 32-bit index to speed up + bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); + if (use_32bit_index) { + To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims); + } else { + y.device(place) = x.broadcast(bcast_dims); + } + } +}; + +template +class ExpandV2GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("X"); + auto expand_shape = get_expand_shape(context); + auto x_dims = in0->dims(); + auto vec_in_dims = framework::vectorize(x_dims); + auto diff = expand_shape.size() - vec_in_dims.size(); + vec_in_dims.insert(vec_in_dims.begin(), diff, 1); + // 1. reshape_dims_vec is the broadcast parameter. + // 2. reduce_dims_vec is the dimension parameter to compute gradients. For + // each dimension expanded, the gradients should be summed to original + // size. + std::vector repeat_times(vec_in_dims.size()); + for (size_t i = 0; i < vec_in_dims.size(); ++i) { + if (expand_shape[i] < 0) { + repeat_times[i] = 1; + } else { + repeat_times[i] = expand_shape[i] / vec_in_dims[i]; + } + } + std::vector reshape_dims_vec; + std::vector reduce_dims_vec; + for (size_t i = 0; i < repeat_times.size(); ++i) { + reduce_dims_vec.push_back(reshape_dims_vec.size()); + reshape_dims_vec.push_back(repeat_times[i]); + reshape_dims_vec.push_back(vec_in_dims[i]); + } + + int dims = reduce_dims_vec.size(); + + bool just_copy = true; + for (size_t i = 0; i < repeat_times.size(); i++) { + if (repeat_times[i] != 1) { + just_copy = false; + break; + } + } + // no need reduce, just copy + if (just_copy) { + auto* in0 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + out0->mutable_data(context.GetPlace()); + framework::TensorCopy(*in0, context.GetPlace(), context.device_context(), + out0); + } else { + PADDLE_ENFORCE_GE(dims, 1, + platform::errors::InvalidArgument( + "The rank of the input 'Out@GRAD' for " + "expand_v2_grad op must be greater than or " + "equal to 1, but the value received is %d.", + dims)); + PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The rank of the input 'Out@GRAD' for " + "expand_v2_grad op must be less than or equal " + "to %d, but the value received is %d.", + MAX_RANK_SUPPORTED, dims)); + switch (dims) { REP_EXPAND_GRAD_TEMPLATE(MAX_RANK_SUPPORTED) } + } + } + + protected: + template + void ExpandBackward(const framework::ExecutionContext& context, + const std::vector& reshape_dims_vec, + const std::vector& reduce_dims_vec) const { + size_t reshape_size = reshape_dims_vec.size(); + size_t reduce_size = reduce_dims_vec.size(); + auto* in0 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + out0->mutable_data(context.GetPlace()); + auto x_grad = EigenVector::Flatten(*out0); + Eigen::DSizes reshape_dims; + for (size_t i = 0; i < reshape_size; ++i) { + reshape_dims[i] = reshape_dims_vec[i]; + } + Eigen::DSizes reduce_dims; + for (size_t i = 0; i < reduce_size; ++i) { + reduce_dims[i] = reduce_dims_vec[i]; + } + auto out_grad = EigenVector::Flatten(*in0); + x_grad.device( + *context.template device_context().eigen_device()) = + out_grad.reshape(reshape_dims) + .sum(reduce_dims) + .reshape(x_grad.dimensions()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py new file mode 100644 index 0000000000..94669bc28f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py @@ -0,0 +1,234 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +import paddle + + +# Situation 1: shape is a list(without tensor) +class TestExpandV2OpRank1(OpTest): + def setUp(self): + self.op_type = "expand_v2" + self.init_data() + + self.inputs = {'X': np.random.random(self.ori_shape).astype("float64")} + self.attrs = {'shape': self.shape} + output = np.tile(self.inputs['X'], self.expand_times) + self.outputs = {'Out': output} + + def init_data(self): + self.ori_shape = [100] + self.shape = [100] + self.expand_times = [1] + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestExpandV2OpRank2_DimExpanding(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [120] + self.shape = [2, 120] + self.expand_times = [2, 1] + + +class TestExpandV2OpRank2(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [1, 140] + self.shape = [12, 140] + self.expand_times = [12, 1] + + +class TestExpandV2OpRank3_Corner(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = (2, 10, 5) + self.shape = (2, 10, 5) + self.expand_times = (1, 1, 1) + + +class TestExpandV2OpRank4(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = (2, 4, 5, 7) + self.shape = (-1, -1, -1, -1) + self.expand_times = (1, 1, 1, 1) + + +# Situation 2: shape is a list(with tensor) +class TestExpandV2OpRank1_tensor_attr(OpTest): + def setUp(self): + self.op_type = "expand_v2" + self.init_data() + expand_shapes_tensor = [] + for index, ele in enumerate(self.expand_shape): + expand_shapes_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = { + 'X': np.random.random(self.ori_shape).astype("float64"), + 'expand_shapes_tensor': expand_shapes_tensor, + } + self.attrs = {"shape": self.infer_expand_shape} + output = np.tile(self.inputs['X'], self.expand_times) + self.outputs = {'Out': output} + + def init_data(self): + self.ori_shape = [100] + self.expand_times = [1] + self.expand_shape = [100] + self.infer_expand_shape = [-1] + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestExpandV2OpRank2_Corner_tensor_attr(TestExpandV2OpRank1_tensor_attr): + def init_data(self): + self.ori_shape = [12, 14] + self.expand_times = [1, 1] + self.expand_shape = [12, 14] + self.infer_expand_shape = [12, -1] + + +# Situation 3: shape is a tensor +class TestExpandV2OpRank1_tensor(OpTest): + def setUp(self): + self.op_type = "expand_v2" + self.init_data() + + self.inputs = { + 'X': np.random.random(self.ori_shape).astype("float64"), + 'Shape': np.array(self.expand_shape).astype("int32"), + } + self.attrs = {} + output = np.tile(self.inputs['X'], self.expand_times) + self.outputs = {'Out': output} + + def init_data(self): + self.ori_shape = [100] + self.expand_times = [2, 1] + self.expand_shape = [2, 100] + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +# Situation 4: input x is Integer +class TestExpandV2OpInteger(OpTest): + def setUp(self): + self.op_type = "expand_v2" + self.inputs = { + 'X': np.random.randint( + 10, size=(2, 4, 5)).astype("int32") + } + self.attrs = {'shape': [2, 4, 5]} + output = np.tile(self.inputs['X'], (1, 1, 1)) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +# Situation 5: input x is Bool +class TestExpandV2OpBoolean(OpTest): + def setUp(self): + self.op_type = "expand_v2" + self.inputs = {'X': np.random.randint(2, size=(2, 4, 5)).astype("bool")} + self.attrs = {'shape': [2, 4, 5]} + output = np.tile(self.inputs['X'], (1, 1, 1)) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +# Situation 56: input x is Integer +class TestExpandV2OpInt64_t(OpTest): + def setUp(self): + self.op_type = "expand_v2" + self.inputs = { + 'X': np.random.randint( + 10, size=(2, 4, 5)).astype("int64") + } + self.attrs = {'shape': [2, 4, 5]} + output = np.tile(self.inputs['X'], (1, 1, 1)) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +class TestExpandV2Error(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + shape = [2, 2] + self.assertRaises(TypeError, paddle.tensor.expand, x1, shape) + x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8") + self.assertRaises(TypeError, paddle.tensor.expand, x2, shape) + x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool") + x3.stop_gradient = True + self.assertRaises(ValueError, paddle.tensor.expand, x3, shape) + + +# Test python API +class TestExpandV2API(unittest.TestCase): + def test_api(self): + input = np.random.random([12, 14]).astype("float32") + x = fluid.layers.data( + name='x', shape=[12, 14], append_batch_size=False, dtype="float32") + + positive_2 = fluid.layers.fill_constant([1], "int32", 12) + expand_shape = fluid.layers.data( + name="expand_shape", + shape=[2], + append_batch_size=False, + dtype="int32") + + out_1 = paddle.expand(x, shape=[12, 14]) + out_2 = paddle.expand(x, shape=[positive_2, 14]) + out_3 = paddle.expand(x, shape=expand_shape) + + g0 = fluid.backward.calc_gradient(out_2, x) + + exe = fluid.Executor(place=fluid.CPUPlace()) + res_1, res_2, res_3 = exe.run(fluid.default_main_program(), + feed={ + "x": input, + "expand_shape": + np.array([12, 14]).astype("int32") + }, + fetch_list=[out_1, out_2, out_3]) + assert np.array_equal(res_1, np.tile(input, (1, 1))) + assert np.array_equal(res_2, np.tile(input, (1, 1))) + assert np.array_equal(res_3, np.tile(input, (1, 1))) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_retain_graph.py b/python/paddle/fluid/tests/unittests/test_retain_graph.py index db4b922afc..360a2de1df 100644 --- a/python/paddle/fluid/tests/unittests/test_retain_graph.py +++ b/python/paddle/fluid/tests/unittests/test_retain_graph.py @@ -60,8 +60,10 @@ class TestRetainGraph(unittest.TestCase): interpolatesv = fake_data elif type == 'mixed': alpha = paddle.rand((real_data.shape[0], 1)) - alpha = paddle.expand( - alpha, [1, np.prod(real_data.shape) // real_data.shape[0]]) + alpha = paddle.expand(alpha, [ + real_data.shape[0], + np.prod(real_data.shape) // real_data.shape[0] + ]) alpha = paddle.reshape(alpha, real_data.shape) interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) else: diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 7d0f7ec625..9e2b7286ba 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -23,7 +23,6 @@ from ..fluid.layers import utils import numpy as np # TODO: define functions to manipulate a tensor from ..fluid.layers import cast #DEFINE_ALIAS -from ..fluid.layers import expand #DEFINE_ALIAS from ..fluid.layers import expand_as #DEFINE_ALIAS from ..fluid.layers import reshape #DEFINE_ALIAS from ..fluid.layers import scatter #DEFINE_ALIAS @@ -794,7 +793,6 @@ def tile(x, repeat_times, name=None): """ :alias_main: paddle.tile :alias: paddle.tile,paddle.tensor.tile,paddle.tensor.manipulation.tile - Construct a new tensor by repeating ``x`` the number of times given by the parameter ``repeat_times``. The rank of ``x`` should be less than or equal to 6, and the size of the shape of ``repeat_times`` should be less than or equal to 6. @@ -807,52 +805,38 @@ def tile(x, repeat_times, name=None): For example, if the tensor ``x`` is with a shape(4, 3, 2, 2) and ``repeat_times`` is a tuple (3, 2), ``repeat_times`` is first promoted to a tuple (1, 1, 3, 2). The following gives an using case: - - .. code-block:: text - Input(x) is a 3-D tensor of shape (2, 3, 1): - [ [[1], [2], [3]], [[4], [5], [6]] ] - Attr(repeat_times): [1, 2, 2] - Output(out) is a 3-D tensor of shape (2, 6, 2): - [ [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] ] - Args: x (Tensor): The input tensor, its data type should be bool, float32, float64, int32. The rank of ``x`` should be in [1, 6]. repeat_times (Tensor|tuple|list): The number of repeating times for each dimension of the input ``x``. If repeat_times is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If repeat_times is Tensor, it should be an 1-D Tensor. The size of its shape should be in [1, 6]. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name` . - Returns: N-D Tensor. The data type is the same as ``x``. After tiling, each dimension of the output is equal to the corresponding dimension of ``x`` multiplying the corresponding value given by ``repeat_times`` . - Raises: TypeError: The type of ``repeat_times`` must be list, tuple or Tensor. ValueError: The elements of ``repeat_times`` cannot be negative. - Examples: .. code-block:: python - import paddle import numpy as np paddle.disable_static() - # example 1: np_data_1 = np.array([1, 2, 3]).astype('int32') data_1 = paddle..to_variable(np_data_1) tiled_1 = paddle.tile(data_1, repeat_times=[2, 1]) # [[1, 2, 3], [1, 2, 3]] - # example 2: np_repeat_times = np.array([2, 1]).astype("int32") repeat_times = paddle.to_variable(np_repeat_times) @@ -907,3 +891,95 @@ def tile(x, repeat_times, name=None): helper.append_op( type='tile', inputs=inputs, outputs={'Out': out}, attrs=attrs) return out + + +def expand(x, shape, name=None): + """ + :alias_main: paddle.expand + :alias: paddle.expand,paddle.tensor.expand,paddle.tensor.manipulation.expand + + Expand the input tensor to a given shape. + + The rank of ``x`` should be less than or equal to 6, and the number of elements in ``shape`` should also be less than or equal to 6. The size of the dimension to expand must be 1. + + + Args: + x (Tensor): The input Tensor with rank in [1, 6]. The data type is bool, float32, float64 or int32. + shape (list|tuple|Tensor): The result shape after expanding. The data type is int32. If shape is a list or tuple, all elements of + it should be integers or Tensors with shape (1,). If shape is a Tensor, it should be an 1-D Tensor. + The value -1 in shape, means keeping the corresponding dimension unchanged. + name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Tensor: A Tensor with the given shape. The data type is the same as ``x``. + + Raises: + TypeError: The type of ``shape`` must be list, tuple or Variable. + ValueError: The elements of ``shape`` must be positive or -1. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + paddle.disable_static() + + # example 1: + np_data_1 = np.array([1, 2, 3]).astype=('int32) + data_1 = paddle.to_variable(np_data_1) + expanded_1 = paddle.expand(data_1, shape=[2, 3]) + # [[1, 2, 3], [1, 2, 3]] + + # example 2: + np_shape = np.array([2, 3]).astype=('int32) + shape = paddle.to_variable(np_shape) + expanded_2 = paddle.expand(data_1, shape=shape) + # [[1, 2, 3], [1, 2, 3]] + """ + if in_dygraph_mode(): + if isinstance(shape, (list, tuple)): + expand_shape = [ + item.numpy()[0] if isinstance(item, Variable) else item + for item in shape + ] + + return core.ops.expand_v2(x, 'shape', expand_shape) + + inputs = {"X": [x]} + attrs = {} + check_variable_and_dtype( + x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') + check_type(shape, 'shape', (list, tuple, Variable), 'expand') + if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True: + raise ValueError("When the data type of input 'x' for expand is bool, " + "you must set its stop_gradient to be False by " + " some_var.stop_gradient = False, supporting " + "some_var as the input.") + + helper = LayerHelper('expand', input=x, **locals()) + + def get_attr_expand_shape(list_expand_shape): + attrs_expand_shape = [] + for idx, shape in enumerate(list_expand_shape): + if isinstance(shape, Variable): + attrs_expand_shape.append(-1) + else: + attrs_expand_shape.append(shape) + assert shape > 0 or shape == -1, ( + "Every element in shape must be positive or -1.") + return attrs_expand_shape + + if isinstance(shape, Variable): + shape.stop_gradient = True + inputs['Shape'] = shape + elif isinstance(shape, (list, tuple)): + attrs['shape'] = get_attr_expand_shape(shape) + if utils._contain_var(shape): + inputs['expand_shapes_tensor'] = utils._convert_to_tensor_list( + shape) + + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='expand_v2', inputs=inputs, outputs={'Out': out}, attrs=attrs) + return out -- GitLab