From decaeb1c6d9b9bc8a0d7634c542373c098c463a7 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 2 Nov 2018 13:47:04 +0800 Subject: [PATCH] fix style check after conflicts check. test=develop --- python/paddle/fluid/layers/nn.py | 5 ++--- python/paddle/fluid/tests/unittests/test_layers.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3f5b0bcd7b..d66a5b083a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7806,7 +7806,6 @@ def grid_sampler(x, grid, name=None): out = fluid.layers.grid_sampler(x=x, grid=grid) """ helper = LayerHelper("grid_sampler", **locals()) - dtype = helper.input_dtype() if not isinstance(x, Variable): return ValueError("The x should be a Variable") @@ -7814,10 +7813,10 @@ def grid_sampler(x, grid, name=None): if not isinstance(grid, Variable): return ValueError("The grid should be a Variable") - out = helper.create_variable_for_type_inference(dtype) + out = helper.create_variable_for_type_inference(x.dtype) ipts = {'X': x, 'Grid': grid} - helper.append_op(type='grid_sampler', inputs=ipts, outputs={'Output', out}) + helper.append_op(type='grid_sampler', inputs=ipts, outputs={'Output': out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index f85beee9be..c4ecc2c2c2 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -868,12 +868,12 @@ class TestBook(unittest.TestCase): def test_grid_sampler(self): program = Program() with program_guard(program): - x = layers.data(name='x', shape=[2, 3, 5, 7], dtype='float32') - grid = layers.data(name='grid', shape=[2, 5, 7, 2], dtype='float32') + x = layers.data(name='x', shape=[3, 5, 7], dtype='float32') + grid = layers.data(name='grid', shape=[5, 7, 2], dtype='float32') out = layers.grid_sampler(x, grid) self.assertIsNotNone(out) print(str(program)) - + def test_affine_grid(self): program = Program() with program_guard(program): -- GitLab