提交 248d8bf0 编写于 作者: M Megvii Engine Team

feat(imperative/ops): improve infer attrs validate function

GitOrigin-RevId: 6fab3b140220709c6edf92bd5a59105c96c2320a
上级 8ed2077b
......@@ -176,7 +176,7 @@ TensorShape ChannelImpl::get_shape(void* handle) {
m_buffer.enqueue(Flush{info});
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return bool(info->ptr);
return static_cast<bool>(info->ptr);
});
m_waitee = nullptr;
TensorShape ret = info->ptr->layout();
......@@ -212,7 +212,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) {
m_buffer.enqueue(Flush{info});
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return bool(info->ptr);
return static_cast<bool>(info->ptr);
});
m_waitee = nullptr;
return info->ptr->dev_tensor();
......@@ -232,7 +232,7 @@ void ChannelImpl::close() {
}
void ChannelImpl::config_async_level(int level) {
mgb_assert(level <= 2 and level >= 0, "async_level should be 0, 1 or 2");
mgb_assert(level <= 2 && level >= 0, "async_level should be 0, 1 or 2");
m_async_level = level;
}
......
......@@ -49,7 +49,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i
expr_input_descs.push_back(node2attr.at(inp));
}
auto[expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible(
auto [expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible(
*expr_op, expr_input_descs);
validated = validated && expr_validated;
......
......@@ -54,16 +54,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
SmallVector<LogicalTensorDesc> out_shapes(nr_out);
auto&& i0 = inputs[0];
auto&& i1 = inputs[1];
size_t i = 0;
if (!need_stat) {
out_shapes[0] = out_shapes[1] = {TensorLayout({0}, i0.layout.dtype, i0.layout.format), i0.comp_node};
i = 2;
}
for (; i < nr_out-1; ++ i) {
// [running_mean, running_var,] save_mean, save_var
for (size_t i = 0; i < nr_out-1; ++ i) {
out_shapes[i] = {i1.layout, i1.comp_node};
}
// output tensor
out_shapes[nr_out-1] = {i0.layout, i0.comp_node};
return {out_shapes, true};
return {out_shapes, out_shapes[nr_out-1].layout.ndim != 0};
}
OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
......
......@@ -61,17 +61,17 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorLayout out_layout = src.layout;
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0;
return {{{out_layout, src.comp_node}}, true};
return {{{out_layout, src.comp_node}}, false};
}
mgb_assert(
tshp.layout.ndim == 1,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim == 1,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim);
size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for(size_t i=0; i<target_ndim; ++i) {
for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i];
}
mgb_assert(valid_broadcast(src.layout, out_layout),
......
......@@ -76,7 +76,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{
{TensorLayout(inputs[0].layout.dtype), cn},
{TensorLayout(dtype::Int32()), cn}
}, true};
}, false};
}
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
......
......@@ -60,7 +60,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorLayout out_layout;
out_layout.ndim = 0;
out_layout.dtype = out_dt;
return {{{out_layout, out_cn}}, true};
return {{{out_layout, out_cn}}, false};
}
}
......
......@@ -59,7 +59,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& desc = inputs[0];
if (!desc.layout.ndim) {
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true};
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
}
DeviceTensorND value;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册