Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
248d8bf0
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
248d8bf0
编写于
12月 15, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/ops): improve infer attrs validate function
GitOrigin-RevId: 6fab3b140220709c6edf92bd5a59105c96c2320a
上级
8ed2077b
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
15 addition
and
18 deletion
+15
-18
imperative/src/impl/interpreter_impl.cpp
imperative/src/impl/interpreter_impl.cpp
+3
-3
imperative/src/impl/ops/backward_graph.cpp
imperative/src/impl/ops/backward_graph.cpp
+1
-1
imperative/src/impl/ops/batch_norm.cpp
imperative/src/impl/ops/batch_norm.cpp
+4
-7
imperative/src/impl/ops/broadcast.cpp
imperative/src/impl/ops/broadcast.cpp
+4
-4
imperative/src/impl/ops/cond_take.cpp
imperative/src/impl/ops/cond_take.cpp
+1
-1
imperative/src/impl/ops/elemwise.cpp
imperative/src/impl/ops/elemwise.cpp
+1
-1
imperative/src/impl/ops/tensor_manip.cpp
imperative/src/impl/ops/tensor_manip.cpp
+1
-1
未找到文件。
imperative/src/impl/interpreter_impl.cpp
浏览文件 @
248d8bf0
...
...
@@ -176,7 +176,7 @@ TensorShape ChannelImpl::get_shape(void* handle) {
m_buffer
.
enqueue
(
Flush
{
info
});
m_cv
.
wait
(
lock
,
[
&
]()
{
check_worker_exc_unsafe
();
return
bool
(
info
->
ptr
);
return
static_cast
<
bool
>
(
info
->
ptr
);
});
m_waitee
=
nullptr
;
TensorShape
ret
=
info
->
ptr
->
layout
();
...
...
@@ -212,7 +212,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) {
m_buffer
.
enqueue
(
Flush
{
info
});
m_cv
.
wait
(
lock
,
[
&
]()
{
check_worker_exc_unsafe
();
return
bool
(
info
->
ptr
);
return
static_cast
<
bool
>
(
info
->
ptr
);
});
m_waitee
=
nullptr
;
return
info
->
ptr
->
dev_tensor
();
...
...
@@ -232,7 +232,7 @@ void ChannelImpl::close() {
}
void
ChannelImpl
::
config_async_level
(
int
level
)
{
mgb_assert
(
level
<=
2
and
level
>=
0
,
"async_level should be 0, 1 or 2"
);
mgb_assert
(
level
<=
2
&&
level
>=
0
,
"async_level should be 0, 1 or 2"
);
m_async_level
=
level
;
}
...
...
imperative/src/impl/ops/backward_graph.cpp
浏览文件 @
248d8bf0
...
...
@@ -49,7 +49,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i
expr_input_descs
.
push_back
(
node2attr
.
at
(
inp
));
}
auto
[
expr_output_descs
,
expr_validated
]
=
OpDef
::
infer_output_attrs_fallible
(
auto
[
expr_output_descs
,
expr_validated
]
=
OpDef
::
infer_output_attrs_fallible
(
*
expr_op
,
expr_input_descs
);
validated
=
validated
&&
expr_validated
;
...
...
imperative/src/impl/ops/batch_norm.cpp
浏览文件 @
248d8bf0
...
...
@@ -54,16 +54,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
SmallVector
<
LogicalTensorDesc
>
out_shapes
(
nr_out
);
auto
&&
i0
=
inputs
[
0
];
auto
&&
i1
=
inputs
[
1
];
size_t
i
=
0
;
if
(
!
need_stat
)
{
out_shapes
[
0
]
=
out_shapes
[
1
]
=
{
TensorLayout
({
0
},
i0
.
layout
.
dtype
,
i0
.
layout
.
format
),
i0
.
comp_node
};
i
=
2
;
}
for
(;
i
<
nr_out
-
1
;
++
i
)
{
// [running_mean, running_var,] save_mean, save_var
for
(
size_t
i
=
0
;
i
<
nr_out
-
1
;
++
i
)
{
out_shapes
[
i
]
=
{
i1
.
layout
,
i1
.
comp_node
};
}
// output tensor
out_shapes
[
nr_out
-
1
]
=
{
i0
.
layout
,
i0
.
comp_node
};
return
{
out_shapes
,
true
};
return
{
out_shapes
,
out_shapes
[
nr_out
-
1
].
layout
.
ndim
!=
0
};
}
OP_TRAIT_REG
(
BatchNorm
,
BatchNorm
,
opr
::
BatchNorm
)
...
...
imperative/src/impl/ops/broadcast.cpp
浏览文件 @
248d8bf0
...
...
@@ -61,17 +61,17 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorLayout
out_layout
=
src
.
layout
;
if
(
tshp
.
layout
.
ndim
==
0
||
tshp
.
value
.
empty
())
{
out_layout
.
ndim
=
0
;
return
{{{
out_layout
,
src
.
comp_node
}},
tru
e
};
return
{{{
out_layout
,
src
.
comp_node
}},
fals
e
};
}
mgb_assert
(
tshp
.
layout
.
ndim
==
1
,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually"
,
tshp
.
layout
.
ndim
==
1
,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually"
,
tshp
.
layout
.
ndim
);
size_t
target_ndim
=
tshp
.
layout
.
shape
[
0
];
out_layout
.
ndim
=
target_ndim
;
auto
*
ptr
=
tshp
.
value
.
ptr
<
dt_int32
>
();
for
(
size_t
i
=
0
;
i
<
target_ndim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
target_ndim
;
++
i
)
{
out_layout
.
shape
[
i
]
=
ptr
[
i
];
}
mgb_assert
(
valid_broadcast
(
src
.
layout
,
out_layout
),
...
...
imperative/src/impl/ops/cond_take.cpp
浏览文件 @
248d8bf0
...
...
@@ -76,7 +76,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return
{{
{
TensorLayout
(
inputs
[
0
].
layout
.
dtype
),
cn
},
{
TensorLayout
(
dtype
::
Int32
()),
cn
}
},
tru
e
};
},
fals
e
};
}
OP_TRAIT_REG
(
CondTake
,
CondTake
,
opr
::
CondTake
)
...
...
imperative/src/impl/ops/elemwise.cpp
浏览文件 @
248d8bf0
...
...
@@ -60,7 +60,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorLayout
out_layout
;
out_layout
.
ndim
=
0
;
out_layout
.
dtype
=
out_dt
;
return
{{{
out_layout
,
out_cn
}},
tru
e
};
return
{{{
out_layout
,
out_cn
}},
fals
e
};
}
}
...
...
imperative/src/impl/ops/tensor_manip.cpp
浏览文件 @
248d8bf0
...
...
@@ -59,7 +59,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert
(
inputs
.
size
()
==
1
,
"GetVarShape take 1 input, got %lu"
,
inputs
.
size
());
auto
&&
desc
=
inputs
[
0
];
if
(
!
desc
.
layout
.
ndim
)
{
return
{{{
TensorLayout
(
dtype
::
Int32
()),
desc
.
comp_node
}},
tru
e
};
return
{{{
TensorLayout
(
dtype
::
Int32
()),
desc
.
comp_node
}},
fals
e
};
}
DeviceTensorND
value
;
if
(
op_def
.
axis
==
opr
::
GetVarShape
::
Param
::
INVALID_AXIS
){
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录