未验证 提交 915341e3 编写于 作者: W wawltor 提交者: GitHub

Add the zeros, ones, ones_like, zeros_like for api 2.0, test=develop (#23471)

Update the new api ops of creation ops to the api 2.0
上级 56b50c97
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_any_like_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -29,6 +30,25 @@ class FillAnyLikeOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
const auto &data_type = ctx.Attr<int>("dtype");
if (data_type >= 0) {
kt.data_type_ = static_cast<framework::proto::VarType::Type>(data_type);
}
return kt;
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
}
};
class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -37,6 +57,10 @@ class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Out", "The variable will be filled up with specified value.");
AddAttr<float>("value", "The filled value").SetDefault(0.0);
AddAttr<int>("dtype",
"Output tensor data type. defalut value is -1,"
"according to the input dtype.")
.SetDefault(-1);
AddComment(R"DOC(
FillAnyLike Operator.
......@@ -47,18 +71,37 @@ The output will have the same shape and dtype as the input.
}
};
class FillAnyLikeVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front();
auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
if (var_data_type < 0) {
const auto &input_var_name = ctx->Input("X").front();
ctx->SetDataType(out_var_name, ctx->GetDataType(input_var_name));
} else {
ctx->SetDataType(out_var_name, var_data_type);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_any_like, ops::FillAnyLikeOp,
ops::FillAnyLikeOpMaker);
REGISTER_OPERATOR(
fill_any_like, ops::FillAnyLikeOp, ops::FillAnyLikeOpMaker,
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FillAnyLikeVarTypeInference)
REGISTER_OP_CPU_KERNEL(
fill_any_like,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, bool>);
......@@ -22,6 +22,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, bool>);
......@@ -42,14 +42,14 @@ import paddle.nn
# from .tensor.creation import crop_.tensor #DEFINE_ALIAS
# from .tensor.creation import diag #DEFINE_ALIAS
# from .tensor.creation import eye #DEFINE_ALIAS
# from .tensor.creation import fill_constant #DEFINE_ALIAS
from .tensor.creation import fill_constant #DEFINE_ALIAS
# from .tensor.creation import get_.tensor_from_selected_rows #DEFINE_ALIAS
from .tensor.creation import linspace #DEFINE_ALIAS
# from .tensor.creation import ones #DEFINE_ALIAS
# from .tensor.creation import ones_like #DEFINE_ALIAS
from .tensor.creation import ones #DEFINE_ALIAS
from .tensor.creation import ones_like #DEFINE_ALIAS
# from .tensor.creation import range #DEFINE_ALIAS
# from .tensor.creation import zeros #DEFINE_ALIAS
# from .tensor.creation import zeros_like #DEFINE_ALIAS
from .tensor.creation import zeros #DEFINE_ALIAS
from .tensor.creation import zeros_like #DEFINE_ALIAS
# from .tensor.creation import arrange #DEFINE_ALIAS
# from .tensor.creation import eye #DEFINE_ALIAS
from .tensor.creation import full #DEFINE_ALIAS
......
......@@ -11,17 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from six.moves import reduce
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator
from paddle.fluid.framework import Variable, device_guard
from paddle.fluid.framework import device_guard, default_main_program, dygraph_only, _dygraph_tracer
from paddle.fluid.framework import OpProtoHolder, Variable
from paddle.fluid.initializer import Constant
from paddle.fluid.core import VarDesc
from paddle.fluid import core
from paddle.fluid.data_feeder import check_type, check_dtype, convert_dtype
from paddle.fluid.layers import utils
from paddle.fluid.layers import fill_constant
from paddle.fluid import core, dygraph_utils
from paddle.fluid.data_feeder import check_type, check_dtype, check_variable_and_dtype, convert_dtype
from paddle.fluid.layers import fill_constant, utils, scale
from paddle.fluid.layers.layer_function_generator import templatedoc
import paddle.fluid as fluid
import numpy
import warnings
......@@ -14,6 +14,8 @@
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.compat as cpt
import unittest
......@@ -59,6 +61,23 @@ class TestFillAnyLikeOpValue3(TestFillAnyLikeOp):
self.value = 1e-100
class TestFillAnyLikeOpType(TestFillAnyLikeOp):
def setUp(self):
self.op_type = "fill_any_like"
self.dtype = np.int32
self.value = 0.0
self.init()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.attrs = {
'value': self.value,
'dtype': int(core.VarDesc.VarType.FP32)
}
self.outputs = {
'Out':
self.value * np.ones_like(self.inputs["X"]).astype(np.float32)
}
class TestFillAnyLikeOpOverflow(TestFillAnyLikeOp):
def init(self):
self.value = 1e100
......@@ -77,5 +96,102 @@ class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp):
self.dtype = np.float16
class ApiOnesLikeTest(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[10], dtype="float64", name="data")
ones = paddle.ones_like(data, device="cpu")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(feed={"data": np.random.rand(10)},
fetch_list=[ones])
expected_result = np.ones(10, dtype="float64")
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[10], dtype="float64", name="data")
ones = paddle.ones_like(data, device="cpu", dtype="float32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(feed={"data": np.random.rand(10)},
fetch_list=[ones])
expected_result = np.ones(10, dtype="float32")
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[10], dtype="float64", name="data")
ones = paddle.ones_like(data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(feed={"data": np.random.rand(10)},
fetch_list=[ones])
expected_result = np.ones(10, dtype="float32")
self.assertEqual((result == expected_result).all(), True)
class ApiZerosLikeTest(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[10], dtype="float64", name="data")
zeros = paddle.zeros_like(data, device="cpu")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(feed={"data": np.random.rand(10)},
fetch_list=[zeros])
expected_result = np.zeros(10, dtype="float64")
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[10], dtype="float64", name="data")
zeros = paddle.zeros_like(data, device="cpu", dtype="float32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(feed={"data": np.random.rand(10)},
fetch_list=[zeros])
expected_result = np.zeros(10, dtype="float32")
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[10], dtype="float64", name="data")
zeros = paddle.zeros_like(data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(feed={"data": np.random.rand(10)},
fetch_list=[zeros])
expected_result = np.zeros(10, dtype="float32")
self.assertEqual((result == expected_result).all(), True)
class TestOnesZerosError(unittest.TestCase):
def test_errors(self):
def test_device_error1():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
paddle.ones_like(data, device="opu")
self.assertRaises(ValueError, test_device_error1)
def test_device_error2():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
paddle.ones_like(data, dtype="float")
self.assertRaises(ValueError, test_device_error2)
def test_device_error3():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
paddle.zeros_like(data, device="opu")
self.assertRaises(ValueError, test_device_error3)
def test_device_error4():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
paddle.zeros_like(data, dtype="float")
self.assertRaises(ValueError, test_device_error4)
if __name__ == "__main__":
unittest.main()
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
......@@ -81,6 +82,28 @@ class TestFillConstantOp4(OpTest):
self.check_output()
class TestFillConstantOp5(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=[1], dtype="float32")
out = paddle.zeros(shape=[1], out=data, dtype="float32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(feed={"X": np.array(
[0.1], dtype="float32")},
fetch_list=[data, out])
self.assertEqual(result[0], result[1])
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=[1], dtype="float32")
out = paddle.ones(shape=[1], out=data, dtype="float32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(feed={"X": np.array(
[0.1], dtype="float32")},
fetch_list=[data, out])
self.assertEqual(result[0], result[1])
class TestFillConstantOpWithSelectedRows(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
......@@ -303,5 +326,74 @@ class TestFillConstantOpError(unittest.TestCase):
self.assertRaises(TypeError, test_shape_tensor_list_dtype)
class ApiZerosTest(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
zeros = paddle.zeros(shape=[10], dtype="float64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="float64")
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
zeros = paddle.zeros(shape=[10], dtype="int64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
zeros = paddle.zeros(shape=[10], dtype="int64", device="cpu")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)
class ApiOnesTest(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
ones = paddle.ones(shape=[10], dtype="float64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="float64")
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
ones = paddle.ones(shape=[10], dtype="int64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
ones = paddle.ones(shape=[10], dtype="int64", device="cpu")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)
class ApiOnesZerosError(unittest.TestCase):
def test_errors(self):
def test_error1():
with fluid.program_guard(fluid.Program()):
ones = paddle.ones(shape=10, dtype="int64", device="opu")
self.assertRaises(ValueError, test_error1)
def test_error2():
with fluid.program_guard(fluid.Program()):
ones = paddle.ones(shape=10, dtype="int64", device="opu")
self.assertRaises(ValueError, test_error2)
if __name__ == "__main__":
unittest.main()
......@@ -25,11 +25,11 @@ __all__ = [
# 'fill_constant',
# 'get_tensor_from_selected_rows',
'linspace',
# 'ones',
# 'ones_like',
'ones',
'ones_like',
# 'range',
# 'zeros',
# 'zeros_like',
'zeros',
'zeros_like',
# 'arrange',
# 'eye',
'full',
......@@ -126,6 +126,218 @@ def linspace(start, stop, num, dtype, out=None, device=None, name=None):
return out
def ones(shape, dtype=None, out=None, device=None):
"""
The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 1.
Args:
shape(tuple|list): Shape of output tensor.
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor, it supports
bool, float16, float32, float64, int32 and int64.
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
device(str, optional): Which device to run the operator. The :attr:`device` must be
None,'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in
the paddle program. Default value is False.
Returns:
Variable: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Examples:
.. code-block:: python
import paddle
data = paddle.ones(shape=[3, 2], dtype='float32') # [[1., 1.], [1., 1.], [1., 1.]]
data = paddle.ones(shape=[2, 2], dtype='float32', device='cpu') # [[1., 1.], [1., 0.]]
"""
check_dtype(dtype, 'create data type',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'zeros')
if device is not None:
if device not in ['cpu', 'gpu']:
raise ValueError(
"The value of 'device' in zeros_op must be cpu or gpu, but received %s."
% (device))
with fluid.device_guard(device):
return fill_constant(value=1.0, shape=shape, dtype=dtype, out=out)
return fill_constant(value=1.0, shape=shape, dtype=dtype, out=out)
def ones_like(input, dtype=None, device=None, name=None):
"""
This function creates a ones tensor which has identical shape and dtype
with `input`.
Args:
input(Variable): The input tensor which specifies shape and dtype.The dtype of input can be
float32, float64, int32, int64.
dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type can be set bool, float32, float64, int32, int64.
The default value is None, the dtype is the same as input.
device(str, optional): Which device to run the operator. The :attr:`device` must be
None, 'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in
the paddle program. Default value is None.
name(str, optional): The name of output variable, normally there is no need for user to set this this property.
Default value is None, the framework set the name of output variable.
Returns:
out(Variable): The tensor variable storing the output.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
x = fluid.layers.data(name='x', dtype='float32', shape=[3], append_batch_size=False)
data = paddle.ones_like(x) # data=[1.0, 1.0, 1.0]
data1 = paddle.ones_like(input=x, device="gpu") data1=[1.0, 1.0. 1.0]
"""
helper = LayerHelper("zeros_like", **locals())
attrs = {"value": 1.0}
var_dtype = None
if dtype is not None:
check_dtype(
dtype, 'create data type',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'zeros_like')
var_dtype = convert_np_dtype_to_dtype_(dtype)
attrs["dtype"] = var_dtype
else:
var_dtype = input.dtype
out = helper.create_variable_for_type_inference(dtype=var_dtype)
if device is not None:
if device not in ['cpu', 'gpu']:
raise ValueError(
"The value of 'device' in zeros_op must be cpu or gpu, but received %s."
% (device))
with fluid.device_guard(device):
helper.append_op(
type='fill_any_like',
inputs={'X': [input]},
attrs=attrs,
outputs={'Out': [out]})
return out
helper.append_op(
type='fill_any_like',
inputs={'X': [input]},
attrs=attrs,
outputs={'Out': [out]})
out.stop_gradient = True
return out
def zeros(shape, dtype, out=None, device=None):
"""
The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0.
Args:
shape(tuple|list): Shape of output tensor.
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor, it supports
bool, float16, float32, float64, int32 and int64.
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
device(str, optional): Which device to run the operator. The :attr:`device` must be
None,'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in
the paddle program. Default value is False.
Returns:
Variable: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0.
Examples:
.. code-block:: python
import paddle
data = paddle.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]]
data = paddle.zeros(shape=[2, 2], dtype='float32', device='cpu') # [[0., 0.], [0., 0.]]
"""
check_dtype(dtype, 'create data type',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'zeros')
if device is not None:
if device not in ['cpu', 'gpu']:
raise ValueError(
"The value of 'device' in zeros_op must be cpu or gpu, but received %s."
% (device))
with fluid.device_guard(device):
return fill_constant(value=0.0, shape=shape, dtype=dtype, out=out)
return fill_constant(value=0.0, shape=shape, dtype=dtype, out=out)
def zeros_like(input, dtype=None, device=None, name=None):
"""
This function creates a zeros tensor which has identical shape and dtype
with `input`.
Args:
input(Variable): The input tensor which specifies shape and dtype.The dtype of input can be
bool, float32, float64, int32, int64.
dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type can be set bool, float32, float64, int32, int64.
The default value is None, the dtype is the same as input.
device(str, optional): Which device to run the operator. The :attr:`device` must be
None, 'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in
the paddle program. Default value is None.
name(str, optional): The name of output variable, normally there is no need for user to set this this property.
Default value is None, the framework set the name of output variable.
Returns:
out(Variable): The tensor variable storing the output.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
x = fluid.layers.data(name='x', dtype='float32', shape=[3], append_batch_size=False)
data = paddle.ones_like(x) # data=[1.0, 1.0, 1.0]
data1 = paddle.ones_like(input=x, device="gpu") data1=[1.0, 1.0. 1.0]
"""
helper = LayerHelper("zeros_like", **locals())
attrs = {"value": 0.0}
var_dtype = None
if dtype is not None:
check_dtype(dtype, 'create data type',
['bool', 'float32', 'float64', 'int32', 'int64'],
'zeros_like')
var_dtype = convert_np_dtype_to_dtype_(dtype)
attrs["dtype"] = var_dtype
else:
var_dtype = input.dtype
out = helper.create_variable_for_type_inference(dtype=var_dtype)
if device is not None:
if device not in ['cpu', 'gpu']:
raise ValueError(
"The value of 'device' in zeros_op must be cpu or gpu, but received %s."
% (device))
with fluid.device_guard(device):
helper.append_op(
type='fill_any_like',
inputs={'X': [input]},
attrs=attrs,
outputs={'Out': [out]})
return out
helper.append_op(
type='fill_any_like',
inputs={'X': [input]},
attrs=attrs,
outputs={'Out': [out]})
out.stop_gradient = True
return out
def full(shape,
fill_value,
out=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册