提交 857c0b52 编写于 作者: G gongchen

fix custom op tutorial.

上级 6da853cf
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
- 输入输出的名称通过`init_prim_io_names()`函数定义。 - 输入输出的名称通过`init_prim_io_names()`函数定义。
- 输出Tensor的shape推理方法在`infer_shape()`函数中定义,输出Tensor的dtype推理方法在`infer_dtype()`函数中定义。 - 输出Tensor的shape推理方法在`infer_shape()`函数中定义,输出Tensor的dtype推理方法在`infer_dtype()`函数中定义。
自定义算子与内置算子的唯一区别是需要通过在`__init__()`函数中导入算子实现函数(`from .square_impl import CusSquareImpl`)来将算子实现注册到后端。本用例在`square_impl.py`中定义了算子实现和算子信息,将在后文中说明。 自定义算子与内置算子的唯一区别是需要通过在`__init__()`函数中导入算子实现函数(`from square_impl import CusSquareImpl`)来将算子实现注册到后端。本用例在`square_impl.py`中定义了算子实现和算子信息,将在后文中说明。
以Square算子原语`cus_square.py`为例,给出如下示例代码。 以Square算子原语`cus_square.py`为例,给出如下示例代码。
...@@ -53,7 +53,7 @@ class CusSquare(PrimitiveWithInfer): ...@@ -53,7 +53,7 @@ class CusSquare(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['y']) self.init_prim_io_names(inputs=['x'], outputs=['y'])
from .square_impl import CusSquareImpl # Import the entry function of the kernel implementation from relative path or PYTHONPATH. from square_impl import CusSquareImpl # Import the entry function of the kernel implementation from relative path or PYTHONPATH.
def infer_shape(self, data_shape): def infer_shape(self, data_shape):
return data_shape return data_shape
...@@ -155,7 +155,7 @@ import mindspore.nn as nn ...@@ -155,7 +155,7 @@ import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
# Import the definition of the CusSquare primtive. # Import the definition of the CusSquare primtive.
from .cus_square import CusSquare from cus_square import CusSquare
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
...@@ -200,7 +200,7 @@ class CusSquare(PrimitiveWithInfer): ...@@ -200,7 +200,7 @@ class CusSquare(PrimitiveWithInfer):
def __init__(self): def __init__(self):
"""init CusSquare""" """init CusSquare"""
self.init_prim_io_names(inputs=['x'], outputs=['y']) self.init_prim_io_names(inputs=['x'], outputs=['y'])
from .square_impl import CusSquareImpl from square_impl import CusSquareImpl
def infer_shape(self, data_shape): def infer_shape(self, data_shape):
return data_shape return data_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册