Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9d3c9c69
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9d3c9c69
编写于
6月 24, 2020
作者:
H
huangdongrun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify map to C.Map()
上级
cf0eca56
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
27 addition
and
1 deletion
+27
-1
mindspore/_extends/parse/resources.py
mindspore/_extends/parse/resources.py
+1
-1
mindspore/ccsrc/operator/composite/map.cc
mindspore/ccsrc/operator/composite/map.cc
+3
-0
tests/ut/python/pipeline/parse/test_fix_bug.py
tests/ut/python/pipeline/parse/test_fix_bug.py
+23
-0
未找到文件。
mindspore/_extends/parse/resources.py
浏览文件 @
9d3c9c69
...
...
@@ -111,7 +111,7 @@ convert_object_map = {
# system function
T
.
len
:
M
.
ms_len
,
T
.
bool
:
M
.
bool_
,
T
.
map
:
C
.
Hyper
Map
(),
T
.
map
:
C
.
Map
(),
T
.
partial
:
F
.
partial
,
T
.
zip
:
C
.
zip_operation
,
T
.
print
:
F
.
print_
,
...
...
mindspore/ccsrc/operator/composite/map.cc
浏览文件 @
9d3c9c69
...
...
@@ -181,6 +181,9 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
}
AnfNodePtr
Map
::
Make
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
ArgsPairList
&
arg_pairs
)
{
if
(
arg_pairs
.
size
()
<
1
)
{
MS_EXCEPTION
(
TypeError
)
<<
"map() must have at least two arguments"
;
}
bool
found
=
false
;
TypeId
id
=
kObjectTypeEnd
;
std
::
pair
<
AnfNodePtr
,
TypePtr
>
pair
;
...
...
tests/ut/python/pipeline/parse/test_fix_bug.py
浏览文件 @
9d3c9c69
...
...
@@ -18,6 +18,7 @@ import pytest
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.ops
import
composite
as
C
from
mindspore.common.api
import
_executor
...
...
@@ -93,3 +94,25 @@ def test_compile_unspported():
net
=
unsupported_method_net
()
with
pytest
.
raises
(
RuntimeError
):
_executor
.
compile
(
net
,
input_me
)
def
test_parser_map_0002
():
class
NetMap0002
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
relu
=
nn
.
ReLU
()
self
.
hypermap
=
C
.
Map
()
def
mul
(
self
,
x
=
2
,
y
=
4
):
return
x
*
y
def
construct
(
self
,
x
):
if
map
(
self
.
mul
)
==
8
:
x
=
self
.
relu
(
x
)
return
x
input_np_x
=
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
)
input_me_x
=
Tensor
(
input_np_x
)
net
=
NetMap0002
()
with
pytest
.
raises
(
TypeError
):
net
(
input_me_x
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录