From 9203aaf185ad01f8abe9e9ac364322266737f0bc Mon Sep 17 00:00:00 2001 From: songyouwei Date: Tue, 17 Mar 2020 19:24:05 +0800 Subject: [PATCH] fix unittest for coverage (#23007) test=develop --- python/paddle/fluid/layers/nn.py | 32 +++++++++++----- python/paddle/fluid/layers/tensor.py | 21 +++++----- .../fluid/tests/unittests/test_layers.py | 38 ++++++++++++++++++- 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 17abacc235f..584aad997c5 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4335,12 +4335,13 @@ def split(input, num_or_sections, dim=-1, name=None): if in_dygraph_mode(): inputs = {'X': [input]} attrs = {} - if isinstance(dim, int): - dim = (len(input.shape) + dim) if dim < 0 else dim - attrs['axis'] = dim - else: - dim.stop_gradient = True - inputs['AxisTensor'] = [dim] + if isinstance(dim, Variable): + dim = dim.numpy() + assert dim.shape == (1, + ), "dim of type Variable should have shape [1]" + dim = dim[0] + dim = (len(input.shape) + dim) if dim < 0 else dim + attrs['axis'] = dim if isinstance(num_or_sections, int): num = num_or_sections @@ -4717,17 +4718,23 @@ def topk(input, k, name=None): """ inputs = {"X": [input]} attrs = {} - if isinstance(k, Variable): - inputs['K'] = [k] - else: - attrs = {'k': k} if in_dygraph_mode(): + if isinstance(k, Variable): + k = k.numpy() + assert k.shape == (1, ), "k of type Variable should have shape [1]" + k = k[0] + attrs = {'k': k} outs = core.ops.top_k(inputs, attrs) outs['Out'][0].stop_gradient = True outs['Indices'][0].stop_gradient = True return outs['Out'][0], outs['Indices'][0] + if isinstance(k, Variable): + inputs['K'] = [k] + else: + attrs = {'k': k} + helper = LayerHelper("top_k", **locals()) values = helper.create_variable_for_type_inference(dtype=input.dtype) indices = helper.create_variable_for_type_inference(dtype="int64") @@ -5401,6 +5408,11 @@ def one_hot(input, depth, allow_out_of_range=False): one_hot_label = fluid.layers.one_hot(input=label, depth=4) """ if in_dygraph_mode(): + if isinstance(depth, Variable): + depth = depth.numpy() + assert depth.shape == ( + 1, ), "depth of type Variable should have shape [1]" + depth = depth[0] inputs = {'X': [input]} attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range} outs = core.ops.one_hot(inputs, attrs) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index dc4eb727c44..a9785808681 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -256,9 +256,11 @@ def concat(input, axis=0, name=None): if in_dygraph_mode(): inputs = {'X': input} - if not isinstance(axis, int): - raise TypeError( - "Input 'axis' in concat must be int in Dygraph mode.") + if isinstance(axis, Variable): + axis = axis.numpy() + assert axis.shape == ( + 1, ), "axis of type Variable should have shape [1]" + axis = axis[0] attrs = {'axis': axis} outs = core.ops.concat(inputs, attrs) return outs['Out'][0] @@ -579,15 +581,12 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): if in_dygraph_mode(): if isinstance(shape, (list, tuple)): - if utils._contain_var(shape): - raise TypeError( - "The type of 'shape' in fill_constant must be list[int] or tuple(int) in Dygraph mode, but " - "received %s, which contains Variable." % type(shape)) - attrs['shape'] = shape + shape = list( + map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x, + shape)) else: - raise TypeError( - "The type of 'shape' in fill_constant must be list[int] or tuple(int) in Dygraph mode, but " - "received %s." % type(shape)) + shape = list(shape.numpy().astype(int)) + attrs['shape'] = shape if out is None: out = _varbase_creator(dtype=dtype) attrs['dtype'] = out.dtype diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e94c352da4a..77b8b90f9ff 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -876,7 +876,8 @@ class TestLayer(LayerTest): emb_rlt = emb(words[i]) embs3.append(emb_rlt) - embs3 = layers.concat(input=embs3, axis=1) + embs3 = layers.concat( + input=embs3, axis=fluid.dygraph.to_variable(np.array([1]))) nce = nn.NCE(num_total_classes=dict_size, dim=embs3.shape[1], num_neg_samples=2, @@ -903,7 +904,9 @@ class TestLayer(LayerTest): for i in range(window_size): words.append(base.to_variable(inp_word[i])) sample_weights = layers.fill_constant( - shape=[5, 1], dtype='float32', value=1) + shape=fluid.dygraph.to_variable(np.array([5, 1])), + dtype='float32', + value=1) emb = nn.Embedding( size=[dict_size, 32], param_attr='emb.w', is_sparse=False) @@ -955,6 +958,37 @@ class TestLayer(LayerTest): self.assertTrue( np.array_equal(nce1.bias.numpy(), nce2.bias.numpy())) + def test_one_hot(self): + with self.dynamic_graph(): + label = fluid.dygraph.to_variable(np.array([[1], [1], [3], [0]])) + one_hot_label1 = fluid.layers.one_hot(input=label, depth=4) + one_hot_label2 = fluid.layers.one_hot( + input=label, depth=fluid.dygraph.to_variable(np.array([4]))) + self.assertTrue( + np.array_equal(one_hot_label1.numpy(), one_hot_label2.numpy())) + + def test_split(self): + with self.dynamic_graph(): + input = fluid.dygraph.to_variable(np.random.random((3, 8, 5))) + x0, x1 = fluid.layers.split(input, num_or_sections=2, dim=1) + x00, x11 = fluid.layers.split( + input, + num_or_sections=2, + dim=fluid.dygraph.to_variable(np.array([1]))) + self.assertTrue(np.array_equal(x0.numpy(), x00.numpy())) + self.assertTrue(np.array_equal(x1.numpy(), x11.numpy())) + + def test_topk(self): + with self.dynamic_graph(): + input = fluid.dygraph.to_variable(np.random.random((13, 11))) + top5_values1, top5_indices1 = layers.topk(input, k=5) + top5_values2, top5_indices2 = layers.topk( + input, k=fluid.dygraph.to_variable(np.array([5]))) + self.assertTrue( + np.array_equal(top5_values1.numpy(), top5_values2.numpy())) + self.assertTrue( + np.array_equal(top5_indices1.numpy(), top5_indices2.numpy())) + def test_conv3d(self): with self.static_graph(): images = layers.data( -- GitLab