未验证 提交 da9752fe 编写于 作者: Y Yi Liu 提交者: GitHub

fix bug of issue #21259 (#21331)

* fix bug of issue #21259 (#21287)
pass the argument `allow_out_of_range` of one_hot op to c++ back end.
上级 a91b8014
......@@ -104,16 +104,16 @@ def one_hot(input, depth, allow_out_of_range=False):
if in_dygraph_mode():
inputs = {'X': input}
attrs = {'depth': depth}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
if not isinstance(depth, Variable):
# user attribute
inputs = {'X': input}
attrs = {'depth': depth}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth}
attrs = {}
attrs = {'allow_out_of_range': allow_out_of_range}
helper.append_op(
type="one_hot_v2",
inputs=inputs,
......
......@@ -8731,16 +8731,16 @@ def one_hot(input, depth, allow_out_of_range=False):
if in_dygraph_mode():
inputs = {'X': input}
attrs = {'depth': depth}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
if not isinstance(depth, Variable):
# user attribute
inputs = {'X': input}
attrs = {'depth': depth}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth}
attrs = {}
attrs = {'allow_out_of_range': allow_out_of_range}
helper.append_op(
type="one_hot",
inputs=inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册