提交 89ad8664 编写于 作者: Y yukavio

fix dygraph and input shape

上级 793c74bf
...@@ -20,7 +20,7 @@ from .data_feeder import check_variable_and_dtype, check_dtype ...@@ -20,7 +20,7 @@ from .data_feeder import check_variable_and_dtype, check_dtype
__all__ = ['one_hot', 'embedding'] __all__ = ['one_hot', 'embedding']
@deprecated(since='2.0.0', update_to='paddle.functional.one_hot')
def one_hot(input, depth, allow_out_of_range=False): def one_hot(input, depth, allow_out_of_range=False):
""" """
:alias_main: paddle.nn.functional.one_hot :alias_main: paddle.nn.functional.one_hot
......
...@@ -5831,6 +5831,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): ...@@ -5831,6 +5831,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
return loss return loss
@deprecated(since='2.0.0', update_to='paddle.functional.one_hot')
def one_hot(input, depth, allow_out_of_range=False): def one_hot(input, depth, allow_out_of_range=False):
""" """
......
...@@ -40,7 +40,7 @@ def one_hot(x, num_classes, name=None): ...@@ -40,7 +40,7 @@ def one_hot(x, num_classes, name=None):
Example 1: Example 1:
input: input:
x.shape = [4, 1] x.shape = [4]
x.data = [1, 1, 3, 0] x.data = [1, 1, 3, 0]
num_classes = 4 num_classes = 4
...@@ -54,7 +54,7 @@ def one_hot(x, num_classes, name=None): ...@@ -54,7 +54,7 @@ def one_hot(x, num_classes, name=None):
Example 2: Example 2:
input: input:
x.shape = [4, 1] x.shape = [4]
x.data = [1, 1, 5, 0] x.data = [1, 1, 5, 0]
num_classes = 4 num_classes = 4
...@@ -80,15 +80,15 @@ def one_hot(x, num_classes, name=None): ...@@ -80,15 +80,15 @@ def one_hot(x, num_classes, name=None):
label = fluid.data(name="label", shape=[4, 1], dtype="int64") label = fluid.data(name="label", shape=[4, 1], dtype="int64")
one_hot_label = fluid.one_hot(x=label, num_classes=4) one_hot_label = fluid.one_hot(x=label, num_classes=4)
""" """
check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'one_hot_v2')
helper = LayerHelper("one_hot_v2", **locals())
one_hot_out = helper.create_variable_for_type_inference(dtype='float32')
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.one_hot_v2(x, 'depth', num_classes, return core.ops.one_hot_v2(x, 'depth', num_classes,
'allow_out_of_range', False) 'allow_out_of_range', False)
else: else:
check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'one_hot_v2')
helper = LayerHelper("one_hot_v2", **locals())
one_hot_out = helper.create_variable_for_type_inference(dtype='float32')
if not isinstance(num_classes, Variable): if not isinstance(num_classes, Variable):
# user attribute # user attribute
inputs = {'X': x} inputs = {'X': x}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册