未验证 提交 241b44db 编写于 作者: L lilong12 提交者: GitHub

[API 2.0] adaptive expand op to use shape instead of expand_times (#26206)

* adaptive expand op to 2.0 (align to torch.expand) , test=develop
上级 cbf8ba15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/expand_v2_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class ExpandV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandV2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExpandV2");
auto x_dims = ctx->GetInputDim("X");
auto expand_shape = ctx->Attrs().Get<std::vector<int>>("shape");
if (expand_shape.size() == 0) {
expand_shape = std::vector<int>(x_dims.size(), -1);
}
PADDLE_ENFORCE_GE(
expand_shape.size(), static_cast<size_t>(x_dims.size()),
platform::errors::InvalidArgument(
"The number of elements (%d) of 'shape' for "
"expand_v2 op must be greater than or equal to the rank "
"(%d) of the input.",
expand_shape.size(), static_cast<size_t>(x_dims.size())));
PADDLE_ENFORCE_LE(expand_shape.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The number of elements (%d) of 'shape' for "
"must not be greater than %d.",
expand_shape.size(), MAX_RANK_SUPPORTED));
PADDLE_ENFORCE_GE(expand_shape.size(), 1,
platform::errors::InvalidArgument(
"The number of elements (%d) of 'shape' for "
"must be a positive integer.",
expand_shape.size()));
auto out_rank =
std::max(static_cast<size_t>(x_dims.size()), expand_shape.size());
std::vector<int64_t> out_shape(out_rank);
auto x_dim_vec = framework::vectorize<int>(x_dims);
auto diff = expand_shape.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
for (size_t i = 0; i < expand_shape.size(); ++i) {
if (x_dims[i] == -1) {
out_shape[i] = -1;
} else if (expand_shape[i] == -1) {
out_shape[i] = x_dims[i];
} else {
PADDLE_ENFORCE_GT(
expand_shape[i], 0,
platform::errors::InvalidArgument(
"The %uth element of 'shape' for expand_v2 op must be "
"greater than 0, but the value given is %d.",
i, expand_shape[i]));
out_shape[i] = expand_shape[i];
}
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
if (out_shape[0] == x_dims[0]) {
ctx->ShareLoD("X", "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "expand_shapes_tensor" || var_name == "Shape") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class ExpandV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded.");
AddInput("Shape",
"(Tensor<int>), optional). If provided, expand according to "
"this given Shape. It has a higher priority than "
"expand_shapes_tensor and the shape attribute.")
.AsDispensable();
AddInput("expand_shapes_tensor",
"(Tensor Tensor<int>), epxanded shape for X."
"It has a higher priority than shape attribute, but a lower "
"priority than the input Shape")
.AsDuplicable()
.AsDispensable();
AddOutput("Out",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"The rank of Output(Out) have the same with Input(X). "
"After expanding, size of each dimension of Output(Out) is equal "
"to size of the corresponding dimension of Input(X) multiplying "
"the corresponding value given by Attr(expand_times).");
AddAttr<std::vector<int>>("shape", "The expanded shape for each dimension.")
.SetDefault({});
AddComment(R"DOC(
Expand the input to the given shape. The rank of X
should be in [1, 6] and size of 'shape' must be in [1, 6] also.
Following is a using case:
Input(X) is a 3-D tensor with shape [2, 3, 1]:
[
[[1], [2], [3]],
[[4], [5], [6]]
]
Attr(shape): [2, 6, 2]
Output(Out) is a 3-D tensor with shape [2, 6, 2]:
[
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
]
)DOC");
}
};
class ExpandV2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandV2Grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "ExpandV2Grad");
auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_shape = ctx->Attrs().Get<std::vector<int>>("shape");
if (expand_shape.size() == 0) {
expand_shape = std::vector<int>(x_dims.size(), -1);
}
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dim_vec = framework::vectorize<int>(x_dims);
auto diff = expand_shape.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
for (size_t i = 0; i < expand_shape.size(); ++i) {
if (expand_shape[i] == -1 || x_dim_vec[i] == -1) {
continue;
} else {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
expand_shape[i], out_dims[i],
platform::errors::InvalidArgument(
"The size (%d) of the dimension %d of Input(Out@GRAD) should "
"be equal to the crroresponding dimension size of shape(%d).",
out_dims[i], i, expand_shape[i]));
}
}
}
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "expand_shapes_tensor" || var_name == "Shape") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
template <typename T>
class ExpandV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("expand_v2_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetInput("expand_shapes_tensor", this->Input("expand_shapes_tensor"));
op->SetInput("Shape", this->Input("Shape"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandV2GradNoNeedBufVarsInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(expand_v2, ops::ExpandV2Op, ops::ExpandV2OpMaker,
ops::ExpandV2GradOpMaker<paddle::framework::OpDesc>,
ops::ExpandV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(expand_v2_grad, ops::ExpandV2GradOp,
ops::ExpandV2GradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
expand_v2, ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
expand_v2_grad,
ops::ExpandV2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandV2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandV2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandV2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/expand_v2_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
expand_v2, ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
expand_v2_grad,
ops::ExpandV2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandV2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandV2GradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ExpandV2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandV2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#define MAX_RANK_SUPPORTED 6
#define EXPAND_TEMPLATE(z, n, data) \
case n + 1: { \
Expand<n + 1>(context); \
break; \
}
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
#define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
#define EXPAND_GRAD_CASE(n) \
case n: { \
ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
break; \
}
#define EXPAND_GRAD_TEMPLATE(z, n, data) \
BOOST_PP_IF(COND(n), EXPAND_GRAD_CASE(n), )
#define REP_EXPAND_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_GRAD_TEMPLATE, ~)
namespace paddle {
namespace operators {
inline std::vector<int> get_expand_shape(
const framework::ExecutionContext& ctx) {
if (ctx.HasInput("Shape")) {
auto* shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
auto* shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(shape_tensor->place())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
}
auto vec_shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
return vec_shape;
}
auto list_expand_shapes_tensor =
ctx.MultiInput<framework::Tensor>("expand_shapes_tensor");
if (list_expand_shapes_tensor.size() > 0) {
// get tensor from
std::vector<int> vec_epxand_shape;
for (size_t i = 0; i < list_expand_shapes_tensor.size(); ++i) {
auto tensor = list_expand_shapes_tensor[i];
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_epxand_shape.push_back(*temp.data<int32_t>());
} else {
vec_epxand_shape.push_back(*tensor->data<int32_t>());
}
}
return vec_epxand_shape;
} else {
return ctx.Attr<std::vector<int>>("shape");
}
}
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::To32BitIndex;
template <typename DeviceContext, typename T>
class ExpandV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rank = context.Input<Tensor>("X")->dims().size();
PADDLE_ENFORCE_GE(
rank, 1,
platform::errors::InvalidArgument(
"The rank of the input 'X' for expand_v2 op must be positive, "
"but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'X' for expand_v2 op must be less than "
"or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, rank));
auto expand_shape = get_expand_shape(context);
auto shape_size = expand_shape.size();
PADDLE_ENFORCE_GE(
shape_size, rank,
platform::errors::InvalidArgument(
"The number (%d) of elements of 'shape' for expand_v2 op must be "
"greater than or equal to the rank (%d) of the input 'X'.",
shape_size, rank));
PADDLE_ENFORCE_LE(
shape_size, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The number (%d) of elements of 'shape' for expand_v2 op must be "
"less than or equal to %d.",
shape_size, MAX_RANK_SUPPORTED));
rank = std::max(rank, static_cast<int>(shape_size));
switch (rank) { REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) }
}
protected:
template <int Rank>
void Expand(const framework::ExecutionContext& context) const {
auto* in0 = context.Input<Tensor>("X");
auto in_dims = in0->dims();
auto expand_shape = get_expand_shape(context);
auto vec_in_dims = framework::vectorize<int>(in_dims);
auto diff = expand_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
PADDLE_ENFORCE_NE(expand_shape[i], 0,
platform::errors::InvalidArgument(
"The expanded size cannot be zero."));
if (i < diff) {
PADDLE_ENFORCE_GT(
expand_shape[i], 0,
platform::errors::InvalidArgument(
"The expanded size (%d) for non-existing dimensions must be "
"positive for expand_v2 op.",
expand_shape[i]));
repeat_times[i] = expand_shape[i];
} else if (expand_shape[i] > 0) {
if (vec_in_dims[i] != 1) {
PADDLE_ENFORCE_EQ(
vec_in_dims[i], expand_shape[i],
platform::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in shape for expand_v2 op.",
vec_in_dims[i], expand_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = expand_shape[i];
}
} else {
PADDLE_ENFORCE_EQ(
expand_shape[i], -1,
platform::errors::InvalidArgument(
"When the value in shape is negative for expand_v2 op, "
"only -1 is supported, but the value received is %d.",
expand_shape[i]));
repeat_times[i] = 1;
}
}
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
framework::DDim new_in_dims = framework::make_ddim(vec_in_dims);
framework::DDim out_dims(new_in_dims);
for (size_t i = 0; i < repeat_times.size(); ++i) {
out_dims[i] *= repeat_times[i];
}
out0->Resize(out_dims);
auto x = EigenTensor<T, Rank>::From(*in0, new_in_dims);
out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0, out_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
// use 32-bit index to speed up
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims);
} else {
y.device(place) = x.broadcast(bcast_dims);
}
}
};
template <typename DeviceContext, typename T>
class ExpandV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto expand_shape = get_expand_shape(context);
auto x_dims = in0->dims();
auto vec_in_dims = framework::vectorize<int>(x_dims);
auto diff = expand_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
// 1. reshape_dims_vec is the broadcast parameter.
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
// each dimension expanded, the gradients should be summed to original
// size.
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
if (expand_shape[i] < 0) {
repeat_times[i] = 1;
} else {
repeat_times[i] = expand_shape[i] / vec_in_dims[i];
}
}
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times[i]);
reshape_dims_vec.push_back(vec_in_dims[i]);
}
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
out0);
} else {
PADDLE_ENFORCE_GE(dims, 1,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_v2_grad op must be greater than or "
"equal to 1, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, dims));
switch (dims) { REP_EXPAND_GRAD_TEMPLATE(MAX_RANK_SUPPORTED) }
}
}
protected:
template <int Dims>
void ExpandBackward(const framework::ExecutionContext& context,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec) const {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<int, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
import paddle
# Situation 1: shape is a list(without tensor)
class TestExpandV2OpRank1(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.init_data()
self.inputs = {'X': np.random.random(self.ori_shape).astype("float64")}
self.attrs = {'shape': self.shape}
output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [100]
self.shape = [100]
self.expand_times = [1]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandV2OpRank2_DimExpanding(TestExpandV2OpRank1):
def init_data(self):
self.ori_shape = [120]
self.shape = [2, 120]
self.expand_times = [2, 1]
class TestExpandV2OpRank2(TestExpandV2OpRank1):
def init_data(self):
self.ori_shape = [1, 140]
self.shape = [12, 140]
self.expand_times = [12, 1]
class TestExpandV2OpRank3_Corner(TestExpandV2OpRank1):
def init_data(self):
self.ori_shape = (2, 10, 5)
self.shape = (2, 10, 5)
self.expand_times = (1, 1, 1)
class TestExpandV2OpRank4(TestExpandV2OpRank1):
def init_data(self):
self.ori_shape = (2, 4, 5, 7)
self.shape = (-1, -1, -1, -1)
self.expand_times = (1, 1, 1, 1)
# Situation 2: shape is a list(with tensor)
class TestExpandV2OpRank1_tensor_attr(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.init_data()
expand_shapes_tensor = []
for index, ele in enumerate(self.expand_shape):
expand_shapes_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
'X': np.random.random(self.ori_shape).astype("float64"),
'expand_shapes_tensor': expand_shapes_tensor,
}
self.attrs = {"shape": self.infer_expand_shape}
output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [100]
self.expand_times = [1]
self.expand_shape = [100]
self.infer_expand_shape = [-1]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandV2OpRank2_Corner_tensor_attr(TestExpandV2OpRank1_tensor_attr):
def init_data(self):
self.ori_shape = [12, 14]
self.expand_times = [1, 1]
self.expand_shape = [12, 14]
self.infer_expand_shape = [12, -1]
# Situation 3: shape is a tensor
class TestExpandV2OpRank1_tensor(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.init_data()
self.inputs = {
'X': np.random.random(self.ori_shape).astype("float64"),
'Shape': np.array(self.expand_shape).astype("int32"),
}
self.attrs = {}
output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [100]
self.expand_times = [2, 1]
self.expand_shape = [2, 100]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
# Situation 4: input x is Integer
class TestExpandV2OpInteger(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.inputs = {
'X': np.random.randint(
10, size=(2, 4, 5)).astype("int32")
}
self.attrs = {'shape': [2, 4, 5]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
# Situation 5: input x is Bool
class TestExpandV2OpBoolean(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.inputs = {'X': np.random.randint(2, size=(2, 4, 5)).astype("bool")}
self.attrs = {'shape': [2, 4, 5]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
# Situation 56: input x is Integer
class TestExpandV2OpInt64_t(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.inputs = {
'X': np.random.randint(
10, size=(2, 4, 5)).astype("int64")
}
self.attrs = {'shape': [2, 4, 5]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestExpandV2Error(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
shape = [2, 2]
self.assertRaises(TypeError, paddle.tensor.expand, x1, shape)
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, paddle.tensor.expand, x2, shape)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool")
x3.stop_gradient = True
self.assertRaises(ValueError, paddle.tensor.expand, x3, shape)
# Test python API
class TestExpandV2API(unittest.TestCase):
def test_api(self):
input = np.random.random([12, 14]).astype("float32")
x = fluid.layers.data(
name='x', shape=[12, 14], append_batch_size=False, dtype="float32")
positive_2 = fluid.layers.fill_constant([1], "int32", 12)
expand_shape = fluid.layers.data(
name="expand_shape",
shape=[2],
append_batch_size=False,
dtype="int32")
out_1 = paddle.expand(x, shape=[12, 14])
out_2 = paddle.expand(x, shape=[positive_2, 14])
out_3 = paddle.expand(x, shape=expand_shape)
g0 = fluid.backward.calc_gradient(out_2, x)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3 = exe.run(fluid.default_main_program(),
feed={
"x": input,
"expand_shape":
np.array([12, 14]).astype("int32")
},
fetch_list=[out_1, out_2, out_3])
assert np.array_equal(res_1, np.tile(input, (1, 1)))
assert np.array_equal(res_2, np.tile(input, (1, 1)))
assert np.array_equal(res_3, np.tile(input, (1, 1)))
if __name__ == "__main__":
unittest.main()
......@@ -60,8 +60,10 @@ class TestRetainGraph(unittest.TestCase):
interpolatesv = fake_data
elif type == 'mixed':
alpha = paddle.rand((real_data.shape[0], 1))
alpha = paddle.expand(
alpha, [1, np.prod(real_data.shape) // real_data.shape[0]])
alpha = paddle.expand(alpha, [
real_data.shape[0],
np.prod(real_data.shape) // real_data.shape[0]
])
alpha = paddle.reshape(alpha, real_data.shape)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
......
......@@ -23,7 +23,6 @@ from ..fluid.layers import utils
import numpy as np
# TODO: define functions to manipulate a tensor
from ..fluid.layers import cast #DEFINE_ALIAS
from ..fluid.layers import expand #DEFINE_ALIAS
from ..fluid.layers import expand_as #DEFINE_ALIAS
from ..fluid.layers import reshape #DEFINE_ALIAS
from ..fluid.layers import scatter #DEFINE_ALIAS
......@@ -794,7 +793,6 @@ def tile(x, repeat_times, name=None):
"""
:alias_main: paddle.tile
:alias: paddle.tile,paddle.tensor.tile,paddle.tensor.manipulation.tile
Construct a new tensor by repeating ``x`` the number of times given by the parameter ``repeat_times``.
The rank of ``x`` should be less than or equal to 6, and the size of the shape of ``repeat_times`` should
be less than or equal to 6.
......@@ -807,52 +805,38 @@ def tile(x, repeat_times, name=None):
For example, if the tensor ``x`` is with a shape(4, 3, 2, 2) and ``repeat_times`` is a tuple (3, 2),
``repeat_times`` is first promoted to a tuple (1, 1, 3, 2).
The following gives an using case:
.. code-block:: text
Input(x) is a 3-D tensor of shape (2, 3, 1):
[
[[1], [2], [3]],
[[4], [5], [6]]
]
Attr(repeat_times): [1, 2, 2]
Output(out) is a 3-D tensor of shape (2, 6, 2):
[
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
]
Args:
x (Tensor): The input tensor, its data type should be bool, float32, float64, int32. The rank of ``x`` should be in [1, 6].
repeat_times (Tensor|tuple|list): The number of repeating times for each dimension of the input ``x``. If repeat_times is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If repeat_times is Tensor, it should be an 1-D Tensor. The size of its shape should be in [1, 6].
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name` .
Returns:
N-D Tensor. The data type is the same as ``x``. After tiling, each dimension of the output is equal to the corresponding dimension of ``x`` multiplying the corresponding value given by ``repeat_times`` .
Raises:
TypeError: The type of ``repeat_times`` must be list, tuple or Tensor.
ValueError: The elements of ``repeat_times`` cannot be negative.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
# example 1:
np_data_1 = np.array([1, 2, 3]).astype('int32')
data_1 = paddle..to_variable(np_data_1)
tiled_1 = paddle.tile(data_1, repeat_times=[2, 1])
# [[1, 2, 3], [1, 2, 3]]
# example 2:
np_repeat_times = np.array([2, 1]).astype("int32")
repeat_times = paddle.to_variable(np_repeat_times)
......@@ -907,3 +891,95 @@ def tile(x, repeat_times, name=None):
helper.append_op(
type='tile', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out
def expand(x, shape, name=None):
"""
:alias_main: paddle.expand
:alias: paddle.expand,paddle.tensor.expand,paddle.tensor.manipulation.expand
Expand the input tensor to a given shape.
The rank of ``x`` should be less than or equal to 6, and the number of elements in ``shape`` should also be less than or equal to 6. The size of the dimension to expand must be 1.
Args:
x (Tensor): The input Tensor with rank in [1, 6]. The data type is bool, float32, float64 or int32.
shape (list|tuple|Tensor): The result shape after expanding. The data type is int32. If shape is a list or tuple, all elements of
it should be integers or Tensors with shape (1,). If shape is a Tensor, it should be an 1-D Tensor.
The value -1 in shape, means keeping the corresponding dimension unchanged.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor: A Tensor with the given shape. The data type is the same as ``x``.
Raises:
TypeError: The type of ``shape`` must be list, tuple or Variable.
ValueError: The elements of ``shape`` must be positive or -1.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
# example 1:
np_data_1 = np.array([1, 2, 3]).astype=('int32)
data_1 = paddle.to_variable(np_data_1)
expanded_1 = paddle.expand(data_1, shape=[2, 3])
# [[1, 2, 3], [1, 2, 3]]
# example 2:
np_shape = np.array([2, 3]).astype=('int32)
shape = paddle.to_variable(np_shape)
expanded_2 = paddle.expand(data_1, shape=shape)
# [[1, 2, 3], [1, 2, 3]]
"""
if in_dygraph_mode():
if isinstance(shape, (list, tuple)):
expand_shape = [
item.numpy()[0] if isinstance(item, Variable) else item
for item in shape
]
return core.ops.expand_v2(x, 'shape', expand_shape)
inputs = {"X": [x]}
attrs = {}
check_variable_and_dtype(
x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand')
check_type(shape, 'shape', (list, tuple, Variable), 'expand')
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True:
raise ValueError("When the data type of input 'x' for expand is bool, "
"you must set its stop_gradient to be False by "
" some_var.stop_gradient = False, supporting "
"some_var as the input.")
helper = LayerHelper('expand', input=x, **locals())
def get_attr_expand_shape(list_expand_shape):
attrs_expand_shape = []
for idx, shape in enumerate(list_expand_shape):
if isinstance(shape, Variable):
attrs_expand_shape.append(-1)
else:
attrs_expand_shape.append(shape)
assert shape > 0 or shape == -1, (
"Every element in shape must be positive or -1.")
return attrs_expand_shape
if isinstance(shape, Variable):
shape.stop_gradient = True
inputs['Shape'] = shape
elif isinstance(shape, (list, tuple)):
attrs['shape'] = get_attr_expand_shape(shape)
if utils._contain_var(shape):
inputs['expand_shapes_tensor'] = utils._convert_to_tensor_list(
shape)
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='expand_v2', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册