未验证 提交 98100fd2 编写于 作者: C co63oc 提交者: GitHub

Add fill_constant_batch_size_like tests (#53736)

上级 51ecd933
......@@ -38,11 +38,12 @@ def fill_constant_batch_size_like(
)
class TestFillConstatnBatchSizeLike1(OpTest):
class TestFillConstantBatchSizeLike1(OpTest):
# test basic
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.python_api = fill_constant_batch_size_like
self.init_dtype()
self.init_data()
input = np.zeros(self.shape)
......@@ -59,9 +60,11 @@ class TestFillConstatnBatchSizeLike1(OpTest):
'force_cpu': self.force_cpu,
}
def init_dtype(self):
self.dtype = np.float32
def init_data(self):
self.shape = [10, 10]
self.dtype = np.float32
self.value = 100
self.input_dim_idx = 0
self.output_dim_idx = 0
......@@ -71,11 +74,16 @@ class TestFillConstatnBatchSizeLike1(OpTest):
self.check_output()
class TestFillConstantBatchSizeLikeFP16Op(TestFillConstantBatchSizeLike1):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda() or not core.supports_bfloat16(),
"core is not compiled with CUDA or place do not support bfloat16",
)
class TestFillConstatnBatchSizeLikeBf16(OpTest):
class TestFillConstantBatchSizeLikeBF16Op(OpTest):
# test bf16
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册