未验证 提交 53e294ca 编写于 作者: Z zhulei 提交者: GitHub

[RC22] Fix linear with matmul_op replace (#35445)

* [RC22] Fix linear with matmul_op replace

* [RC22] Fix linear with matmul_op replace

* [RC22] Fix linear with matmul_op replace

* [RC22] Fix linear with matmul_op replace

* [RC22] Fix linear with matmul_op replace
上级 fb65268c
...@@ -93,8 +93,11 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -93,8 +93,11 @@ class TestImperativeOutSclae(unittest.TestCase):
conv2d_count, matmul_count = 0, 0 conv2d_count, matmul_count = 0, 0
conv2d_skip_count, matmul_skip_count = 0, 0 conv2d_skip_count, matmul_skip_count = 0, 0
find_conv2d = False
find_matmul = False
for i, op in enumerate(model_ops): for i, op in enumerate(model_ops):
if op.type == 'conv2d': if op.type == 'conv2d':
find_conv2d = True
if op.has_attr("skip_quant"): if op.has_attr("skip_quant"):
conv2d_skip_count += 1 conv2d_skip_count += 1
if conv2d_count > 0: if conv2d_count > 0:
...@@ -106,6 +109,7 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -106,6 +109,7 @@ class TestImperativeOutSclae(unittest.TestCase):
conv2d_count += 1 conv2d_count += 1
if op.type == 'matmul': if op.type == 'matmul':
find_matmul = True
if op.has_attr("skip_quant"): if op.has_attr("skip_quant"):
matmul_skip_count += 1 matmul_skip_count += 1
if matmul_count > 0: if matmul_count > 0:
...@@ -116,8 +120,10 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -116,8 +120,10 @@ class TestImperativeOutSclae(unittest.TestCase):
'fake_quantize_dequantize' not in model_ops[i - 1].type) 'fake_quantize_dequantize' not in model_ops[i - 1].type)
matmul_count += 1 matmul_count += 1
self.assertTrue(conv2d_skip_count == 1) if find_conv2d:
self.assertTrue(matmul_skip_count == 1) self.assertTrue(conv2d_skip_count == 1)
if find_matmul:
self.assertTrue(matmul_skip_count == 1)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -104,7 +104,8 @@ class TestDistTraning(unittest.TestCase): ...@@ -104,7 +104,8 @@ class TestDistTraning(unittest.TestCase):
ops = main_program.global_block().ops ops = main_program.global_block().ops
ops = [op.type for op in ops] ops = [op.type for op in ops]
self.assertEqual( self.assertEqual(
ops, ['c_identity', 'matmul', 'elementwise_add', 'c_concat']) ops,
['c_identity', 'matmul_v2', 'elementwise_add', 'c_concat'])
weight = model_a.parallel_linear.weight weight = model_a.parallel_linear.weight
bias = model_a.parallel_linear.bias bias = model_a.parallel_linear.bias
...@@ -127,7 +128,7 @@ class TestDistTraning(unittest.TestCase): ...@@ -127,7 +128,7 @@ class TestDistTraning(unittest.TestCase):
ops = [op.type for op in ops] ops = [op.type for op in ops]
self.assertEqual( self.assertEqual(
ops, ops,
['c_split', 'matmul', 'c_allreduce_sum', 'elementwise_add']) ['c_split', 'matmul_v2', 'c_allreduce_sum', 'elementwise_add'])
weight = model_a.parallel_linear.weight weight = model_a.parallel_linear.weight
bias = model_a.parallel_linear.bias bias = model_a.parallel_linear.bias
......
...@@ -74,7 +74,7 @@ class LinearTestCase(unittest.TestCase): ...@@ -74,7 +74,7 @@ class LinearTestCase(unittest.TestCase):
np.testing.assert_array_almost_equal(res_nn, res_np) np.testing.assert_array_almost_equal(res_nn, res_np)
def test_error_dummy_input(self, place=paddle.CPUPlace()): def test_error_dummy_input(self, place=paddle.CPUPlace()):
with self.assertRaises(ValueError): with self.assertRaises(RuntimeError):
x_arr = np.array([], dtype=np.float32) x_arr = np.array([], dtype=np.float32)
x = paddle.to_tensor( x = paddle.to_tensor(
np.reshape(x_arr, (0, 4, 4, 4)), dtype='float32') np.reshape(x_arr, (0, 4, 4, 4)), dtype='float32')
......
...@@ -664,7 +664,7 @@ class TestFusedMomentumWithDecayAPI(unittest.TestCase): ...@@ -664,7 +664,7 @@ class TestFusedMomentumWithDecayAPI(unittest.TestCase):
self.assertEqual(ops[-3].type, 'sum') self.assertEqual(ops[-3].type, 'sum')
self.assertEqual(ops[-4].type, 'scale') self.assertEqual(ops[-4].type, 'scale')
self.assertEqual(ops[-5].type, 'sign') self.assertEqual(ops[-5].type, 'sign')
self.assertEqual(ops[-6].type, 'matmul_grad') self.assertEqual(ops[-6].type, 'matmul_v2_grad')
if 'weight' in ops[-1].input('Param'): if 'weight' in ops[-1].input('Param'):
self.assertEqual(ops[-1].attr('regularization_method'), '') self.assertEqual(ops[-1].attr('regularization_method'), '')
self.assertEqual(ops[-1].attr('regularization_coeff'), 0) self.assertEqual(ops[-1].attr('regularization_coeff'), 0)
......
...@@ -1467,9 +1467,8 @@ def linear(x, weight, bias=None, name=None): ...@@ -1467,9 +1467,8 @@ def linear(x, weight, bias=None, name=None):
# [2.1077576 2.1077576 2.1077576 2.1077576 ]] # [2.1077576 2.1077576 2.1077576 2.1077576 ]]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
pre_bias = _varbase_creator(dtype=x.dtype) pre_bias = _C_ops.matmul_v2(x, weight, 'trans_x', False, 'trans_y',
_C_ops.matmul(x, weight, pre_bias, 'transpose_X', False, 'transpose_Y', False)
False, "alpha", 1)
if bias is None: if bias is None:
return pre_bias return pre_bias
...@@ -1484,14 +1483,10 @@ def linear(x, weight, bias=None, name=None): ...@@ -1484,14 +1483,10 @@ def linear(x, weight, bias=None, name=None):
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')
inputs = {'X': [x], 'Y': [weight]} inputs = {'X': [x], 'Y': [weight]}
attrs = { attrs = {'trans_x': False, 'trans_y': False}
'transpose_X': False,
'transpose_Y': False,
'alpha': 1,
}
tmp = helper.create_variable_for_type_inference(dtype) tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='matmul', inputs=inputs, outputs={'Out': tmp}, attrs=attrs) type='matmul_v2', inputs=inputs, outputs={'Out': tmp}, attrs=attrs)
if bias is not None: if bias is not None:
res = helper.create_variable_for_type_inference(dtype) res = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册