未验证 提交 da9752fe 编写于 作者: Y Yi Liu 提交者: GitHub

fix bug of issue #21259 (#21331)

* fix bug of issue #21259 (#21287)
pass the argument `allow_out_of_range` of one_hot op to c++ back end.
上级 a91b8014
...@@ -104,16 +104,16 @@ def one_hot(input, depth, allow_out_of_range=False): ...@@ -104,16 +104,16 @@ def one_hot(input, depth, allow_out_of_range=False):
if in_dygraph_mode(): if in_dygraph_mode():
inputs = {'X': input} inputs = {'X': input}
attrs = {'depth': depth} attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else: else:
if not isinstance(depth, Variable): if not isinstance(depth, Variable):
# user attribute # user attribute
inputs = {'X': input} inputs = {'X': input}
attrs = {'depth': depth} attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else: else:
depth.stop_gradient = True depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth} inputs = {'X': input, 'depth_tensor': depth}
attrs = {} attrs = {'allow_out_of_range': allow_out_of_range}
helper.append_op( helper.append_op(
type="one_hot_v2", type="one_hot_v2",
inputs=inputs, inputs=inputs,
......
...@@ -8731,16 +8731,16 @@ def one_hot(input, depth, allow_out_of_range=False): ...@@ -8731,16 +8731,16 @@ def one_hot(input, depth, allow_out_of_range=False):
if in_dygraph_mode(): if in_dygraph_mode():
inputs = {'X': input} inputs = {'X': input}
attrs = {'depth': depth} attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else: else:
if not isinstance(depth, Variable): if not isinstance(depth, Variable):
# user attribute # user attribute
inputs = {'X': input} inputs = {'X': input}
attrs = {'depth': depth} attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else: else:
depth.stop_gradient = True depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth} inputs = {'X': input, 'depth_tensor': depth}
attrs = {} attrs = {'allow_out_of_range': allow_out_of_range}
helper.append_op( helper.append_op(
type="one_hot", type="one_hot",
inputs=inputs, inputs=inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册