From 89ad8664124c1872640b68c0a4418c377ad6aa6f Mon Sep 17 00:00:00 2001 From: yukavio Date: Thu, 13 Aug 2020 10:47:38 +0800 Subject: [PATCH] fix dygraph and input shape --- python/paddle/fluid/input.py | 2 +- python/paddle/fluid/layers/nn.py | 1 + python/paddle/nn/functional/input.py | 12 ++++++------ 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/input.py b/python/paddle/fluid/input.py index 347927509e..733c0d9526 100644 --- a/python/paddle/fluid/input.py +++ b/python/paddle/fluid/input.py @@ -20,7 +20,7 @@ from .data_feeder import check_variable_and_dtype, check_dtype __all__ = ['one_hot', 'embedding'] - +@deprecated(since='2.0.0', update_to='paddle.functional.one_hot') def one_hot(input, depth, allow_out_of_range=False): """ :alias_main: paddle.nn.functional.one_hot diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1b8df4a098..d417f2e5a5 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5831,6 +5831,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): return loss +@deprecated(since='2.0.0', update_to='paddle.functional.one_hot') def one_hot(input, depth, allow_out_of_range=False): """ diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index 50d208073c..e3262d1fdc 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -40,7 +40,7 @@ def one_hot(x, num_classes, name=None): Example 1: input: - x.shape = [4, 1] + x.shape = [4] x.data = [1, 1, 3, 0] num_classes = 4 @@ -54,7 +54,7 @@ def one_hot(x, num_classes, name=None): Example 2: input: - x.shape = [4, 1] + x.shape = [4] x.data = [1, 1, 5, 0] num_classes = 4 @@ -80,15 +80,15 @@ def one_hot(x, num_classes, name=None): label = fluid.data(name="label", shape=[4, 1], dtype="int64") one_hot_label = fluid.one_hot(x=label, num_classes=4) """ - check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'one_hot_v2') - helper = LayerHelper("one_hot_v2", **locals()) - - one_hot_out = helper.create_variable_for_type_inference(dtype='float32') if in_dygraph_mode(): return core.ops.one_hot_v2(x, 'depth', num_classes, 'allow_out_of_range', False) else: + check_variable_and_dtype(x, 'input', ['int32', 'int64'], 'one_hot_v2') + helper = LayerHelper("one_hot_v2", **locals()) + + one_hot_out = helper.create_variable_for_type_inference(dtype='float32') if not isinstance(num_classes, Variable): # user attribute inputs = {'X': x} -- GitLab