未验证 提交 515efe42 编写于 作者: F furnace 提交者: GitHub

add empty_like op (python, and unit test), use c++ implementation of empty op, (#27287)

and optimize the c++ implmentation of empty op as PR#26659 reviews,
and add bool for shape op.
上级 f36b9a7f
......@@ -55,31 +55,38 @@ class EmptyOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "empty");
if (context->HasInput("ShapeTensor")) {
auto dims = context->GetInputDim("ShapeTensor");
auto shape_dims = context->GetInputDim("ShapeTensor");
int num_ele = 1;
for (int i = 0; i < dims.size(); ++i) {
num_ele *= dims[i];
for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i];
}
context->SetOutputDim("Out", framework::make_ddim({num_ele}));
auto vec_dims = std::vector<int>(num_ele, -1);
context->SetOutputDim("Out", framework::make_ddim(vec_dims));
} else if (context->HasInputs("ShapeTensorList")) {
std::vector<int> out_dims;
auto dims_list = context->GetInputsDim("ShapeTensorList");
for (size_t i = 0; i < dims_list.size(); ++i) {
auto& dims = dims_list[i];
PADDLE_ENFORCE_EQ(
dims, framework::make_ddim({1}),
"ShapeError: The shape of Tensor in list must be [1]. "
"But received the shape "
"is [%s]",
dims);
out_dims.push_back(dims[0]);
PADDLE_ENFORCE_EQ(dims, framework::make_ddim({1}),
platform::errors::InvalidArgument(
"The shape of Tensor in list must be [1]. "
"But received the shape is [%s]",
dims));
out_dims.push_back(-1);
}
context->SetOutputDim("Out", framework::make_ddim(out_dims));
} else {
auto& shape = context->Attrs().Get<std::vector<int64_t>>("shape");
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE_GE(
shape[i], 0,
platform::errors::InvalidArgument(
"Each value of attribute 'shape' is expected to be no less "
"than 0. But recieved: shape[%u] = %d; shape = [%s].",
i, shape[i], framework::make_ddim(shape)));
}
context->SetOutputDim("Out", framework::make_ddim(shape));
}
}
......
......@@ -68,6 +68,6 @@ REGISTER_OPERATOR(
shape, ops::ShapeOp, ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<int>, ops::ShapeKernel<int32_t>,
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<bool>, ops::ShapeKernel<int>,
ops::ShapeKernel<int64_t>, ops::ShapeKernel<float>,
ops::ShapeKernel<double>);
......@@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/shape_op.h"
REGISTER_OP_CUDA_KERNEL(
shape, paddle::operators::ShapeKernel<int>,
paddle::operators::ShapeKernel<int32_t>,
shape, paddle::operators::ShapeKernel<bool>,
paddle::operators::ShapeKernel<int>,
paddle::operators::ShapeKernel<int64_t>,
paddle::operators::ShapeKernel<float>,
paddle::operators::ShapeKernel<double>,
......
......@@ -77,6 +77,7 @@ from .tensor.creation import triu #DEFINE_ALIAS
from .tensor.creation import tril #DEFINE_ALIAS
from .tensor.creation import meshgrid #DEFINE_ALIAS
from .tensor.creation import empty #DEFINE_ALIAS
from .tensor.creation import empty_like #DEFINE_ALIAS
from .tensor.linalg import matmul #DEFINE_ALIAS
from .tensor.linalg import dot #DEFINE_ALIAS
# from .tensor.linalg import einsum #DEFINE_ALIAS
......
......@@ -11229,7 +11229,7 @@ def shape(input):
input.shape = [3, 2]
Args:
input (Variable): The input can be N-D Tensor or SelectedRows with data type float16, float32, float64, int32, int64.
input (Variable): The input can be N-D Tensor or SelectedRows with data type bool, float16, float32, float64, int32, int64.
If input variable is type of SelectedRows, returns the shape of it's inner tensor.
Returns:
......@@ -11253,8 +11253,8 @@ def shape(input):
print(res) # [array([ 3, 100, 100], dtype=int32)]
"""
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'],
'shape')
input, 'input',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'shape')
helper = LayerHelper('shape', **locals())
out = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.data_feeder import convert_dtype
import paddle.fluid.core as core
from paddle.static import program_guard, Program
class TestEmptyLikeAPICommon(unittest.TestCase):
def __check_out__(self, out):
data_type = convert_dtype(out.dtype)
self.assertEqual(data_type, self.dst_dtype,
'dtype should be %s, but get %s' %
(self.dst_dtype, data_type))
shape = out.shape
self.assertTupleEqual(shape, self.dst_shape,
'shape should be %s, but get %s' %
(self.dst_shape, shape))
if data_type in ['float32', 'float64', 'int32', 'int64']:
max_value = np.nanmax(out)
min_value = np.nanmin(out)
always_non_full_zero = max_value > min_value
always_full_zero = max_value == 0.0 and min_value == 0.0
self.assertTrue(always_full_zero or always_non_full_zero,
'always_full_zero or always_non_full_zero.')
elif data_type in ['bool']:
total_num = out.size
true_num = np.sum(out == True)
false_num = np.sum(out == False)
self.assertTrue(total_num == true_num + false_num,
'The value should always be True or False.')
else:
self.assertTrue(False, 'invalid data type')
class TestEmptyLikeAPI(TestEmptyLikeAPICommon):
def setUp(self):
self.init_config()
def test_dygraph_api_out(self):
paddle.disable_static()
out = paddle.empty_like(self.x, self.dtype)
self.__check_out__(out.numpy())
paddle.enable_static()
def init_config(self):
self.x = np.random.random((200, 3)).astype("float32")
self.dtype = self.x.dtype
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI2(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("float64")
self.dtype = self.x.dtype
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI3(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("int")
self.dtype = self.x.dtype
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI4(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("int64")
self.dtype = self.x.dtype
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI5(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("bool")
self.dtype = self.x.dtype
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI6(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("float64")
self.dtype = "float32"
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI7(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("int")
self.dtype = "float32"
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI8(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("int64")
self.dtype = "float32"
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI9(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("bool")
self.dtype = "float32"
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI10(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("float32")
self.dtype = "bool"
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI_Static(TestEmptyLikeAPICommon):
def setUp(self):
self.init_config()
def test_static_graph(self):
dtype = 'float32'
train_program = Program()
startup_program = Program()
with program_guard(train_program, startup_program):
x = np.random.random(self.x_shape).astype(dtype)
data_x = paddle.static.data(
'x', shape=self.data_x_shape, dtype=dtype)
out = paddle.empty_like(data_x)
place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
res = exe.run(train_program, feed={'x': x}, fetch_list=[out])
self.dst_dtype = dtype
self.dst_shape = x.shape
self.__check_out__(res[0])
def init_config(self):
self.x_shape = (200, 3)
self.data_x_shape = [200, 3]
class TestEmptyLikeAPI_Static2(TestEmptyLikeAPI_Static):
def init_config(self):
self.x_shape = (3, 200, 3)
self.data_x_shape = [-1, 200, 3]
class TestEmptyError(unittest.TestCase):
def test_attr(self):
def test_dtype():
x = np.random.random((200, 3)).astype("float64")
dtype = 'uint8'
result = paddle.empty_like(x, dtype=dtype)
self.assertRaises(TypeError, test_dtype)
if __name__ == '__main__':
unittest.main()
......@@ -41,6 +41,7 @@ from .creation import triu #DEFINE_ALIAS
from .creation import tril #DEFINE_ALIAS
from .creation import meshgrid #DEFINE_ALIAS
from .creation import empty #DEFINE_ALIAS
from .creation import empty_like #DEFINE_ALIAS
from .io import save #DEFINE_ALIAS
from .io import load #DEFINE_ALIAS
from .linalg import matmul #DEFINE_ALIAS
......
......@@ -49,6 +49,7 @@ __all__ = [
'full',
'full_like',
'empty',
'empty_like',
'triu',
'tril',
'meshgrid'
......@@ -1068,3 +1069,70 @@ def empty(shape, dtype=None, name=None):
stop_gradient=True)
out.stop_gradient = True
return out
def empty_like(x, dtype=None, name=None):
"""
This Op returns a Tensor with uninitialized data which has identical shape of ``x`` and ``dtype``.
If the ``dtype`` is None, the data type of Tensor is same with ``x``.
Args:
x(Tensor): The input tensor which specifies shape and data type. The data type can be bool, float16, float32, float64, int32, int64.
dtype(np.dtype|str, optional): The data type of output. The data type can be one
of bool, float16, float32, float64, int32, int64. The default value is None, which means the output
data type is the same as input.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: Tensor which is created according to ``x`` and ``dtype``, and is uninitialized.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static() # Now we are in imperative mode
paddle.set_device("cpu") # and use cpu device
x = paddle.randn([2, 3], 'float32')
output = paddle.empty_like(x)
#[[1.8491974e+20 1.8037303e+28 1.7443726e+28] # uninitialized
# [4.9640171e+28 3.0186127e+32 5.6715899e-11]] # uninitialized
"""
if dtype is None:
dtype = x.dtype
dtype = convert_dtype(dtype)
if in_dygraph_mode():
out = core.ops.empty('shape', x.shape, 'dtype',
convert_np_dtype_to_dtype_(dtype))
out.stop_gradient = True
return out
helper = LayerHelper("empty_like", **locals())
check_variable_and_dtype(
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'empty_like')
check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'empty_like')
out = helper.create_variable_for_type_inference(dtype=dtype)
inputs = {}
attrs = {}
attrs['dtype'] = convert_np_dtype_to_dtype_(dtype)
shape = paddle.shape(x)
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='empty_like')
helper.append_op(
type='empty',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs,
stop_gradient=True)
out.stop_gradient = True
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册