未验证 提交 fbd4d3cc 编写于 作者: L lilong12 提交者: GitHub

[API 2.0] add paddle.tile op (#26245)

* add tile_op, test=develop
上级 e4033a06
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tile_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class TileOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Tile");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Tile");
auto x_dims = ctx->GetInputDim("X");
auto repeat_times = ctx->Attrs().Get<std::vector<int>>("repeat_times");
if (repeat_times.size() == 0) {
repeat_times = std::vector<int>(x_dims.size(), -1);
}
PADDLE_ENFORCE_LE(
x_dims.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'x' for tile op "
"must not be greater than %d, but the value received is %d.",
MAX_RANK_SUPPORTED, x_dims.size()));
PADDLE_ENFORCE_LE(
repeat_times.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The size of the shape of input 'repeat_times' for tile op "
"must not be greater than %d, but the value received is %d.",
MAX_RANK_SUPPORTED, repeat_times.size()));
PADDLE_ENFORCE_GE(
repeat_times.size(), 1,
platform::errors::InvalidArgument(
"The size of the shape of input 'repeat_times' for tile op "
"must be positive integers, but the value received is %d.",
repeat_times.size()));
auto out_rank =
std::max(static_cast<size_t>(x_dims.size()), repeat_times.size());
std::vector<int64_t> out_shape(out_rank);
auto x_dim_vec = framework::vectorize<int>(x_dims);
if (x_dim_vec.size() > repeat_times.size()) {
auto diff = x_dim_vec.size() - repeat_times.size();
repeat_times.insert(repeat_times.begin(), diff, -1);
} else {
auto diff = repeat_times.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
}
for (size_t i = 0; i < repeat_times.size(); ++i) {
if (x_dim_vec[i] == -1 || repeat_times[i] == -1) {
out_shape[i] = -1;
} else {
PADDLE_ENFORCE_GT(
repeat_times[i], 0,
platform::errors::InvalidArgument(
"Every element of the input 'repeat_times' for tile op must be "
"greater than 0, but the value given is %d.",
repeat_times[i]));
out_shape[i] = x_dim_vec[i] * repeat_times[i];
}
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
if (out_shape[0] == x_dims[0]) {
ctx->ShareLoD("X", "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "repeat_times_tensor" || var_name == "RepeatTimes") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class TileOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>). X is the input to be titled.");
AddInput(
"RepeatTimes",
"(Tensor<int>, optional). If provided, it is the number of repeat times"
" along specific axis. It has a higher priority than "
"repeat_times_tensor and the repeat_times attribute.")
.AsDispensable();
AddInput("repeat_times_tensor",
"(Tensor Tensor<int>), repeat times for X."
"It has a higher priority than repeat_times, but a lower priority "
"than RepeatTimes")
.AsDuplicable()
.AsDispensable();
AddOutput("Out",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"After tiling, size of each dimension of Output(Out) is equal "
"to size of the corresponding dimension of Input(X) multiplying "
"the corresponding value given by Attr(repeat_times).");
AddAttr<std::vector<int>>("repeat_times",
"The number of repeat times for each dimension.")
.SetDefault({});
AddComment(R"DOC(
Tile operator repeats the input by given times number. You should set times
number for each dimension by providing attribute 'repeat_times'. The rank of X
should be in [1, 6]. Please note that size of 'repeat_times' must be the same
with X's rank. Following is a using case:
Input(X) is a 3-D tensor with shape [2, 3, 1]:
[
[[1], [2], [3]],
[[4], [5], [6]]
]
Attr(repeat_times): [1, 2, 2]
Output(Out) is a 3-D tensor with shape [2, 6, 2]:
[
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
]
)DOC");
}
};
class TileGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "TileGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "TileGrad");
auto x_dims = ctx->GetInputDim("X");
std::vector<int> repeat_times =
ctx->Attrs().Get<std::vector<int>>("repeat_times");
if (repeat_times.size() == 0) {
repeat_times = std::vector<int>(x_dims.size(), -1);
}
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dim_vec = framework::vectorize<int>(x_dims);
if (x_dim_vec.size() > repeat_times.size()) {
auto diff = x_dim_vec.size() - repeat_times.size();
repeat_times.insert(repeat_times.begin(), diff, -1);
} else {
auto diff = repeat_times.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
}
for (size_t i = 0; i < repeat_times.size(); ++i) {
if (repeat_times[i] == -1 || x_dim_vec[i] == -1) {
continue;
} else {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
x_dim_vec[i] * repeat_times[i], out_dims[i],
platform::errors::InvalidArgument(
"The size (%d) of the dimension %d of Input(Out@GRAD) should "
"be equal to the multiplication of the crroresponding "
"dimension size of Input(X) (%d) and repeat_times (%d).",
out_dims[i], i, x_dim_vec[i], repeat_times[i]));
}
}
}
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "repeat_times_tensor" || var_name == "RepeatTimes") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
template <typename T>
class TileGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("tile_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetInput("repeat_times_tensor", this->Input("repeat_times_tensor"));
op->SetInput("RepeatTimes", this->Input("RepeatTimes"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(TileGradNoNeedBufVarsInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(tile, ops::TileOp, ops::TileOpMaker,
ops::TileGradOpMaker<paddle::framework::OpDesc>,
ops::TileGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(tile_grad, ops::TileGradOp,
ops::TileGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
tile, ops::TileKernel<paddle::platform::CPUDeviceContext, float>,
ops::TileKernel<paddle::platform::CPUDeviceContext, double>,
ops::TileKernel<paddle::platform::CPUDeviceContext, int>,
ops::TileKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TileKernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
tile_grad, ops::TileGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TileGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TileGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::TileGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tile_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
tile, ops::TileKernel<paddle::platform::CUDADeviceContext, float>,
ops::TileKernel<paddle::platform::CUDADeviceContext, double>,
ops::TileKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::TileKernel<paddle::platform::CUDADeviceContext, int>,
ops::TileKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TileKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
tile_grad, ops::TileGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TileGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::TileGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::TileGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::TileGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#define MAX_RANK_SUPPORTED 6
#define TILE_TEMPLATE(z, n, data) \
case n + 1: { \
Tile<n + 1>(context); \
break; \
}
#define REP_TILE_TEMPLATE(n) BOOST_PP_REPEAT(n, TILE_TEMPLATE, ~)
#define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
#define TILE_GRAD_CASE(n) \
case n: { \
TileBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
break; \
}
#define TILE_GRAD_TEMPLATE(z, n, data) BOOST_PP_IF(COND(n), TILE_GRAD_CASE(n), )
#define REP_TILE_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, TILE_GRAD_TEMPLATE, ~)
namespace paddle {
namespace operators {
inline std::vector<int> get_repeat_times(
const framework::ExecutionContext& ctx) {
if (ctx.HasInput("RepeatTimes")) {
auto* repeat_tensor = ctx.Input<framework::LoDTensor>("RepeatTimes");
auto* repeat_data = repeat_tensor->data<int>();
framework::Tensor cpu_repeat_tensor;
if (platform::is_gpu_place(repeat_tensor->place())) {
TensorCopySync(*repeat_tensor, platform::CPUPlace(), &cpu_repeat_tensor);
repeat_data = cpu_repeat_tensor.data<int>();
}
auto vec_repeat_times =
std::vector<int>(repeat_data, repeat_data + repeat_tensor->numel());
return vec_repeat_times;
}
auto list_repeat_times_tensor =
ctx.MultiInput<framework::Tensor>("repeat_times_tensor");
if (list_repeat_times_tensor.size() > 0) {
// get tensor from
std::vector<int> vec_repeat_times;
for (size_t i = 0; i < list_repeat_times_tensor.size(); ++i) {
auto tensor = list_repeat_times_tensor[i];
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_repeat_times.push_back(*temp.data<int32_t>());
} else {
vec_repeat_times.push_back(*tensor->data<int32_t>());
}
}
return vec_repeat_times;
} else {
return ctx.Attr<std::vector<int>>("repeat_times");
}
}
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::To32BitIndex;
template <typename DeviceContext, typename T>
class TileKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rank = context.Input<Tensor>("X")->dims().size();
PADDLE_ENFORCE_GE(
rank, 1, platform::errors::InvalidArgument(
"The rank of the input 'x' for tile op must be a positive "
"integer, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'x' for tile op "
"must be less than or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, rank));
auto repeat_times = get_repeat_times(context);
int repeat_times_size = repeat_times.size();
PADDLE_ENFORCE_GE(
repeat_times_size, 1,
platform::errors::InvalidArgument(
"The number of elements of the input 'repeat_times' for tile "
"op must be positive, but the value received is %d.",
repeat_times_size));
PADDLE_ENFORCE_LE(
repeat_times_size, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The number of elements of the input 'repeat_times' for tile op "
"must be less than or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, repeat_times_size));
rank = std::max(rank, repeat_times_size);
switch (rank) { REP_TILE_TEMPLATE(MAX_RANK_SUPPORTED) }
}
protected:
template <int Rank>
void Tile(const framework::ExecutionContext& context) const {
auto* in0 = context.Input<Tensor>("X");
auto in_dims = in0->dims();
auto repeat_times = get_repeat_times(context);
for (size_t i = 0; i < repeat_times.size(); ++i) {
PADDLE_ENFORCE_GT(
repeat_times[i], 0,
platform::errors::InvalidArgument(
"All elements of the input 'repeat_times' for tile op must "
"be positive integers, but the value received is %d.",
repeat_times[i]));
}
auto vec_in_dims = framework::vectorize<int>(in_dims);
if (repeat_times.size() < vec_in_dims.size()) {
int diff = vec_in_dims.size() - repeat_times.size();
repeat_times.insert(repeat_times.begin(), diff, 1);
} else {
int diff = repeat_times.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
}
PADDLE_ENFORCE_EQ(
repeat_times.size(), vec_in_dims.size(),
platform::errors::InvalidArgument(
"The rank (%d) of the input 'x' and the rank (%d) of the input "
"'repeat_times' for tile op must match after promotion.",
vec_in_dims.size(), repeat_times.size()));
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
framework::DDim new_in_dims = framework::make_ddim(vec_in_dims);
framework::DDim out_dims(new_in_dims);
for (size_t i = 0; i < repeat_times.size(); ++i) {
out_dims[i] *= repeat_times[i];
}
out0->Resize(out_dims);
auto x = EigenTensor<T, Rank>::From(*in0, new_in_dims);
out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0, out_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
// use 32-bit index to speed up
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims);
} else {
y.device(place) = x.broadcast(bcast_dims);
}
}
};
template <typename DeviceContext, typename T>
class TileGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto repeat_times = get_repeat_times(context);
auto x_dims = in0->dims();
auto vec_in_dims = framework::vectorize<int>(x_dims);
if (repeat_times.size() < vec_in_dims.size()) {
int diff = vec_in_dims.size() - repeat_times.size();
repeat_times.insert(repeat_times.begin(), diff, 1);
} else {
int diff = repeat_times.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
}
// 1. reshape_dims_vec is the broadcast parameter.
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
// each dimension expanded, the gradients should be summed to original
// size.
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times[i]);
reshape_dims_vec.push_back(vec_in_dims[i]);
}
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
out0);
} else {
PADDLE_ENFORCE_GE(dims, 1,
platform::errors::InvalidArgument(
"Th rank of the input 'Out@GRAD' for tile_grad op "
" must be greater than or equal to 1, but "
"the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for tile_grad op "
"must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, dims));
switch (dims) { REP_TILE_GRAD_TEMPLATE(MAX_RANK_SUPPORTED) }
}
}
protected:
template <int Dims>
void TileBackward(const framework::ExecutionContext& context,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec) const {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<int, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
}
};
} // namespace operators
} // namespace paddle
......@@ -101,6 +101,7 @@ from .tensor.manipulation import cast #DEFINE_ALIAS
from .tensor.manipulation import concat #DEFINE_ALIAS
from .tensor.manipulation import expand #DEFINE_ALIAS
from .tensor.manipulation import expand_as #DEFINE_ALIAS
from .tensor.manipulation import tile #DEFINE_ALIAS
from .tensor.manipulation import flatten #DEFINE_ALIAS
from .tensor.manipulation import gather #DEFINE_ALIAS
from .tensor.manipulation import gather_nd #DEFINE_ALIAS
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
# Situation 1: repeat_times is a list(without tensor)
class TestTileOpRank1(OpTest):
def setUp(self):
self.op_type = "tile"
self.init_data()
self.inputs = {'X': np.random.random(self.ori_shape).astype("float64")}
self.attrs = {'repeat_times': self.repeat_times}
output = np.tile(self.inputs['X'], self.repeat_times)
self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [100]
self.repeat_times = [2]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
# with dimension expanding
class TestTileOpRank2Expanding(TestTileOpRank1):
def init_data(self):
self.ori_shape = [120]
self.repeat_times = [2, 2]
class TestTileOpRank2(TestTileOpRank1):
def init_data(self):
self.ori_shape = [12, 14]
self.repeat_times = [2, 3]
class TestTileOpRank3_Corner(TestTileOpRank1):
def init_data(self):
self.ori_shape = (2, 10, 5)
self.repeat_times = (1, 1, 1)
class TestTileOpRank3_Corner2(TestTileOpRank1):
def init_data(self):
self.ori_shape = (2, 10, 5)
self.repeat_times = (2, 2)
class TestTileOpRank3(TestTileOpRank1):
def init_data(self):
self.ori_shape = (2, 4, 15)
self.repeat_times = (2, 1, 4)
class TestTileOpRank4(TestTileOpRank1):
def init_data(self):
self.ori_shape = (2, 4, 5, 7)
self.repeat_times = (3, 2, 1, 2)
# Situation 2: repeat_times is a list(with tensor)
class TestTileOpRank1_tensor_attr(OpTest):
def setUp(self):
self.op_type = "tile"
self.init_data()
repeat_times_tensor = []
for index, ele in enumerate(self.repeat_times):
repeat_times_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
'X': np.random.random(self.ori_shape).astype("float64"),
'repeat_times_tensor': repeat_times_tensor,
}
self.attrs = {"repeat_times": self.infer_repeat_times}
output = np.tile(self.inputs['X'], self.repeat_times)
self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [100]
self.repeat_times = [2]
self.infer_repeat_times = [-1]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestTileOpRank2_Corner_tensor_attr(TestTileOpRank1_tensor_attr):
def init_data(self):
self.ori_shape = [12, 14]
self.repeat_times = [1, 1]
self.infer_repeat_times = [1, -1]
class TestTileOpRank2_attr_tensor(TestTileOpRank1_tensor_attr):
def init_data(self):
self.ori_shape = [12, 14]
self.repeat_times = [2, 3]
self.infer_repeat_times = [-1, 3]
# Situation 3: repeat_times is a tensor
class TestTileOpRank1_tensor(OpTest):
def setUp(self):
self.op_type = "tile"
self.init_data()
self.inputs = {
'X': np.random.random(self.ori_shape).astype("float64"),
'RepeatTimes': np.array(self.repeat_times).astype("int32"),
}
self.attrs = {}
output = np.tile(self.inputs['X'], self.repeat_times)
self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [100]
self.repeat_times = [2]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestTileOpRank2_tensor(TestTileOpRank1_tensor):
def init_data(self):
self.ori_shape = [12, 14]
self.repeat_times = [2, 3]
# Situation 4: input x is Integer
class TestTileOpInteger(OpTest):
def setUp(self):
self.op_type = "tile"
self.inputs = {
'X': np.random.randint(
10, size=(2, 4, 5)).astype("int32")
}
self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
# Situation 5: input x is Bool
class TestTileOpBoolean(OpTest):
def setUp(self):
self.op_type = "tile"
self.inputs = {'X': np.random.randint(2, size=(2, 4, 5)).astype("bool")}
self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
# Situation 56: input x is Integer
class TestTileOpInt64_t(OpTest):
def setUp(self):
self.op_type = "tile"
self.inputs = {
'X': np.random.randint(
10, size=(2, 4, 5)).astype("int64")
}
self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestTileError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
repeat_times = [2, 2]
self.assertRaises(TypeError, paddle.tile, x1, repeat_times)
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, paddle.tile, x2, repeat_times)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool")
x3.stop_gradient = True
self.assertRaises(ValueError, paddle.tile, x3, repeat_times)
# Test python API
class TestTileAPI(unittest.TestCase):
def test_api(self):
input = np.random.random([12, 14]).astype("float32")
x = fluid.layers.data(
name='x', shape=[12, 14], append_batch_size=False, dtype="float32")
positive_2 = fluid.layers.fill_constant([1], "int32", 2)
repeat_times = fluid.layers.data(
name="repeat_times", shape=[2], append_batch_size=False)
out_1 = paddle.tile(x, repeat_times=[2, 3])
out_2 = paddle.tile(x, repeat_times=[positive_2, 3])
out_3 = paddle.tile(x, repeat_times=repeat_times)
g0 = fluid.backward.calc_gradient(out_2, x)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3 = exe.run(fluid.default_main_program(),
feed={
"x": input,
"repeat_times":
np.array([1, 3]).astype("int32")
},
fetch_list=[out_1, out_2, out_3])
assert np.array_equal(res_1, np.tile(input, (2, 3)))
assert np.array_equal(res_2, np.tile(input, (2, 3)))
assert np.array_equal(res_3, np.tile(input, (1, 3)))
if __name__ == "__main__":
unittest.main()
......@@ -77,6 +77,7 @@ from .manipulation import cast #DEFINE_ALIAS
from .manipulation import concat #DEFINE_ALIAS
from .manipulation import expand #DEFINE_ALIAS
from .manipulation import expand_as #DEFINE_ALIAS
from .manipulation import tile #DEFINE_ALIAS
from .manipulation import flatten #DEFINE_ALIAS
from .manipulation import gather #DEFINE_ALIAS
from .manipulation import gather_nd #DEFINE_ALIAS
......
......@@ -68,6 +68,7 @@ __all__ = [
'flip',
'unbind',
'roll',
'tile',
]
......@@ -787,3 +788,122 @@ def unbind(input, axis=0):
outputs={"Out": outs},
attrs={"axis": axis})
return outs
def tile(x, repeat_times, name=None):
"""
:alias_main: paddle.tile
:alias: paddle.tile,paddle.tensor.tile,paddle.tensor.manipulation.tile
Construct a new tensor by repeating ``x`` the number of times given by the parameter ``repeat_times``.
The rank of ``x`` should be less than or equal to 6, and the size of the shape of ``repeat_times`` should
be less than or equal to 6.
If the size of the parameter ``repeat_times`` is ``d``, and the rank for ``x`` is ``r``, then the number
of dimensions for the result is ``max(d, r)``.
If ``r < d``, ``x`` if first promoted to a d-dimensional tensor by inserting new axes at the beginning.
For example, a tensor ``x`` with the shape(3,) is promoted to a 2-D tensor with the shape (1, 3) if ``d`` is 2
and a 3-D tensor with the shape(1, 1, 3) if ``d`` is 3.
If ``r > d``, ``repeat_times`` is first promoted by inserting 1's at the beginning.
For example, if the tensor ``x`` is with a shape(4, 3, 2, 2) and ``repeat_times`` is a tuple (3, 2),
``repeat_times`` is first promoted to a tuple (1, 1, 3, 2).
The following gives an using case:
.. code-block:: text
Input(x) is a 3-D tensor of shape (2, 3, 1):
[
[[1], [2], [3]],
[[4], [5], [6]]
]
Attr(repeat_times): [1, 2, 2]
Output(out) is a 3-D tensor of shape (2, 6, 2):
[
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
]
Args:
x (Tensor): The input tensor, its data type should be bool, float32, float64, int32. The rank of ``x`` should be in [1, 6].
repeat_times (Tensor|tuple|list): The number of repeating times for each dimension of the input ``x``. If repeat_times is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If repeat_times is Tensor, it should be an 1-D Tensor. The size of its shape should be in [1, 6].
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name` .
Returns:
N-D Tensor. The data type is the same as ``x``. After tiling, each dimension of the output is equal to the corresponding dimension of ``x`` multiplying the corresponding value given by ``repeat_times`` .
Raises:
TypeError: The type of ``repeat_times`` must be list, tuple or Tensor.
ValueError: The elements of ``repeat_times`` cannot be negative.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
# example 1:
np_data_1 = np.array([1, 2, 3]).astype('int32')
data_1 = paddle..to_variable(np_data_1)
tiled_1 = paddle.tile(data_1, repeat_times=[2, 1])
# [[1, 2, 3], [1, 2, 3]]
# example 2:
np_repeat_times = np.array([2, 1]).astype("int32")
repeat_times = paddle.to_variable(np_repeat_times)
tiled_2 = paddle.tile(data_1, repeat_times=repeat_times)
# [[1, 2, 3], [1, 2, 3]]
"""
if in_dygraph_mode():
if isinstance(repeat_times, (list, tuple)):
repeat_times = [
item.numpy()[0] if isinstance(item, Variable) else item
for item in repeat_times
]
return core.ops.tile(x, 'repeat_times', repeat_times)
inputs = {"X": [x]}
attrs = {}
check_variable_and_dtype(
x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'tile')
check_type(repeat_times, 'repeat_times', (list, tuple, Variable), 'tile')
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True:
raise ValueError(
"When the date type is bool for the input 'x' of tile op, you "
"must set its stop_gradient to be False by "
"some_var.stop_gradient == True supporting some_var is the input.")
helper = LayerHelper('tile', input=x, **locals())
def get_attr_repeat_times(list_repeat_times):
attrs_repeat_times = []
for idx, times in enumerate(list_repeat_times):
if isinstance(times, Variable):
attrs_repeat_times.append(-1)
else:
attrs_repeat_times.append(times)
assert times > 0, (
"Every element given in repeat_times must be positive.")
return attrs_repeat_times
if isinstance(repeat_times, Variable):
repeat_times.stop_gradient = True
inputs['RepeatTimes'] = repeat_times
attrs['repeat_times'] = [-1] * len(repeat_times.shape)
elif isinstance(repeat_times, (list, tuple)):
attrs['repeat_times'] = get_attr_repeat_times(repeat_times)
if utils._contain_var(repeat_times):
inputs['repeat_times_tensor'] = utils._convert_to_tensor_list(
repeat_times)
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='tile', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册