未验证 提交 06850493 编写于 作者: W wawltor 提交者: GitHub

Fix the elementwise_div bug in the broadcast process, test=develop (#24987)

Fix the bug for elementwise_div op, when the first var is scalar; Use the shape 1 replace the -1 in shape
上级 0b3f6265
......@@ -72,17 +72,23 @@ def monkey_patch_variable():
block = current_block(ref_var)
var = create_new_tmp_var(block, dtype)
batch_dim = -1
out_shape = []
for i, d in enumerate(ref_var.shape):
if d < 0:
if batch_dim < 0:
batch_dim = i
break
out_shape.append(d)
else:
out_shape.append(1)
else:
out_shape.append(d)
assert batch_dim != -1
block.append_op(
type='fill_constant_batch_size_like',
outputs={'Out': [var]},
inputs={'Input': [ref_var]},
attrs={
'shape': ref_var.shape,
'shape': out_shape,
'value': value,
'input_dim_idx': batch_dim,
'output_dim_idx': batch_dim
......
......@@ -227,5 +227,18 @@ class TestElementwiseDivOpFp16(ElementwiseDivOp):
['X'], 'Out', max_relative_error=1, no_grad_set=set('Y'))
class TestElementwiseDivBroadcast(unittest.TestCase):
def test_shape_with_batch_sizes(self):
with fluid.program_guard(fluid.Program()):
x_var = fluid.data(
name='x', dtype='float32', shape=[None, 3, None, None])
one = 2.
out = one / x_var
exe = fluid.Executor(fluid.CPUPlace())
x = np.random.uniform(0.1, 0.6, (1, 3, 32, 32)).astype("float32")
out_result, = exe.run(feed={'x': x}, fetch_list=[out])
self.assertEqual((out_result == (2 / x)).all(), True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册