提交 89ad8664 编写于 作者: Y yukavio

fix dygraph and input shape

上级 793c74bf
......@@ -20,7 +20,7 @@ from .data_feeder import check_variable_and_dtype, check_dtype
__all__ = ['one_hot', 'embedding']
@deprecated(since='2.0.0', update_to='paddle.functional.one_hot')
def one_hot(input, depth, allow_out_of_range=False):
"""
:alias_main: paddle.nn.functional.one_hot
......
......@@ -5831,6 +5831,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
return loss
@deprecated(since='2.0.0', update_to='paddle.functional.one_hot')
def one_hot(input, depth, allow_out_of_range=False):
"""
......
......@@ -40,7 +40,7 @@ def one_hot(x, num_classes, name=None):
Example 1:
input:
x.shape = [4, 1]
x.shape = [4]
x.data = [1, 1, 3, 0]
num_classes = 4
......@@ -54,7 +54,7 @@ def one_hot(x, num_classes, name=None):
Example 2:
input:
x.shape = [4, 1]
x.shape = [4]
x.data = [1, 1, 5, 0]
num_classes = 4
......@@ -80,15 +80,15 @@ def one_hot(x, num_classes, name=None):
label = fluid.data(name="label", shape=[4, 1], dtype="int64")
one_hot_label = fluid.one_hot(x=label, num_classes=4)
"""
check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'one_hot_v2')
helper = LayerHelper("one_hot_v2", **locals())
one_hot_out = helper.create_variable_for_type_inference(dtype='float32')
if in_dygraph_mode():
return core.ops.one_hot_v2(x, 'depth', num_classes,
'allow_out_of_range', False)
else:
check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'one_hot_v2')
helper = LayerHelper("one_hot_v2", **locals())
one_hot_out = helper.create_variable_for_type_inference(dtype='float32')
if not isinstance(num_classes, Variable):
# user attribute
inputs = {'X': x}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册