未验证 提交 0fd1281e 编写于 作者: Y Yi Liu 提交者: GitHub

fix bug of issue #21259 (#21287)

pass the argument `allow_out_of_range` of one_hot op to c++ back end.
上级 319d2ba9
...@@ -5244,16 +5244,16 @@ def one_hot(input, depth, allow_out_of_range=False): ...@@ -5244,16 +5244,16 @@ def one_hot(input, depth, allow_out_of_range=False):
if in_dygraph_mode(): if in_dygraph_mode():
inputs = {'X': input} inputs = {'X': input}
attrs = {'depth': depth} attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else: else:
if not isinstance(depth, Variable): if not isinstance(depth, Variable):
# user attribute # user attribute
inputs = {'X': input} inputs = {'X': input}
attrs = {'depth': depth} attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else: else:
depth.stop_gradient = True depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth} inputs = {'X': input, 'depth_tensor': depth}
attrs = {} attrs = {'allow_out_of_range': allow_out_of_range}
helper.append_op( helper.append_op(
type="one_hot", type="one_hot",
inputs=inputs, inputs=inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册