提交 48f90eb7 编写于 作者: D dengwentao

add custom op st to ci

上级 9bda080b
...@@ -24,7 +24,7 @@ class CusSquare(PrimitiveWithInfer): ...@@ -24,7 +24,7 @@ class CusSquare(PrimitiveWithInfer):
def __init__(self): def __init__(self):
"""init CusSquare""" """init CusSquare"""
self.init_prim_io_names(inputs=['x'], outputs=['y']) self.init_prim_io_names(inputs=['x'], outputs=['y'])
from .square_impl import CusSquare from square_impl import CusSquare
def vm_impl(self, x): def vm_impl(self, x):
x = x.asnumpy() x = x.asnumpy()
......
...@@ -16,7 +16,7 @@ import numpy as np ...@@ -16,7 +16,7 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from .cus_square import CusSquare from cus_square import CusSquare
import pytest import pytest
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
...@@ -32,6 +32,7 @@ class Net(nn.Cell): ...@@ -32,6 +32,7 @@ class Net(nn.Cell):
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_net(): def test_net():
x = np.array([1.0, 4.0, 9.0]).astype(np.float32) x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册