Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
docs
提交
857c0b52
D
docs
项目概览
MindSpore
/
docs
通知
5
Star
3
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
docs
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
857c0b52
编写于
5月 20, 2020
作者:
G
gongchen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix custom op tutorial.
上级
6da853cf
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
4 addition
and
4 deletion
+4
-4
tutorials/source_zh_cn/advanced_use/custom_operator.md
tutorials/source_zh_cn/advanced_use/custom_operator.md
+4
-4
未找到文件。
tutorials/source_zh_cn/advanced_use/custom_operator.md
浏览文件 @
857c0b52
...
...
@@ -38,7 +38,7 @@
-
输入输出的名称通过
`init_prim_io_names()`
函数定义。
-
输出Tensor的shape推理方法在
`infer_shape()`
函数中定义,输出Tensor的dtype推理方法在
`infer_dtype()`
函数中定义。
自定义算子与内置算子的唯一区别是需要通过在
`__init__()`
函数中导入算子实现函数(
`from
.
square_impl import CusSquareImpl`
)来将算子实现注册到后端。本用例在
`square_impl.py`
中定义了算子实现和算子信息,将在后文中说明。
自定义算子与内置算子的唯一区别是需要通过在
`__init__()`
函数中导入算子实现函数(
`from square_impl import CusSquareImpl`
)来将算子实现注册到后端。本用例在
`square_impl.py`
中定义了算子实现和算子信息,将在后文中说明。
以Square算子原语
`cus_square.py`
为例,给出如下示例代码。
...
...
@@ -53,7 +53,7 @@ class CusSquare(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
):
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'y'
])
from
.
square_impl
import
CusSquareImpl
# Import the entry function of the kernel implementation from relative path or PYTHONPATH.
from
square_impl
import
CusSquareImpl
# Import the entry function of the kernel implementation from relative path or PYTHONPATH.
def
infer_shape
(
self
,
data_shape
):
return
data_shape
...
...
@@ -155,7 +155,7 @@ import mindspore.nn as nn
import
mindspore.context
as
context
from
mindspore
import
Tensor
# Import the definition of the CusSquare primtive.
from
.
cus_square
import
CusSquare
from
cus_square
import
CusSquare
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
Net
(
nn
.
Cell
):
...
...
@@ -200,7 +200,7 @@ class CusSquare(PrimitiveWithInfer):
def
__init__
(
self
):
"""init CusSquare"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'y'
])
from
.
square_impl
import
CusSquareImpl
from
square_impl
import
CusSquareImpl
def
infer_shape
(
self
,
data_shape
):
return
data_shape
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录