未验证 提交 7468bab4 编写于 作者: Y Yichen Zhang 提交者: GitHub

Add full_like composite rule (#50794)

* implement composite full_like and simple unit test

* implement op tests for composite full_like op

* some modification as reviewers suggested
add cinn op test to CMakeLists.txt
fix code style

* fix code style

* modify input args of prim fill_any_like op

* resolve conflicts

* resolve conflicts

* modify python api and unit tests as suggested

* resolve conflicts

* resolve conflicts

* use framework.dtype to convert dtype in Op test
上级 add510b9
......@@ -147,6 +147,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
'int64',
'complex64',
'complex128',
'uint16',
],
'fill_constant',
)
......
......@@ -1202,8 +1202,14 @@ if($ENV{USE_STANDALONE_EXECUTOR})
PROPERTIES ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0)
endif()
set(TEST_CINN_OPS test_softmax_op test_expand_v2_op test_reduce_op
test_slice_op test_activation_op)
set(TEST_CINN_OPS
test_softmax_op
test_expand_v2_op
test_reduce_op
test_slice_op
test_activation_op
test_full_like_op
test_fill_any_like_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -19,24 +19,40 @@ from eager_op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid.core as core
import paddle.framework.dtype as dtypes
def fill_any_like_wrapper(x, value, out_dtype=None, name=None):
if isinstance(out_dtype, int):
tmp_dtype = dtypes.dtype(out_dtype)
elif out_dtype == np.complex64:
raise ValueError("Not supported dtype %s" % out_dtype)
else:
tmp_dtype = out_dtype
return paddle.full_like(x, value, tmp_dtype, name)
class TestFillAnyLikeOp(OpTest):
def setUp(self):
self.op_type = "fill_any_like"
self.python_api = paddle.full_like
self.prim_op_type = "comp"
self.python_api = fill_any_like_wrapper
self.dtype = np.int32
self.value = 0.0
self.init()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.attrs = {'value': self.value}
self.outputs = {'Out': self.value * np.ones_like(self.inputs["X"])}
self.skip_cinn()
def init(self):
pass
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def skip_cinn(self):
pass
class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
......@@ -44,6 +60,9 @@ class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
self.dtype = np.float32
self.value = 0.0
def skip_cinn(self):
self.enable_cinn = False
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
......@@ -51,7 +70,8 @@ class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
class TestFillAnyLikeOpBfloat16(OpTest):
def setUp(self):
self.op_type = "fill_any_like"
self.python_api = paddle.full_like
self.prim_op_type = "comp"
self.python_api = fill_any_like_wrapper
self.dtype = np.uint16
self.value = 0.0
self.inputs = {'X': np.random.random((219, 232)).astype(np.float32)}
......@@ -61,31 +81,45 @@ class TestFillAnyLikeOpBfloat16(OpTest):
self.value * np.ones_like(self.inputs["X"])
)
}
self.skip_cinn()
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_prim=True)
def skip_cinn(self):
self.enable_cinn = False
class TestFillAnyLikeOpValue1(TestFillAnyLikeOp):
def init(self):
self.value = 1.0
def skip_cinn(self):
self.enable_cinn = False
class TestFillAnyLikeOpValue2(TestFillAnyLikeOp):
def init(self):
self.value = 1e-10
def skip_cinn(self):
self.enable_cinn = False
class TestFillAnyLikeOpValue3(TestFillAnyLikeOp):
def init(self):
self.value = 1e-100
def skip_cinn(self):
self.enable_cinn = False
class TestFillAnyLikeOpType(TestFillAnyLikeOp):
def setUp(self):
self.op_type = "fill_any_like"
self.python_api = paddle.full_like
self.prim_op_type = "comp"
self.python_api = fill_any_like_wrapper
self.dtype = np.int32
self.value = 0.0
self.init()
......@@ -99,11 +133,19 @@ class TestFillAnyLikeOpType(TestFillAnyLikeOp):
* np.ones_like(self.inputs["X"]).astype(np.float32)
}
self.skip_cinn()
def skip_cinn(self):
self.enable_cinn = False
class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp):
def init(self):
self.dtype = np.float16
def skip_cinn(self):
self.enable_cinn = False
if __name__ == "__main__":
paddle.enable_static()
......
......@@ -19,10 +19,21 @@ from op_test import OpTest
import paddle
import paddle.fluid.core as core
import paddle.framework.dtype as dtypes
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.static import Program, program_guard
def fill_any_like_wrapper(x, value, out_dtype=None, name=None):
if isinstance(out_dtype, int):
tmp_dtype = dtypes.dtype(out_dtype)
elif out_dtype == np.complex64:
raise ValueError("Not supported dtype %s" % out_dtype)
else:
tmp_dtype = out_dtype
return paddle.full_like(x, value, tmp_dtype, name)
class TestFullOp(unittest.TestCase):
"""Test fill_any_like op(whose API is full_like) for attr out."""
......@@ -100,8 +111,10 @@ class TestFullLikeOp1(OpTest):
# test basic
def setUp(self):
self.op_type = "fill_any_like"
self.python_api = paddle.full_like
self.prim_op_type = "comp"
self.python_api = fill_any_like_wrapper
self.init_data()
self.skip_cinn()
x = np.zeros(self.shape)
out = np.full_like(x, self.fill_value, self.dtype)
......@@ -119,7 +132,10 @@ class TestFullLikeOp1(OpTest):
self.dtype = np.float32
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=True, check_prim=True)
def skip_cinn(self):
pass
class TestFullLikeOp2(TestFullLikeOp1):
......@@ -128,6 +144,9 @@ class TestFullLikeOp2(TestFullLikeOp1):
self.shape = [1024, 1024]
self.dtype = np.float64
def skip_cinn(self):
self.enable_cinn = False
class TestFullLikeOp3(TestFullLikeOp1):
def init_data(self):
......@@ -135,6 +154,9 @@ class TestFullLikeOp3(TestFullLikeOp1):
self.shape = [5000, 5000]
self.dtype = np.int64
def skip_cinn(self):
self.enable_cinn = False
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
......
......@@ -20,6 +20,7 @@
import functools
import operator
import paddle.framework.dtype as dtypes
from paddle.fluid import core
from .primitives import * # noqa: F403
......@@ -269,3 +270,13 @@ def silu_composite(x):
sum_temp = 1 + exp(-x)
res = x / sum_temp
return res
@REGISTER_COMPOSITE('fill_any_like')
def fill_any_like(x, fill_value, dtype, place=None):
"""define composite rule of op full_like."""
"""op name: full_like op type name: fill_any_like."""
"""arg place is not used, add it here to keep same as python api."""
dtype = dtypes.dtype(dtype)
val = full(x.shape, fill_value, dtype)
return val
......@@ -827,6 +827,7 @@ def full_like(x, fill_value, dtype=None, name=None):
'int16',
'int32',
'int64',
'uint16',
],
'full_like',
)
......@@ -841,6 +842,7 @@ def full_like(x, fill_value, dtype=None, name=None):
'int16',
'int32',
'int64',
'uint16',
],
'full_like/zeros_like/ones_like',
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册