提交 2e423f9e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!141 doc(custom_op): delete comment in code block.

Merge pull request !141 from gongchen/fix_customop1
...@@ -83,13 +83,19 @@ class CusSquare(PrimitiveWithInfer): ...@@ -83,13 +83,19 @@ class CusSquare(PrimitiveWithInfer):
算子信息是指导后端选择算子实现的关键信息,同时也指导后端为算子插入合适的类型和格式转换。它通过`TBERegOp`接口定义,通过`op_info_register`装饰器将算子信息与算子实现入口函数绑定。当算子实现py文件被导入时,`op_info_register`装饰器会将算子信息注册到后端的算子信息库中。更多关于算子信息的使用方法请参考`TBERegOp`的成员方法的注释说明。 算子信息是指导后端选择算子实现的关键信息,同时也指导后端为算子插入合适的类型和格式转换。它通过`TBERegOp`接口定义,通过`op_info_register`装饰器将算子信息与算子实现入口函数绑定。当算子实现py文件被导入时,`op_info_register`装饰器会将算子信息注册到后端的算子信息库中。更多关于算子信息的使用方法请参考`TBERegOp`的成员方法的注释说明。
> 算子信息中定义输入输出信息的个数和顺序、算子实现入口函数的参数中的输入输出信息的个数和顺序、算子原语中输入输出名称列表的个数和顺序,三者要完全一致。 > - 算子信息中定义输入输出信息的个数和顺序、算子实现入口函数的参数中的输入输出信息的个数和顺序、算子原语中输入输出名称列表的个数和顺序,三者要完全一致。
> - 算子如果带属性,在算子信息中需要用`attr()`描述属性信息,属性的名称与算子原语定义中的属性名称要一致。
> 算子如果带属性,在算子信息中需要用`attr()`描述属性信息,属性的名称与算子原语定义中的属性名称要一致。
### 示例 ### 示例
下面以`Square`算子的TBE实现`square_impl.py`为例进行介绍。`square_compute`是算子实现的计算函数,通过调到`te.lang.cce`提供的API描述了`x * x`的计算逻辑。`cus_square_op_info `是算子信息,通过`TBERegOp`来定义。`TBERegOp`中的`dtype_format`是用来描述算子支持的数据类型,下面示例中注册了两项说明该算子支持两种数据类型,而每一项需按照输入和输出的顺序依次描述支持的格式。第一个`dtype_format`说明支持的第一种数据类型是input0为F32_Default格式,output0为F32_Default格式。第二个`dtype_format`说明支持的第二种数据类型是input0为F16_Default格式,output0为F16_Default格式。 下面以`Square`算子的TBE实现`square_impl.py`为例进行介绍。`square_compute`是算子实现的计算函数,通过调到`te.lang.cce`提供的API描述了`x * x`的计算逻辑。`cus_square_op_info `是算子信息,通过`TBERegOp`来定义。
`TBERegOp`中:
- `TBERegOp("CusSquare")`中算子注册名称`CusSquare`需要与算子名称一致。
- `fusion_type("OPAQUE")``OPAQUE`是说明自定义算子采取不融合策略。
- `kernel_name("CusSquareImpl")`中"CusSquareImpl"需要与算子入口函数名称一致。
- `dtype_format`是用来描述算子支持的数据类型,下面示例中注册了两项说明该算子支持两种数据类型,而每一项需按照输入和输出的顺序依次描述支持的格式。第一个`dtype_format`说明支持的第一种数据类型是input0为F32_Default格式,output0为F32_Default格式。第二个`dtype_format`说明支持的第二种数据类型是input0为F16_Default格式,output0为F16_Default格式。
```python ```python
from __future__ import absolute_import from __future__ import absolute_import
...@@ -107,13 +113,13 @@ def square_compute(input_x, output_y): ...@@ -107,13 +113,13 @@ def square_compute(input_x, output_y):
return res return res
# Define the kernel info of CusSquare. # Define the kernel info of CusSquare.
cus_square_op_info = TBERegOp("CusSquare") \ # The registered op name should be same with primitive name. cus_square_op_info = TBERegOp("CusSquare") \
.fusion_type("OPAQUE") \ # Setting kernel fusion strategy. The default is not infusible. .fusion_type("OPAQUE") \
.partial_flag(True) \ .partial_flag(True) \
.async_flag(False) \ .async_flag(False) \
.binfile_name("square.so") \ .binfile_name("square.so") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("CusSquareImpl") \ # The kernel name should be same with the name of the entry function. .kernel_name("CusSquareImpl") \
.input(0, "x", False, "required", "all") \ .input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \ .output(0, "y", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册