未验证 提交 4ec1251a 编写于 作者: L Leo Chen 提交者: GitHub

Refine squeeze, test=develop (#25281)

* refine squeeze, test=develop

* update squeeze, test=develop

* refine compile-time infershape, test=develop

* add more unittest, test=develop

* follow comments, test=develop

* add update_api, test=develop

* follow comments, test=develop
上级 28064c2d
...@@ -13,15 +13,73 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,73 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/squeeze_op.h" #include "paddle/fluid/operators/squeeze_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims,
bool is_runtime) {
size_t num_squeeze_dims = squeeze_dims.size();
std::vector<bool> should_squeeze(in_dims.size(), false);
// Mark dimensions need to be squeezed.
if (num_squeeze_dims == 0) {
for (int i = 0; i < in_dims.size(); ++i) {
if (in_dims[i] == 1) {
should_squeeze[i] = true;
}
}
} else {
for (size_t i = 0; i < num_squeeze_dims; ++i) {
int current = squeeze_dims[i] < 0 ? squeeze_dims[i] + in_dims.size()
: squeeze_dims[i];
PADDLE_ENFORCE_GE(
current, 0,
platform::errors::InvalidArgument(
"Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(), in_dims.size() - 1, current, in_dims));
PADDLE_ENFORCE_LT(
current, in_dims.size(),
platform::errors::InvalidArgument(
"Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(), in_dims.size() - 1, current, in_dims));
if (!should_squeeze[current]) {
if (is_runtime) {
// At run time, dim of 1 is allowed to squeeze
if (in_dims[current] == 1) {
should_squeeze[current] = true;
}
} else {
// At compile time, dim of -1 or 1 is allowed to squeeze
if (in_dims[current] == 1 || in_dims[current] == -1) {
should_squeeze[current] = true;
}
}
}
}
}
// Make output dimensions
std::vector<int64_t> output_shape;
for (int i = 0; i < in_dims.size(); ++i) {
if (!should_squeeze[i]) {
output_shape.push_back(in_dims[i]);
}
}
return framework::make_ddim(output_shape);
}
class SqueezeOp : public framework::OperatorWithKernel { class SqueezeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -40,7 +98,7 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -40,7 +98,7 @@ class SqueezeOp : public framework::OperatorWithKernel {
x_dims.size(), x_dims)); x_dims.size(), x_dims));
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims, false);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) { if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X) // Only pass LoD when the first dimension of output and Input(X)
...@@ -49,56 +107,6 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -49,56 +107,6 @@ class SqueezeOp : public framework::OperatorWithKernel {
} }
} }
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims) {
size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false};
// Determines number of dimensions of output tensor after squeeze.
// Mark and count the dimensions need to be squeezed
if (num_squeeze_dims == 0) {
for (int idx = 0; idx < in_dims.size(); ++idx) {
if (in_dims[idx] == 1) {
should_squeeze[idx] = true;
++cnt_squeezed_dims;
}
}
} else {
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
PADDLE_ENFORCE_GE(
current, 0,
platform::errors::InvalidArgument(
"Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(), in_dims.size() - 1, current, in_dims));
PADDLE_ENFORCE_LT(
current, in_dims.size(),
platform::errors::InvalidArgument(
"Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(), in_dims.size() - 1, current, in_dims));
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
}
should_squeeze[current] = true;
}
}
// Make output dimensions
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
if (!should_squeeze[in_idx]) {
output_shape[out_idx++] = in_dims[in_idx];
}
}
return framework::make_ddim(output_shape);
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -183,7 +191,7 @@ class Squeeze2Op : public framework::OperatorWithKernel { ...@@ -183,7 +191,7 @@ class Squeeze2Op : public framework::OperatorWithKernel {
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto out_dims = SqueezeOp::GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims, false);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) { if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X) // Only pass LoD when the first dimension of output and Input(X)
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -24,6 +25,9 @@ limitations under the License. */ ...@@ -24,6 +25,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims, bool is_runtime);
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SqueezeKernel : public framework::OpKernel<T> { class SqueezeKernel : public framework::OpKernel<T> {
public: public:
...@@ -33,7 +37,7 @@ class SqueezeKernel : public framework::OpKernel<T> { ...@@ -33,7 +37,7 @@ class SqueezeKernel : public framework::OpKernel<T> {
auto &axes = context.Attr<std::vector<int>>("axes"); auto &axes = context.Attr<std::vector<int>>("axes");
auto x_dims = in->dims(); auto x_dims = in->dims();
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims, true);
out->mutable_data(context.GetPlace(), in->type()); out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy( framework::TensorCopy(
...@@ -41,64 +45,6 @@ class SqueezeKernel : public framework::OpKernel<T> { ...@@ -41,64 +45,6 @@ class SqueezeKernel : public framework::OpKernel<T> {
context.template device_context<platform::DeviceContext>(), out); context.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims); out->Resize(out_dims);
} }
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims) {
size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false};
// Determines number of dimensions of output tensor after squeeze.
// Mark and count the dimensions need to be squeezed
if (num_squeeze_dims == 0) {
for (int idx = 0; idx < in_dims.size(); ++idx) {
if (in_dims[idx] == 1) {
should_squeeze[idx] = true;
++cnt_squeezed_dims;
}
}
} else {
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
PADDLE_ENFORCE_GE(
current, 0,
platform::errors::InvalidArgument(
"Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(), in_dims.size() - 1, current, in_dims));
PADDLE_ENFORCE_LT(
current, in_dims.size(),
platform::errors::InvalidArgument(
"Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(), in_dims.size() - 1, current, in_dims));
PADDLE_ENFORCE_EQ(in_dims[current], 1,
platform::errors::InvalidArgument(
"The size of axis that will be squeezed "
"should be equal to 1. But current axis = %d,"
"input tensor's shape = [%s].",
in_dims[current], in_dims));
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
}
should_squeeze[current] = true;
}
}
// Make output dimensions
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
if (!should_squeeze[in_idx]) {
output_shape[out_idx++] = in_dims[in_idx];
}
}
return framework::make_ddim(output_shape);
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -126,8 +72,7 @@ class Squeeze2Kernel : public framework::OpKernel<T> { ...@@ -126,8 +72,7 @@ class Squeeze2Kernel : public framework::OpKernel<T> {
auto &axes = context.Attr<std::vector<int>>("axes"); auto &axes = context.Attr<std::vector<int>>("axes");
auto x_dims = in->dims(); auto x_dims = in->dims();
auto out_dims = auto out_dims = GetOutputShape(axes, x_dims, true);
SqueezeKernel<DeviceContext, T>::GetOutputShape(axes, x_dims);
out->mutable_data(context.GetPlace(), in->type()); out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy( framework::TensorCopy(
......
...@@ -6203,6 +6203,10 @@ def squeeze(input, axes, name=None): ...@@ -6203,6 +6203,10 @@ def squeeze(input, axes, name=None):
y = layers.squeeze(input=x, axes=[2]) # y.shape=[None, 5, 10] y = layers.squeeze(input=x, axes=[2]) # y.shape=[None, 5, 10]
""" """
if in_dygraph_mode():
out, _ = core.ops.squeeze2(input, 'axes', axes)
return out
helper = LayerHelper("squeeze", **locals()) helper = LayerHelper("squeeze", **locals())
check_variable_and_dtype( check_variable_and_dtype(
input, 'input', input, 'input',
......
...@@ -70,6 +70,14 @@ class TestSqueezeOp3(TestSqueezeOp): ...@@ -70,6 +70,14 @@ class TestSqueezeOp3(TestSqueezeOp):
self.new_shape = (6, 5, 1, 4) self.new_shape = (6, 5, 1, 4)
# Correct: The demension of axis is not of size 1 remains unchanged.
class TestSqueezeOp4(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (6, 1, 5, 1, 4, 1)
self.axes = (1, 2)
self.new_shape = (6, 5, 1, 4, 1)
class TestSqueezeOpError(unittest.TestCase): class TestSqueezeOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
...@@ -90,7 +98,7 @@ class API_TestSqueeze(unittest.TestCase): ...@@ -90,7 +98,7 @@ class API_TestSqueeze(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.layers.data( data1 = fluid.layers.data(
'data1', shape=[-1, 1, 10], dtype='float64') 'data1', shape=[-1, 1, 10], dtype='float64')
result_squeeze = paddle.squeeze(data1, axes=[1]) result_squeeze = paddle.squeeze(data1, axis=[1])
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
input1 = np.random.random([5, 1, 10]).astype('float64') input1 = np.random.random([5, 1, 10]).astype('float64')
...@@ -105,7 +113,25 @@ class API_TestDygraphSqueeze(unittest.TestCase): ...@@ -105,7 +113,25 @@ class API_TestDygraphSqueeze(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
input_1 = np.random.random([5, 1, 10]).astype("int32") input_1 = np.random.random([5, 1, 10]).astype("int32")
input = fluid.dygraph.to_variable(input_1) input = fluid.dygraph.to_variable(input_1)
output = paddle.squeeze(input, axes=[1]) output = paddle.squeeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
def test_axis_not_list(self):
with fluid.dygraph.guard():
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = fluid.dygraph.to_variable(input_1)
output = paddle.squeeze(input, axis=1)
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
def test_dimension_not_1(self):
with fluid.dygraph.guard():
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = fluid.dygraph.to_variable(input_1)
output = paddle.squeeze(input, axis=(1, 2))
out_np = output.numpy() out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1) expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np)) self.assertTrue(np.allclose(expected_out, out_np))
......
...@@ -40,6 +40,7 @@ from ..fluid.layers import scatter_nd_add #DEFINE_ALIAS ...@@ -40,6 +40,7 @@ from ..fluid.layers import scatter_nd_add #DEFINE_ALIAS
from ..fluid.layers import scatter_nd #DEFINE_ALIAS from ..fluid.layers import scatter_nd #DEFINE_ALIAS
from ..fluid.layers import shard_index #DEFINE_ALIAS from ..fluid.layers import shard_index #DEFINE_ALIAS
from ..fluid.layers import unique_with_counts #DEFINE_ALIAS from ..fluid.layers import unique_with_counts #DEFINE_ALIAS
from ..fluid import layers
__all__ = [ __all__ = [
'cast', 'concat', 'expand', 'expand_as', 'flatten', 'gather', 'gather_nd', 'cast', 'concat', 'expand', 'expand_as', 'flatten', 'gather', 'gather_nd',
...@@ -442,83 +443,81 @@ def split(input, num_or_sections, dim=-1, name=None): ...@@ -442,83 +443,81 @@ def split(input, num_or_sections, dim=-1, name=None):
return outs return outs
def squeeze(input, axes, out=None, name=None): def squeeze(x, axis=None, name=None):
""" """
:alias_main: paddle.squeeze :alias_main: paddle.squeeze
:alias: paddle.squeeze,paddle.tensor.squeeze,paddle.tensor.manipulation.squeeze :alias: paddle.squeeze, paddle.tensor.squeeze, paddle.tensor.manipulation.squeeze
This OP will squeeze single-dimensional entries of input tensor's shape. If axes is provided, will This OP will squeeze the dimension(s) of size 1 of input tensor x's shape.
remove the dims by axes, the dims selected by axes should be one. If not provide axes, all dims equal
to one will be deleted.
If axis is provided, it will remove the dimension(s) by given axis that of size 1.
If the dimension of given axis is not of size 1, the dimension remain unchanged.
If axis is not provided, all dims equal of size 1 will be removed.
.. code-block:: text .. code-block:: text
Case1: Case1:
Input: Input:
X.shape = (1, 3, 1, 5) x.shape = [1, 3, 1, 5] # If axis is not provided, all dims equal of size 1 will be removed.
axes = [0] axis = None
Output: Output:
Out.shape = (3, 1, 5) out.shape = [3, 5]
Case2: Case2:
Input: Input:
X.shape = (1, 3, 1, 5) x.shape = [1, 3, 1, 5] # If axis is provided, it will remove the dimension(s) by given axis that of size 1.
axes = [] axis = 0
Output:
out.shape = [3, 1, 5]
Case4:
Input:
x.shape = [1, 3, 1, 5] # If the dimension of one given axis (3) is not of size 1, the dimension remain unchanged.
axis = [0, 2, 3]
Output: Output:
Out.shape = (3, 5) out.shape = [3, 5]
Case3: Case4:
Input: Input:
X.shape = [1,3,1,5] x.shape = [1, 3, 1, 5] # If axis is negative, axis = axis + ndim (number of dimensions in x).
axes = [-2] axis = [-2]
Output: Output:
Out.shape = [1,3,5] out.shape = [1, 3, 5]
Args: Args:
input (Variable): The input Tensor. Support data type: float32, float64, int8, int32, int64. input (Tensor): The input Tensor. Support data type: float32, float64, int8, int32, int64.
axes (list): One integer or List of integers, indicating the dimensions to be squeezed. axis (int|list|tuple, optional): An integer or list of integers, indicating the dimensions to be squeezed. Default is None.
Axes range is :math:`[-rank(input), rank(input))`. The range of axis is :math:`[-ndim(input), ndim(input))`.
If axes is negative, :math:`axes=axes+rank(input)`. If axis is negative, :math:`axis = axis + ndim(input)`.
If axis is None, all the dimensions of input of size 1 will be removed.
name (str, optional): Please refer to :ref:`api_guide_Name`, Default None. name (str, optional): Please refer to :ref:`api_guide_Name`, Default None.
Returns: Returns:
Variable: Output squeezed Tensor. Data type is same as input Tensor. Tensor: Output squeezed Tensor. Data type is same as input Tensor.
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle import paddle
import paddle.fluid as fluid
with fluid.dygraph.guard(): paddle.enable_imperative()
input_1 = np.random.random([5, 1, 10]).astype("int32")
# input is a variable which shape is [5, 1, 10] x = paddle.rand([5, 1, 10])
input = fluid.dygraph.to_variable(input_1) output = paddle.squeeze(x, axis=1)
# output.shape [5, 10]
output = paddle.squeeze(input, axes=[1])
# output.shape [5, 10]
""" """
if axis is None:
axis = []
elif isinstance(axis, int):
axis = [axis]
elif isinstance(axis, tuple):
axis = list(axis)
helper = LayerHelper("squeeze", **locals()) return layers.squeeze(x, axis, name)
check_variable_and_dtype(input, 'input',
['float32', 'float64', 'int8', 'int32', 'int64'],
'squeeze')
check_type(axes, 'axes', list, 'squeeze')
out = helper.create_variable_for_type_inference(dtype=input.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type="squeeze2",
inputs={"X": input},
attrs={"axes": axes},
outputs={"Out": out,
"XShape": x_shape})
return out
def unsqueeze(input, axes, out=None, name=None): def unsqueeze(input, axes, out=None, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册