未验证 提交 5464b5c4 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports cast and relu6 (#55442)

上级 a771c1e1
...@@ -525,6 +525,9 @@ class TestUnaryOp(OpTest): ...@@ -525,6 +525,9 @@ class TestUnaryOp(OpTest):
create_unit_test(TestUnaryOp, "tanh", paddle.tanh, "builder.tanh") create_unit_test(TestUnaryOp, "tanh", paddle.tanh, "builder.tanh")
create_unit_test(TestUnaryOp, "relu", paddle.nn.functional.relu, "builder.relu") create_unit_test(TestUnaryOp, "relu", paddle.nn.functional.relu, "builder.relu")
create_unit_test(
TestUnaryOp, "relu6", paddle.nn.functional.relu6, "builder.relu6"
)
create_unit_test(TestUnaryOp, "gelu", paddle.nn.functional.gelu, "builder.gelu") create_unit_test(TestUnaryOp, "gelu", paddle.nn.functional.gelu, "builder.gelu")
create_unit_test( create_unit_test(
TestUnaryOp, "sigmoid", paddle.nn.functional.sigmoid, "builder.sigmoid" TestUnaryOp, "sigmoid", paddle.nn.functional.sigmoid, "builder.sigmoid"
...@@ -630,6 +633,44 @@ class TestScaleOp(OpTest): ...@@ -630,6 +633,44 @@ class TestScaleOp(OpTest):
self.check_outputs_and_grads() self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestCastOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.target_shape = ()
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.cast(x, 'int32')
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("cast_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.cast(x, "int32")
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])
self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
@OpTestTool.skip_if( @OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册