提交 248d8bf0 编写于 作者: M Megvii Engine Team

feat(imperative/ops): improve infer attrs validate function

GitOrigin-RevId: 6fab3b140220709c6edf92bd5a59105c96c2320a
上级 8ed2077b
...@@ -176,7 +176,7 @@ TensorShape ChannelImpl::get_shape(void* handle) { ...@@ -176,7 +176,7 @@ TensorShape ChannelImpl::get_shape(void* handle) {
m_buffer.enqueue(Flush{info}); m_buffer.enqueue(Flush{info});
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
return bool(info->ptr); return static_cast<bool>(info->ptr);
}); });
m_waitee = nullptr; m_waitee = nullptr;
TensorShape ret = info->ptr->layout(); TensorShape ret = info->ptr->layout();
...@@ -212,7 +212,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { ...@@ -212,7 +212,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) {
m_buffer.enqueue(Flush{info}); m_buffer.enqueue(Flush{info});
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
return bool(info->ptr); return static_cast<bool>(info->ptr);
}); });
m_waitee = nullptr; m_waitee = nullptr;
return info->ptr->dev_tensor(); return info->ptr->dev_tensor();
...@@ -232,7 +232,7 @@ void ChannelImpl::close() { ...@@ -232,7 +232,7 @@ void ChannelImpl::close() {
} }
void ChannelImpl::config_async_level(int level) { void ChannelImpl::config_async_level(int level) {
mgb_assert(level <= 2 and level >= 0, "async_level should be 0, 1 or 2"); mgb_assert(level <= 2 && level >= 0, "async_level should be 0, 1 or 2");
m_async_level = level; m_async_level = level;
} }
......
...@@ -49,7 +49,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i ...@@ -49,7 +49,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i
expr_input_descs.push_back(node2attr.at(inp)); expr_input_descs.push_back(node2attr.at(inp));
} }
auto[expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( auto [expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible(
*expr_op, expr_input_descs); *expr_op, expr_input_descs);
validated = validated && expr_validated; validated = validated && expr_validated;
......
...@@ -54,16 +54,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -54,16 +54,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
SmallVector<LogicalTensorDesc> out_shapes(nr_out); SmallVector<LogicalTensorDesc> out_shapes(nr_out);
auto&& i0 = inputs[0]; auto&& i0 = inputs[0];
auto&& i1 = inputs[1]; auto&& i1 = inputs[1];
size_t i = 0; // [running_mean, running_var,] save_mean, save_var
if (!need_stat) { for (size_t i = 0; i < nr_out-1; ++ i) {
out_shapes[0] = out_shapes[1] = {TensorLayout({0}, i0.layout.dtype, i0.layout.format), i0.comp_node};
i = 2;
}
for (; i < nr_out-1; ++ i) {
out_shapes[i] = {i1.layout, i1.comp_node}; out_shapes[i] = {i1.layout, i1.comp_node};
} }
// output tensor
out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; out_shapes[nr_out-1] = {i0.layout, i0.comp_node};
return {out_shapes, true}; return {out_shapes, out_shapes[nr_out-1].layout.ndim != 0};
} }
OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
......
...@@ -61,7 +61,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -61,7 +61,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorLayout out_layout = src.layout; TensorLayout out_layout = src.layout;
if (tshp.layout.ndim == 0 || tshp.value.empty()) { if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0; out_layout.ndim = 0;
return {{{out_layout, src.comp_node}}, true}; return {{{out_layout, src.comp_node}}, false};
} }
mgb_assert( mgb_assert(
tshp.layout.ndim == 1, tshp.layout.ndim == 1,
...@@ -71,7 +71,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -71,7 +71,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
size_t target_ndim = tshp.layout.shape[0]; size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim; out_layout.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>(); auto* ptr = tshp.value.ptr<dt_int32>();
for(size_t i=0; i<target_ndim; ++i) { for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i]; out_layout.shape[i] = ptr[i];
} }
mgb_assert(valid_broadcast(src.layout, out_layout), mgb_assert(valid_broadcast(src.layout, out_layout),
......
...@@ -76,7 +76,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -76,7 +76,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{ return {{
{TensorLayout(inputs[0].layout.dtype), cn}, {TensorLayout(inputs[0].layout.dtype), cn},
{TensorLayout(dtype::Int32()), cn} {TensorLayout(dtype::Int32()), cn}
}, true}; }, false};
} }
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
......
...@@ -60,7 +60,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -60,7 +60,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorLayout out_layout; TensorLayout out_layout;
out_layout.ndim = 0; out_layout.ndim = 0;
out_layout.dtype = out_dt; out_layout.dtype = out_dt;
return {{{out_layout, out_cn}}, true}; return {{{out_layout, out_cn}}, false};
} }
} }
......
...@@ -59,7 +59,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -59,7 +59,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& desc = inputs[0]; auto&& desc = inputs[0];
if (!desc.layout.ndim) { if (!desc.layout.ndim) {
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true}; return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
} }
DeviceTensorND value; DeviceTensorND value;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册