提交 decaeb1c 编写于 作者: D dengkaipeng

fix style check after conflicts check. test=develop

上级 0b290782
...@@ -7806,7 +7806,6 @@ def grid_sampler(x, grid, name=None): ...@@ -7806,7 +7806,6 @@ def grid_sampler(x, grid, name=None):
out = fluid.layers.grid_sampler(x=x, grid=grid) out = fluid.layers.grid_sampler(x=x, grid=grid)
""" """
helper = LayerHelper("grid_sampler", **locals()) helper = LayerHelper("grid_sampler", **locals())
dtype = helper.input_dtype()
if not isinstance(x, Variable): if not isinstance(x, Variable):
return ValueError("The x should be a Variable") return ValueError("The x should be a Variable")
...@@ -7814,10 +7813,10 @@ def grid_sampler(x, grid, name=None): ...@@ -7814,10 +7813,10 @@ def grid_sampler(x, grid, name=None):
if not isinstance(grid, Variable): if not isinstance(grid, Variable):
return ValueError("The grid should be a Variable") return ValueError("The grid should be a Variable")
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x, 'Grid': grid} ipts = {'X': x, 'Grid': grid}
helper.append_op(type='grid_sampler', inputs=ipts, outputs={'Output', out}) helper.append_op(type='grid_sampler', inputs=ipts, outputs={'Output': out})
return out return out
......
...@@ -868,12 +868,12 @@ class TestBook(unittest.TestCase): ...@@ -868,12 +868,12 @@ class TestBook(unittest.TestCase):
def test_grid_sampler(self): def test_grid_sampler(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
x = layers.data(name='x', shape=[2, 3, 5, 7], dtype='float32') x = layers.data(name='x', shape=[3, 5, 7], dtype='float32')
grid = layers.data(name='grid', shape=[2, 5, 7, 2], dtype='float32') grid = layers.data(name='grid', shape=[5, 7, 2], dtype='float32')
out = layers.grid_sampler(x, grid) out = layers.grid_sampler(x, grid)
self.assertIsNotNone(out) self.assertIsNotNone(out)
print(str(program)) print(str(program))
def test_affine_grid(self): def test_affine_grid(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册