未验证 提交 638bbb61 编写于 作者: L lilong12 提交者: GitHub

Improve expand as (#26290)

align expand_as op to expand.
上级 5fdec3ed
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/expand_as_v2_op.h"
#include <memory>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class ExpandAsV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandAsV2");
OP_INOUT_CHECK(ctx->HasInput("target_tensor"), "Input", "target_tensor",
"ExpandAsV2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExpandAsV2");
auto x_dims = ctx->GetInputDim("X");
auto target_tensor_dims = ctx->GetInputDim("target_tensor");
PADDLE_ENFORCE_GE(
target_tensor_dims.size(), static_cast<size_t>(x_dims.size()),
platform::errors::InvalidArgument(
"The rank of Input(target_tensor) must be greater than or equal "
"to the rank of Input(X). But received Input(X): input "
"rank %u, input shape [%s]; received Input(target_tensor): "
"input rank %u, input shape [%s].",
x_dims.size(), x_dims, target_tensor_dims.size(),
target_tensor_dims));
PADDLE_ENFORCE_LE(
target_tensor_dims.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of Input(target_tensor) must not be less than or equal "
"to %d. But received: input rank %u, input shape [%s].",
MAX_RANK_SUPPORTED, x_dims.size(), x_dims));
std::vector<int64_t> out_shape(target_tensor_dims.size());
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}
};
class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded.");
AddOutput("Out",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"The rank of Output(Out) have the same with Input(X). "
"After expanding, size of each dimension of Output(Out) is equal "
"to size of the corresponding dimension of Input(X) multiplying "
"the corresponding value given by Attr(expand_times).");
AddInput("target_tensor", "Expand tensor's shape for each dimension.");
AddComment(R"DOC(
Expand the input by given times number. You should set times
number for each dimension by providing tensor 'expend_tensor'. The rank of X
should be in [1, 6]. Please note that size of 'expend_tensor' must be the same
with X's rank. Following is a using case:
Input(X) is a 3-D tensor with shape [2, 3, 1]:
[
[[1], [2], [3]],
[[4], [5], [6]]
]
target_tensors'shape: [2, 6, 2]
Output(Out) is a 3-D tensor with shape [2, 6, 2]:
[
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
]
)DOC");
}
};
class ExpandAsV2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandAsV2Grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "ExpandAsV2Grad");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
class ExpandAsV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("expand_as_v2_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("target_tensor", this->Input("target_tensor"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandAsV2GradNoNeedBufVarsInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(expand_as_v2, ops::ExpandAsV2Op, ops::ExpandAsV2OpMaker,
ops::ExpandAsV2GradOpMaker<paddle::framework::OpDesc>,
ops::ExpandAsV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(expand_as_v2_grad, ops::ExpandAsV2GradOp,
ops::ExpandAsV2GradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
expand_as_v2,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
expand_as_v2_grad,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/expand_as_v2_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
expand_as_v2,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
expand_as_v2_grad,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#define MAX_RANK_SUPPORTED 6
#define EXPAND_AS_TEMPLATE(z, n, data) \
case n + 1: { \
ExpandAs<n + 1>(context); \
break; \
}
#define REP_EXPAND_AS_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_AS_TEMPLATE, ~)
#define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
#define EXPAND_AS_GRAD_CASE(n) \
case n: { \
ExpandAsBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
break; \
}
#define EXPAND_AS_GRAD_TEMPLATE(z, n, data) \
BOOST_PP_IF(COND(n), EXPAND_AS_GRAD_CASE(n), )
#define REP_EXPAND_AS_GRAD_TEMPLATE(n) \
BOOST_PP_REPEAT(n, EXPAND_AS_GRAD_TEMPLATE, ~)
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class ExpandAsV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rank = context.Input<Tensor>("X")->dims().size();
auto* target_tensor = context.Input<Tensor>("target_tensor");
auto target_rank = target_tensor->dims().size();
PADDLE_ENFORCE_GE(target_rank, rank,
platform::errors::InvalidArgument(
"The rank (%d) of the input 'target_tensor' for "
"expand_as_v2 op must be greater than or equal to "
"the rank (%d) of the input 'x'.",
target_rank, rank));
PADDLE_ENFORCE_GE(rank, 1, platform::errors::InvalidArgument(
"The rank (%d) of the input 'x' for "
"expand_as_v2 op must be positive.",
rank));
PADDLE_ENFORCE_LE(target_rank, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank (%d) of the input 'target_tensor' for "
"expand_as_v2 op must be less than or equal to %d.",
target_rank, MAX_RANK_SUPPORTED));
switch (target_rank) { REP_EXPAND_AS_TEMPLATE(MAX_RANK_SUPPORTED) }
}
protected:
template <int Rank>
void ExpandAs(const framework::ExecutionContext& context) const {
auto* in0 = context.Input<Tensor>("X");
auto in_dims = in0->dims();
auto* target_tensor = context.Input<Tensor>("target_tensor");
auto vec_in_dims = framework::vectorize<int>(in_dims);
auto target_shape = framework::vectorize<int>(target_tensor->dims());
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
PADDLE_ENFORCE_NE(target_shape[i], 0,
platform::errors::InvalidArgument(
"The value of target shape cannot be zero."));
if (vec_in_dims[i] != 1) {
PADDLE_ENFORCE_EQ(
vec_in_dims[i], target_shape[i],
platform::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in "
"target tensor for expand_as_v2 op.",
vec_in_dims[i], target_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = target_shape[i];
}
}
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
framework::DDim new_in_dims = framework::make_ddim(vec_in_dims);
framework::DDim out_dims = framework::make_ddim(target_shape);
out0->Resize(out_dims);
auto x = EigenTensor<T, Rank>::From(*in0, new_in_dims);
out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0, out_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
y.device(place) = x.broadcast(bcast_dims);
}
};
template <typename DeviceContext, typename T>
class ExpandAsV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto* target_tensor = context.Input<Tensor>("target_tensor");
auto x_dims = in0->dims();
auto target_shape = target_tensor->dims();
auto vec_in_dims = framework::vectorize<int>(x_dims);
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
repeat_times[i] = target_shape[i] / vec_in_dims[i];
}
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times[i]);
reshape_dims_vec.push_back(vec_in_dims[i]);
}
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
out0);
} else {
PADDLE_ENFORCE_GE(dims, 1,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be greater than or "
"equal to 1, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, dims));
switch (dims) { REP_EXPAND_AS_GRAD_TEMPLATE(MAX_RANK_SUPPORTED) }
}
}
protected:
template <int Dims>
void ExpandAsBackward(const framework::ExecutionContext& context,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec) const {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<int, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
}
};
} // namespace operators
} // namespace paddle
......@@ -101,6 +101,7 @@ from .tensor.logic import equal_all #DEFINE_ALIAS
from .tensor.manipulation import cast #DEFINE_ALIAS
from .tensor.manipulation import concat #DEFINE_ALIAS
from .tensor.manipulation import expand #DEFINE_ALIAS
from .tensor.manipulation import broadcast_to #DEFINE_ALIAS
from .tensor.manipulation import expand_as #DEFINE_ALIAS
from .tensor.manipulation import tile #DEFINE_ALIAS
from .tensor.manipulation import flatten #DEFINE_ALIAS
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
class TestExpandAsOpRank1(OpTest):
def setUp(self):
self.op_type = "expand_as_v2"
x = np.random.rand(100).astype("float64")
target_tensor = np.random.rand(2, 100).astype("float64")
self.inputs = {'X': x, 'target_tensor': target_tensor}
self.attrs = {}
bcast_dims = [2, 1]
output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandAsOpRank2(OpTest):
def setUp(self):
self.op_type = "expand_as_v2"
x = np.random.rand(10, 12).astype("float64")
target_tensor = np.random.rand(10, 12).astype("float64")
self.inputs = {'X': x, 'target_tensor': target_tensor}
self.attrs = {}
bcast_dims = [1, 1]
output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandAsOpRank3(OpTest):
def setUp(self):
self.op_type = "expand_as_v2"
x = np.random.rand(2, 3, 20).astype("float64")
target_tensor = np.random.rand(2, 3, 20).astype("float64")
self.inputs = {'X': x, 'target_tensor': target_tensor}
self.attrs = {}
bcast_dims = [1, 1, 1]
output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandAsOpRank4(OpTest):
def setUp(self):
self.op_type = "expand_as_v2"
x = np.random.rand(1, 1, 7, 16).astype("float64")
target_tensor = np.random.rand(4, 6, 7, 16).astype("float64")
self.inputs = {'X': x, 'target_tensor': target_tensor}
self.attrs = {}
bcast_dims = [4, 6, 1, 1]
output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
# Test python API
class TestExpandAPI(unittest.TestCase):
def test_api(self):
input1 = np.random.random([12, 14]).astype("float32")
input2 = np.random.random([2, 12, 14]).astype("float32")
x = fluid.layers.data(
name='x', shape=[12, 14], append_batch_size=False, dtype="float32")
y = fluid.layers.data(
name='target_tensor',
shape=[2, 12, 14],
append_batch_size=False,
dtype="float32")
out_1 = paddle.expand_as(x, y=y)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1 = exe.run(fluid.default_main_program(),
feed={"x": input1,
"target_tensor": input2},
fetch_list=[out_1])
assert np.array_equal(res_1[0], np.tile(input1, (2, 1, 1)))
if __name__ == "__main__":
unittest.main()
......@@ -193,7 +193,7 @@ class TestExpandV2Error(unittest.TestCase):
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, paddle.tensor.expand, x2, shape)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool")
x3.stop_gradient = True
x3.stop_gradient = False
self.assertRaises(ValueError, paddle.tensor.expand, x3, shape)
......
......@@ -22,7 +22,7 @@ import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
# Situation 1: repeat_times is a list(without tensor)
# Situation 1: repeat_times is a list (without tensor)
class TestTileOpRank1(OpTest):
def setUp(self):
self.op_type = "tile"
......@@ -81,7 +81,7 @@ class TestTileOpRank4(TestTileOpRank1):
self.repeat_times = (3, 2, 1, 2)
# Situation 2: repeat_times is a list(with tensor)
# Situation 2: repeat_times is a list (with tensor)
class TestTileOpRank1_tensor_attr(OpTest):
def setUp(self):
self.op_type = "tile"
......@@ -162,7 +162,7 @@ class TestTileOpInteger(OpTest):
self.op_type = "tile"
self.inputs = {
'X': np.random.randint(
10, size=(2, 4, 5)).astype("int32")
10, size=(4, 4, 5)).astype("int32")
}
self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
......@@ -211,38 +211,30 @@ class TestTileError(unittest.TestCase):
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, paddle.tile, x2, repeat_times)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool")
x3.stop_gradient = True
x3.stop_gradient = False
self.assertRaises(ValueError, paddle.tile, x3, repeat_times)
# Test python API
class TestTileAPI(unittest.TestCase):
def test_api(self):
input = np.random.random([12, 14]).astype("float32")
x = fluid.layers.data(
name='x', shape=[12, 14], append_batch_size=False, dtype="float32")
positive_2 = fluid.layers.fill_constant([1], "int32", 2)
repeat_times = fluid.layers.data(
name="repeat_times", shape=[2], append_batch_size=False)
out_1 = paddle.tile(x, repeat_times=[2, 3])
out_2 = paddle.tile(x, repeat_times=[positive_2, 3])
out_3 = paddle.tile(x, repeat_times=repeat_times)
g0 = fluid.backward.calc_gradient(out_2, x)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3 = exe.run(fluid.default_main_program(),
feed={
"x": input,
"repeat_times":
np.array([1, 3]).astype("int32")
},
fetch_list=[out_1, out_2, out_3])
assert np.array_equal(res_1, np.tile(input, (2, 3)))
assert np.array_equal(res_2, np.tile(input, (2, 3)))
assert np.array_equal(res_3, np.tile(input, (1, 3)))
with fluid.dygraph.guard():
np_x = np.random.random([12, 14]).astype("float32")
x = paddle.to_variable(np_x)
positive_2 = np.array([2]).astype("int32")
positive_2 = paddle.to_variable(positive_2)
repeat_times = np.array([2, 3]).astype("int32")
repeat_times = paddle.to_variable(repeat_times)
out_1 = paddle.tile(x, repeat_times=[2, 3])
out_2 = paddle.tile(x, repeat_times=[positive_2, 3])
out_3 = paddle.tile(x, repeat_times=repeat_times)
assert np.array_equal(out_1.numpy(), np.tile(np_x, (2, 3)))
assert np.array_equal(out_2.numpy(), np.tile(np_x, (2, 3)))
assert np.array_equal(out_3.numpy(), np.tile(np_x, (2, 3)))
if __name__ == "__main__":
......
......@@ -74,6 +74,7 @@ from .logic import equal_all #DEFINE_ALIAS
from .manipulation import cast #DEFINE_ALIAS
from .manipulation import concat #DEFINE_ALIAS
from .manipulation import expand #DEFINE_ALIAS
from .manipulation import broadcast_to #DEFINE_ALIAS
from .manipulation import expand_as #DEFINE_ALIAS
from .manipulation import tile #DEFINE_ALIAS
from .manipulation import flatten #DEFINE_ALIAS
......
......@@ -23,7 +23,6 @@ from ..fluid.layers import utils
import numpy as np
# TODO: define functions to manipulate a tensor
from ..fluid.layers import cast #DEFINE_ALIAS
from ..fluid.layers import expand_as #DEFINE_ALIAS
from ..fluid.layers import reshape #DEFINE_ALIAS
from ..fluid.layers import scatter #DEFINE_ALIAS
from ..fluid.layers import slice #DEFINE_ALIAS
......@@ -44,6 +43,7 @@ __all__ = [
'cast',
'concat',
'expand',
'broadcast_to',
'expand_as',
'flatten',
'gather',
......@@ -791,80 +791,58 @@ def unbind(input, axis=0):
def tile(x, repeat_times, name=None):
"""
:alias_main: paddle.tile
:alias: paddle.tile,paddle.tensor.tile,paddle.tensor.manipulation.tile
Construct a new tensor by repeating ``x`` the number of times given by the parameter ``repeat_times``.
The rank of ``x`` should be less than or equal to 6, and the size of the shape of ``repeat_times`` should
be less than or equal to 6.
If the size of the parameter ``repeat_times`` is ``d``, and the rank for ``x`` is ``r``, then the number
of dimensions for the result is ``max(d, r)``.
If ``r < d``, ``x`` if first promoted to a d-dimensional tensor by inserting new axes at the beginning.
For example, a tensor ``x`` with the shape(3,) is promoted to a 2-D tensor with the shape (1, 3) if ``d`` is 2
and a 3-D tensor with the shape(1, 1, 3) if ``d`` is 3.
If ``r > d``, ``repeat_times`` is first promoted by inserting 1's at the beginning.
For example, if the tensor ``x`` is with a shape(4, 3, 2, 2) and ``repeat_times`` is a tuple (3, 2),
``repeat_times`` is first promoted to a tuple (1, 1, 3, 2).
The following gives an using case:
.. code-block:: text
Input(x) is a 3-D tensor of shape (2, 3, 1):
[
[[1], [2], [3]],
[[4], [5], [6]]
]
Attr(repeat_times): [1, 2, 2]
Output(out) is a 3-D tensor of shape (2, 6, 2):
[
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
]
Construct a new Tensor by repeating ``x`` the number of times given by ``repeat_times``.
After tiling, the number of elements of the i'th dimension of the output is equal to ``x.dims[i] * repeat_times[i]``.
Both the number of dimensions of ``x`` and the number of elements in ``repeat_times`` should be less than or equal to 6.
Args:
x (Tensor): The input tensor, its data type should be bool, float32, float64, int32. The rank of ``x`` should be in [1, 6].
repeat_times (Tensor|tuple|list): The number of repeating times for each dimension of the input ``x``. If repeat_times is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If repeat_times is Tensor, it should be an 1-D Tensor. The size of its shape should be in [1, 6].
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name` .
x (Tensor): The input tensor, its data type should be bool, float32, float64, int32 or int64.
repeat_times (Tensor|tuple|list): The number of repeating times. If repeat_times is a list or tuple, all its elements
should be integers or 1-D Tensors with the data type int32. If repeat_times is a Tensor, it should be an 1-D Tensor with the data type int32.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. The data type is the same as ``x``. After tiling, each dimension of the output is equal to the corresponding dimension of ``x`` multiplying the corresponding value given by ``repeat_times`` .
Raises:
TypeError: The type of ``repeat_times`` must be list, tuple or Tensor.
ValueError: The elements of ``repeat_times`` cannot be negative.
N-D Tensor. The data type is the same as ``x``.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
# example 1:
np_data_1 = np.array([1, 2, 3]).astype('int32')
data_1 = paddle..to_variable(np_data_1)
tiled_1 = paddle.tile(data_1, repeat_times=[2, 1])
np_data = np.array([1, 2, 3]).astype('int32')
data = paddle.to_variable(np_data)
out = paddle.tile(data, repeat_times=[2, 1])
np_out = out1.numpy()
# [[1, 2, 3], [1, 2, 3]]
# example 2:
out = paddle.tile(data, repeat_times=[2, 2])
np_out = out.numpy()
# [[1, 2, 3, 1, 2, 3], [1, 2, 3, 1, 2, 3]]
np_repeat_times = np.array([2, 1]).astype("int32")
repeat_times = paddle.to_variable(np_repeat_times)
tiled_2 = paddle.tile(data_1, repeat_times=repeat_times)
out = paddle.tile(data, repeat_times=repeat_times)
np_out = out.numpy()
# [[1, 2, 3], [1, 2, 3]]
"""
if in_dygraph_mode():
if isinstance(repeat_times, (list, tuple)):
repeat_times = [
item.numpy()[0] if isinstance(item, Variable) else item
for item in repeat_times
]
return core.ops.tile(x, 'repeat_times', repeat_times)
inputs = {"X": [x]}
attrs = {}
check_variable_and_dtype(
x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'tile')
check_type(repeat_times, 'repeat_times', (list, tuple, Variable), 'tile')
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True:
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False:
raise ValueError(
"When the date type is bool for the input 'x' of tile op, you "
"must set its stop_gradient to be False by "
"some_var.stop_gradient == True supporting some_var is the input.")
"must set its stop_gradient to be True by "
"some_var.stop_gradient == True supporting some_var as the input.")
helper = LayerHelper('tile', input=x, **locals())
inputs = {"X": [x]}
attrs = {}
def get_attr_repeat_times(list_repeat_times):
attrs_repeat_times = []
for idx, times in enumerate(list_repeat_times):
......@@ -873,13 +851,13 @@ def tile(x, repeat_times, name=None):
else:
attrs_repeat_times.append(times)
assert times > 0, (
"Every element given in repeat_times must be positive.")
"All elements in repeat_times must be positive for tile.")
return attrs_repeat_times
if isinstance(repeat_times, Variable):
repeat_times.stop_gradient = True
inputs['RepeatTimes'] = repeat_times
attrs['repeat_times'] = [-1] * len(repeat_times.shape)
attrs['repeat_times'] = [-1]
elif isinstance(repeat_times, (list, tuple)):
attrs['repeat_times'] = get_attr_repeat_times(repeat_times)
if utils._contain_var(repeat_times):
......@@ -893,67 +871,103 @@ def tile(x, repeat_times, name=None):
return out
def expand_as(x, y, name=None):
"""
Expand the input tensor ``x`` to the same shape as the input tensor ``y``.
Both the number of dimensions of ``x`` and ``y`` must be less than or equal to 6, and the number of dimensions of ``y`` must be greather than or equal to that of ``x``. The dimension to expand must have a value of 1.
Args:
x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64.
y (Tensor): The input tensor gives the shape that ``x`` to expand to.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor: A Tensor with the same shape as ``y``. The data type is the same as ``x``.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
np_data_x = np.array([1, 2, 3]).astype=('int32)
np_data_y = np.array([[1, 2, 3], [4, 5, 6]]).astype=('int32)
data_x = paddle.to_variable(np_data_x)
data_y = paddle.to_variable(np_data_y)
out = paddle.expand_as(data_x, data_y)
np_out = out.numpy()
# [[1, 2, 3], [1, 2, 3]]
"""
check_variable_and_dtype(
x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand_as')
check_type(y, 'y', Variable, 'expand_as')
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False:
raise ValueError(
"When the data type of input 'x' for expand_as is bool, "
"you must set its stop_gradient to be False by "
"some_var.stop_gradient = True, supporting "
"some_var as the input 'x'.")
inputs = {"X": [x], "target_tensor": [y]}
helper = LayerHelper('expand_as', input=x, **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type='expand_as_v2', inputs=inputs, outputs={'Out': out})
return out
def expand(x, shape, name=None):
"""
:alias_main: paddle.expand
:alias: paddle.expand,paddle.tensor.expand,paddle.tensor.manipulation.expand
Expand the input tensor to a given shape.
The rank of ``x`` should be less than or equal to 6, and the number of elements in ``shape`` should also be less than or equal to 6. The size of the dimension to expand must be 1.
Both the number of dimensions of ``x`` and the number of elements in ``shape`` should be less than or equal to 6. The dimension to expand must have a value 1.
Args:
x (Tensor): The input Tensor with rank in [1, 6]. The data type is bool, float32, float64 or int32.
shape (list|tuple|Tensor): The result shape after expanding. The data type is int32. If shape is a list or tuple, all elements of
it should be integers or Tensors with shape (1,). If shape is a Tensor, it should be an 1-D Tensor.
The value -1 in shape, means keeping the corresponding dimension unchanged.
x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64.
shape (list|tuple|Tensor): The result shape after expanding. The data type is int32. If shape is a list or tuple, all its elements
should be integers or 1-D Tensors with the data type int32. If shape is a Tensor, it should be an 1-D Tensor with the data type int32.
The value -1 in shape means keeping the corresponding dimension unchanged.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor: A Tensor with the given shape. The data type is the same as ``x``.
Raises:
TypeError: The type of ``shape`` must be list, tuple or Variable.
ValueError: The elements of ``shape`` must be positive or -1.
N-D Tensor: A Tensor with the given shape. The data type is the same as ``x``.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
# example 1:
np_data_1 = np.array([1, 2, 3]).astype=('int32)
data_1 = paddle.to_variable(np_data_1)
expanded_1 = paddle.expand(data_1, shape=[2, 3])
paddle.disable_static()
np_data = np.array([1, 2, 3]).astype=('int32)
data = paddle.to_variable(np_data)
out = paddle.expand(data, shape=[2, 3])
out = out.numpy()
# [[1, 2, 3], [1, 2, 3]]
# example 2:
np_shape = np.array([2, 3]).astype=('int32)
shape = paddle.to_variable(np_shape)
expanded_2 = paddle.expand(data_1, shape=shape)
out = paddle.expand(data, shape=shape)
out = out.numpy
# [[1, 2, 3], [1, 2, 3]]
"""
if in_dygraph_mode():
if isinstance(shape, (list, tuple)):
expand_shape = [
item.numpy()[0] if isinstance(item, Variable) else item
for item in shape
]
return core.ops.expand_v2(x, 'shape', expand_shape)
inputs = {"X": [x]}
attrs = {}
check_variable_and_dtype(
x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand')
check_type(shape, 'shape', (list, tuple, Variable), 'expand')
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True:
inputs = {"X": [x]}
attrs = {}
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False:
raise ValueError("When the data type of input 'x' for expand is bool, "
"you must set its stop_gradient to be False by "
" some_var.stop_gradient = False, supporting "
"some_var.stop_gradient = True, supporting "
"some_var as the input.")
helper = LayerHelper('expand', input=x, **locals())
......@@ -966,7 +980,7 @@ def expand(x, shape, name=None):
else:
attrs_expand_shape.append(shape)
assert shape > 0 or shape == -1, (
"Every element in shape must be positive or -1.")
"All elements in shape of expand must be positive or -1.")
return attrs_expand_shape
if isinstance(shape, Variable):
......@@ -983,3 +997,6 @@ def expand(x, shape, name=None):
helper.append_op(
type='expand_v2', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out
broadcast_to = expand
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册