未验证 提交 9a2c1aed 编写于 作者: W wawltor 提交者: GitHub

Fix the elementwise_div op broadcast failed in some shape

Fix the bug for elementwise_div op, when the first var is scalar; Use the shape 1 replace the -1 in shape.
上级 97708add
...@@ -98,17 +98,23 @@ def monkey_patch_variable(): ...@@ -98,17 +98,23 @@ def monkey_patch_variable():
block = current_block(ref_var) block = current_block(ref_var)
var = create_new_tmp_var(block, dtype) var = create_new_tmp_var(block, dtype)
batch_dim = -1 batch_dim = -1
out_shape = []
for i, d in enumerate(ref_var.shape): for i, d in enumerate(ref_var.shape):
if d < 0: if d < 0:
if batch_dim < 0:
batch_dim = i batch_dim = i
break out_shape.append(d)
else:
out_shape.append(1)
else:
out_shape.append(d)
assert batch_dim != -1 assert batch_dim != -1
block.append_op( block.append_op(
type='fill_constant_batch_size_like', type='fill_constant_batch_size_like',
outputs={'Out': [var]}, outputs={'Out': [var]},
inputs={'Input': [ref_var]}, inputs={'Input': [ref_var]},
attrs={ attrs={
'shape': ref_var.shape, 'shape': out_shape,
'value': value, 'value': value,
'input_dim_idx': batch_dim, 'input_dim_idx': batch_dim,
'output_dim_idx': batch_dim 'output_dim_idx': batch_dim
......
...@@ -227,6 +227,19 @@ class TestElementwiseDivOpFp16(ElementwiseDivOp): ...@@ -227,6 +227,19 @@ class TestElementwiseDivOpFp16(ElementwiseDivOp):
['X'], 'Out', max_relative_error=1, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=1, no_grad_set=set('Y'))
class TestElementwiseDivBroadcast(unittest.TestCase):
def test_shape_with_batch_sizes(self):
with fluid.program_guard(fluid.Program()):
x_var = fluid.data(
name='x', dtype='float32', shape=[None, 3, None, None])
one = 2.
out = one / x_var
exe = fluid.Executor(fluid.CPUPlace())
x = np.random.uniform(0.1, 0.6, (1, 3, 32, 32)).astype("float32")
out_result, = exe.run(feed={'x': x}, fetch_list=[out])
self.assertEqual((out_result == (2 / x)).all(), True)
class TestDivOp(unittest.TestCase): class TestDivOp(unittest.TestCase):
def test_out(self): def test_out(self):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册