未验证 提交 7468bab4 编写于 作者: Y Yichen Zhang 提交者: GitHub

Add full_like composite rule (#50794)

* implement composite full_like and simple unit test

* implement op tests for composite full_like op

* some modification as reviewers suggested
add cinn op test to CMakeLists.txt
fix code style

* fix code style

* modify input args of prim fill_any_like op

* resolve conflicts

* resolve conflicts

* modify python api and unit tests as suggested

* resolve conflicts

* resolve conflicts

* use framework.dtype to convert dtype in Op test
上级 add510b9
...@@ -147,6 +147,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -147,6 +147,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
'int64', 'int64',
'complex64', 'complex64',
'complex128', 'complex128',
'uint16',
], ],
'fill_constant', 'fill_constant',
) )
......
...@@ -1202,8 +1202,14 @@ if($ENV{USE_STANDALONE_EXECUTOR}) ...@@ -1202,8 +1202,14 @@ if($ENV{USE_STANDALONE_EXECUTOR})
PROPERTIES ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0) PROPERTIES ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0)
endif() endif()
set(TEST_CINN_OPS test_softmax_op test_expand_v2_op test_reduce_op set(TEST_CINN_OPS
test_slice_op test_activation_op) test_softmax_op
test_expand_v2_op
test_reduce_op
test_slice_op
test_activation_op
test_full_like_op
test_fill_any_like_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN) if(WITH_CINN)
......
...@@ -19,24 +19,40 @@ from eager_op_test import OpTest, convert_float_to_uint16 ...@@ -19,24 +19,40 @@ from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.framework.dtype as dtypes
def fill_any_like_wrapper(x, value, out_dtype=None, name=None):
if isinstance(out_dtype, int):
tmp_dtype = dtypes.dtype(out_dtype)
elif out_dtype == np.complex64:
raise ValueError("Not supported dtype %s" % out_dtype)
else:
tmp_dtype = out_dtype
return paddle.full_like(x, value, tmp_dtype, name)
class TestFillAnyLikeOp(OpTest): class TestFillAnyLikeOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fill_any_like" self.op_type = "fill_any_like"
self.python_api = paddle.full_like self.prim_op_type = "comp"
self.python_api = fill_any_like_wrapper
self.dtype = np.int32 self.dtype = np.int32
self.value = 0.0 self.value = 0.0
self.init() self.init()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)} self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.attrs = {'value': self.value} self.attrs = {'value': self.value}
self.outputs = {'Out': self.value * np.ones_like(self.inputs["X"])} self.outputs = {'Out': self.value * np.ones_like(self.inputs["X"])}
self.skip_cinn()
def init(self): def init(self):
pass pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_prim=True)
def skip_cinn(self):
pass
class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp): class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
...@@ -44,6 +60,9 @@ class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp): ...@@ -44,6 +60,9 @@ class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
self.dtype = np.float32 self.dtype = np.float32
self.value = 0.0 self.value = 0.0
def skip_cinn(self):
self.enable_cinn = False
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
...@@ -51,7 +70,8 @@ class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp): ...@@ -51,7 +70,8 @@ class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
class TestFillAnyLikeOpBfloat16(OpTest): class TestFillAnyLikeOpBfloat16(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fill_any_like" self.op_type = "fill_any_like"
self.python_api = paddle.full_like self.prim_op_type = "comp"
self.python_api = fill_any_like_wrapper
self.dtype = np.uint16 self.dtype = np.uint16
self.value = 0.0 self.value = 0.0
self.inputs = {'X': np.random.random((219, 232)).astype(np.float32)} self.inputs = {'X': np.random.random((219, 232)).astype(np.float32)}
...@@ -61,31 +81,45 @@ class TestFillAnyLikeOpBfloat16(OpTest): ...@@ -61,31 +81,45 @@ class TestFillAnyLikeOpBfloat16(OpTest):
self.value * np.ones_like(self.inputs["X"]) self.value * np.ones_like(self.inputs["X"])
) )
} }
self.skip_cinn()
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place, check_prim=True)
def skip_cinn(self):
self.enable_cinn = False
class TestFillAnyLikeOpValue1(TestFillAnyLikeOp): class TestFillAnyLikeOpValue1(TestFillAnyLikeOp):
def init(self): def init(self):
self.value = 1.0 self.value = 1.0
def skip_cinn(self):
self.enable_cinn = False
class TestFillAnyLikeOpValue2(TestFillAnyLikeOp): class TestFillAnyLikeOpValue2(TestFillAnyLikeOp):
def init(self): def init(self):
self.value = 1e-10 self.value = 1e-10
def skip_cinn(self):
self.enable_cinn = False
class TestFillAnyLikeOpValue3(TestFillAnyLikeOp): class TestFillAnyLikeOpValue3(TestFillAnyLikeOp):
def init(self): def init(self):
self.value = 1e-100 self.value = 1e-100
def skip_cinn(self):
self.enable_cinn = False
class TestFillAnyLikeOpType(TestFillAnyLikeOp): class TestFillAnyLikeOpType(TestFillAnyLikeOp):
def setUp(self): def setUp(self):
self.op_type = "fill_any_like" self.op_type = "fill_any_like"
self.python_api = paddle.full_like self.prim_op_type = "comp"
self.python_api = fill_any_like_wrapper
self.dtype = np.int32 self.dtype = np.int32
self.value = 0.0 self.value = 0.0
self.init() self.init()
...@@ -99,11 +133,19 @@ class TestFillAnyLikeOpType(TestFillAnyLikeOp): ...@@ -99,11 +133,19 @@ class TestFillAnyLikeOpType(TestFillAnyLikeOp):
* np.ones_like(self.inputs["X"]).astype(np.float32) * np.ones_like(self.inputs["X"]).astype(np.float32)
} }
self.skip_cinn()
def skip_cinn(self):
self.enable_cinn = False
class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp): class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp):
def init(self): def init(self):
self.dtype = np.float16 self.dtype = np.float16
def skip_cinn(self):
self.enable_cinn = False
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
......
...@@ -19,10 +19,21 @@ from op_test import OpTest ...@@ -19,10 +19,21 @@ from op_test import OpTest
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.framework.dtype as dtypes
from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.static import Program, program_guard from paddle.static import Program, program_guard
def fill_any_like_wrapper(x, value, out_dtype=None, name=None):
if isinstance(out_dtype, int):
tmp_dtype = dtypes.dtype(out_dtype)
elif out_dtype == np.complex64:
raise ValueError("Not supported dtype %s" % out_dtype)
else:
tmp_dtype = out_dtype
return paddle.full_like(x, value, tmp_dtype, name)
class TestFullOp(unittest.TestCase): class TestFullOp(unittest.TestCase):
"""Test fill_any_like op(whose API is full_like) for attr out.""" """Test fill_any_like op(whose API is full_like) for attr out."""
...@@ -100,8 +111,10 @@ class TestFullLikeOp1(OpTest): ...@@ -100,8 +111,10 @@ class TestFullLikeOp1(OpTest):
# test basic # test basic
def setUp(self): def setUp(self):
self.op_type = "fill_any_like" self.op_type = "fill_any_like"
self.python_api = paddle.full_like self.prim_op_type = "comp"
self.python_api = fill_any_like_wrapper
self.init_data() self.init_data()
self.skip_cinn()
x = np.zeros(self.shape) x = np.zeros(self.shape)
out = np.full_like(x, self.fill_value, self.dtype) out = np.full_like(x, self.fill_value, self.dtype)
...@@ -119,7 +132,10 @@ class TestFullLikeOp1(OpTest): ...@@ -119,7 +132,10 @@ class TestFullLikeOp1(OpTest):
self.dtype = np.float32 self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True, check_prim=True)
def skip_cinn(self):
pass
class TestFullLikeOp2(TestFullLikeOp1): class TestFullLikeOp2(TestFullLikeOp1):
...@@ -128,6 +144,9 @@ class TestFullLikeOp2(TestFullLikeOp1): ...@@ -128,6 +144,9 @@ class TestFullLikeOp2(TestFullLikeOp1):
self.shape = [1024, 1024] self.shape = [1024, 1024]
self.dtype = np.float64 self.dtype = np.float64
def skip_cinn(self):
self.enable_cinn = False
class TestFullLikeOp3(TestFullLikeOp1): class TestFullLikeOp3(TestFullLikeOp1):
def init_data(self): def init_data(self):
...@@ -135,6 +154,9 @@ class TestFullLikeOp3(TestFullLikeOp1): ...@@ -135,6 +154,9 @@ class TestFullLikeOp3(TestFullLikeOp1):
self.shape = [5000, 5000] self.shape = [5000, 5000]
self.dtype = np.int64 self.dtype = np.int64
def skip_cinn(self):
self.enable_cinn = False
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
import functools import functools
import operator import operator
import paddle.framework.dtype as dtypes
from paddle.fluid import core from paddle.fluid import core
from .primitives import * # noqa: F403 from .primitives import * # noqa: F403
...@@ -269,3 +270,13 @@ def silu_composite(x): ...@@ -269,3 +270,13 @@ def silu_composite(x):
sum_temp = 1 + exp(-x) sum_temp = 1 + exp(-x)
res = x / sum_temp res = x / sum_temp
return res return res
@REGISTER_COMPOSITE('fill_any_like')
def fill_any_like(x, fill_value, dtype, place=None):
"""define composite rule of op full_like."""
"""op name: full_like op type name: fill_any_like."""
"""arg place is not used, add it here to keep same as python api."""
dtype = dtypes.dtype(dtype)
val = full(x.shape, fill_value, dtype)
return val
...@@ -827,6 +827,7 @@ def full_like(x, fill_value, dtype=None, name=None): ...@@ -827,6 +827,7 @@ def full_like(x, fill_value, dtype=None, name=None):
'int16', 'int16',
'int32', 'int32',
'int64', 'int64',
'uint16',
], ],
'full_like', 'full_like',
) )
...@@ -841,6 +842,7 @@ def full_like(x, fill_value, dtype=None, name=None): ...@@ -841,6 +842,7 @@ def full_like(x, fill_value, dtype=None, name=None):
'int16', 'int16',
'int32', 'int32',
'int64', 'int64',
'uint16',
], ],
'full_like/zeros_like/ones_like', 'full_like/zeros_like/ones_like',
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册