提交 df9df609 编写于 作者: L leilei_snow

fix index error

上级 5971e313
......@@ -438,8 +438,9 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
void CompileGraph::Push(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (slots_.count(node) > 0) {
MS_LOG(EXCEPTION) << "Push failed node in slots:" << node->DebugString()
<< " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
MS_LOG(WARNING) << "Push failed node in slots:" << node->DebugString()
<< " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
return;
}
MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_
<< " is parameter: " << node->isa<Parameter>();
......
......@@ -341,13 +341,15 @@ void FinalVM::InstSwitchLayer(const VectorRef &args) {
if (!backend_->GetIndex(index, &idx_value)) {
MS_LOG(EXCEPTION) << "Not supported type to be casted to int.";
}
auto ori_value = idx_value;
if (idx_value < 0) {
// Add support negative index range [-size, -1].
idx_value += size;
}
if (idx_value < 0 || idx_value >= size) {
MS_LOG(EXCEPTION) << __FUNCTION__ << " given index " << idx_value << " out of range. Please make sure the value "
<< "of index in [" << -size << ", " << size << "), and the type is int32.";
MS_EXCEPTION(IndexError) << __FUNCTION__ << " given index " << ori_value
<< " out of range. Please make sure the value "
<< "of index in [" << -size << ", " << size << "), and the type is int32.";
}
Push(branches[idx_value]);
MS_LOG(DEBUG) << "End";
......
......@@ -52,5 +52,5 @@ def test_switch_layer():
assert ret
idx3 = Tensor(3, mstype.int32)
with pytest.raises(RuntimeError):
with pytest.raises(IndexError):
value = net(data, idx3, idx2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册