提交 889d7b6e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1180 [pre_activate]Fix pylint warning

Merge pull request !1180 from YuJianfeng/master
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import Primitive from mindspore.ops import Primitive
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import Primitive from mindspore.ops import Primitive
......
...@@ -61,8 +61,8 @@ def test_tbe_eltwise_fusion_1(tag): ...@@ -61,8 +61,8 @@ def test_tbe_eltwise_fusion_1(tag):
def after(x): def after(x):
fusion = Fusion_relu_relu(x) fusion = Fusion_relu_relu(x)
res = Cast(fusion) res = Cast(fusion)
tuple = make_tuple(res) output = make_tuple(res)
return tuple return output
return fns[tag] return fns[tag]
...@@ -86,8 +86,8 @@ def test_tbe_eltwise_fusion_2(tag): ...@@ -86,8 +86,8 @@ def test_tbe_eltwise_fusion_2(tag):
def after(x, y): def after(x, y):
fusion = Fusion_biasadd(x, y) fusion = Fusion_biasadd(x, y)
res = Cast(fusion) res = Cast(fusion)
tuple = make_tuple(res) output = make_tuple(res)
return tuple return output
return fns[tag] return fns[tag]
...@@ -111,8 +111,8 @@ def test_tbe_reduce_eltwise_fusion(tag): ...@@ -111,8 +111,8 @@ def test_tbe_reduce_eltwise_fusion(tag):
def after(x): def after(x):
fusion = Fusion_biasaddgrad(x) fusion = Fusion_biasaddgrad(x)
res = Cast(fusion) res = Cast(fusion)
tuple = make_tuple(res) output = make_tuple(res)
return tuple return output
return fns[tag] return fns[tag]
...@@ -131,8 +131,8 @@ def test_conv_singlein_fusion(tag): ...@@ -131,8 +131,8 @@ def test_conv_singlein_fusion(tag):
def after(x, y): def after(x, y):
fusion = Fusion(x, y) fusion = Fusion(x, y)
res = Cast(fusion) res = Cast(fusion)
tuple = make_tuple(res) output = make_tuple(res)
return tuple return output
return fns[tag] return fns[tag]
...@@ -151,7 +151,7 @@ def test_tbe_matmul_eltwise_fusion(tag): ...@@ -151,7 +151,7 @@ def test_tbe_matmul_eltwise_fusion(tag):
def after(x, y): def after(x, y):
fusion = Fusion_matmul_relu(x, y) fusion = Fusion_matmul_relu(x, y)
res = Cast(fusion) res = Cast(fusion)
tuple = make_tuple(res) output = make_tuple(res)
return tuple return output
return fns[tag] return fns[tag]
...@@ -40,17 +40,17 @@ def test_clip_by_norm_no_div_square_sum_fusion(tag): ...@@ -40,17 +40,17 @@ def test_clip_by_norm_no_div_square_sum_fusion(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(input, constant_select, constant_greater, constant_maximum): def before(x, constant_select, constant_greater, constant_maximum):
greater_output = greater(input, constant_greater) greater_output = greater(x, constant_greater)
res = select(greater_output, input, constant_select) res = select(greater_output, x, constant_select)
res = sqrt(res) res = sqrt(res)
res = select(greater_output, res, input) res = select(greater_output, res, x)
res = maximum(res, constant_maximum) res = maximum(res, constant_maximum)
return res return res
@fns @fns
def after(input, constant_select, constant_greater, constant_maximum): def after(x, constant_select, constant_greater, constant_maximum):
res = clip_by_norm_no_div_square_sum(input, constant_select, constant_greater, constant_maximum) res = clip_by_norm_no_div_square_sum(x, constant_select, constant_greater, constant_maximum)
return make_tuple(res) return make_tuple(res)
return fns[tag] return fns[tag]
...@@ -38,6 +38,7 @@ depth = Tensor(2, mstype.int32) ...@@ -38,6 +38,7 @@ depth = Tensor(2, mstype.int32)
shape = (2, 4, 2, 2) shape = (2, 4, 2, 2)
dropout_gen_mask = P.DropoutGenMask() dropout_gen_mask = P.DropoutGenMask()
class FnDict: class FnDict:
def __init__(self): def __init__(self):
self.fnDict = {} self.fnDict = {}
...@@ -114,7 +115,7 @@ def test_convert_strided_slice_grad_input_to_attr(tag): ...@@ -114,7 +115,7 @@ def test_convert_strided_slice_grad_input_to_attr(tag):
@fns @fns
def before(x): def before(x):
return stridedslicegrad(x, (16, 128, 1024), (0, 0 , 0), (16, 1, 1024), (1, 1,1)) return stridedslicegrad(x, (16, 128, 1024), (0, 0, 0), (16, 1, 1024), (1, 1, 1))
@fns @fns
def after(x): def after(x):
......
...@@ -12,12 +12,10 @@ ...@@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import Primitive from mindspore.ops import Primitive
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import numpy as np
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
concat = P.Concat() concat = P.Concat()
...@@ -51,7 +49,7 @@ def test_convert_tuple_input_to_dynamic_input(tag): ...@@ -51,7 +49,7 @@ def test_convert_tuple_input_to_dynamic_input(tag):
def after(x): def after(x):
res = concat(t1, t2) res = concat(t1, t2)
res = add(x, res) res = add(x, res)
res = make_tuple(res); res = make_tuple(res)
return res return res
return fns[tag] return fns[tag]
...@@ -14,16 +14,13 @@ ...@@ -14,16 +14,13 @@
# ============================================================================ # ============================================================================
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import Primitive from mindspore.ops import Primitive
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
import numpy as np
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_get_item = Primitive("tuple_getitem"); tuple_get_item = Primitive("tuple_getitem")
LSTM = P.LSTM(input_size=10,hidden_size=2,num_layers=1,has_bias=True,bidirectional=False,dropout=0.0) LSTM = P.LSTM(input_size=10, hidden_size=2, num_layers=1, has_bias=True, bidirectional=False, dropout=0.0)
add = P.TensorAdd() add = P.TensorAdd()
class FnDict: class FnDict:
def __init__(self): def __init__(self):
self.fnDict = {} self.fnDict = {}
...@@ -48,7 +45,7 @@ def test_convert_tuple_output_to_maketuple(tag): ...@@ -48,7 +45,7 @@ def test_convert_tuple_output_to_maketuple(tag):
res = LSTM(x, h, c, w) res = LSTM(x, h, c, w)
res = make_tuple( res = make_tuple(
make_tuple(tuple_get_item(res, 0), tuple_get_item(res, 1), tuple_get_item(res, 2), tuple_get_item(res, 3), make_tuple(tuple_get_item(res, 0), tuple_get_item(res, 1), tuple_get_item(res, 2), tuple_get_item(res, 3),
tuple_get_item(res, 4))); tuple_get_item(res, 4)))
return res return res
return fns[tag] return fns[tag]
...@@ -40,8 +40,8 @@ def test_eliminate_5to4_4to5(tag): ...@@ -40,8 +40,8 @@ def test_eliminate_5to4_4to5(tag):
@fns @fns
def before(x, y): def before(x, y):
sum = add(x, y) sum_add = add(x, y)
res = sub(sum, y) res = sub(sum_add, y)
output = make_tuple(res) output = make_tuple(res)
return output return output
...@@ -50,8 +50,8 @@ def test_eliminate_5to4_4to5(tag): ...@@ -50,8 +50,8 @@ def test_eliminate_5to4_4to5(tag):
new_x_sum = transdata(x) new_x_sum = transdata(x)
new_y_sum = transdata(y) new_y_sum = transdata(y)
new_y_sum2 = transdata(y) new_y_sum2 = transdata(y)
sum = add(new_x_sum, new_y_sum) sum_add = add(new_x_sum, new_y_sum)
sum_5to4 = transdata(sum) sum_5to4 = transdata(sum_add)
sum_4to5 = transdata(sum_5to4) sum_4to5 = transdata(sum_5to4)
res = sub(sum_4to5, new_y_sum2) res = sub(sum_4to5, new_y_sum2)
output = transdata(res) output = transdata(res)
...@@ -64,8 +64,8 @@ def test_eliminate_5to4_4to5(tag): ...@@ -64,8 +64,8 @@ def test_eliminate_5to4_4to5(tag):
new_x_sum = transdata(x) new_x_sum = transdata(x)
new_y_sum = transdata(y) new_y_sum = transdata(y)
new_y_diff = transdata(y) new_y_diff = transdata(y)
sum = add(new_x_sum, new_y_sum) sum_add = add(new_x_sum, new_y_sum)
res = sub(sum, new_y_diff) res = sub(sum_add, new_y_diff)
output = transdata(res) output = transdata(res)
new_output = make_tuple(output) new_output = make_tuple(output)
ret = make_tuple(new_output) ret = make_tuple(new_output)
...@@ -79,8 +79,8 @@ def test_eliminate_cast(tag): ...@@ -79,8 +79,8 @@ def test_eliminate_cast(tag):
@fns @fns
def before(x, y): def before(x, y):
sum = add(x, y) sum_add = add(x, y)
res = sub(sum, y) res = sub(sum_add, y)
output = make_tuple(res) output = make_tuple(res)
return output return output
...@@ -89,8 +89,8 @@ def test_eliminate_cast(tag): ...@@ -89,8 +89,8 @@ def test_eliminate_cast(tag):
new_x_sum = cast(x) new_x_sum = cast(x)
new_y_sum = cast(y) new_y_sum = cast(y)
new_y_sum2 = cast(y) new_y_sum2 = cast(y)
sum = add(new_x_sum, new_y_sum) sum_add = add(new_x_sum, new_y_sum)
sum_cast1 = cast(sum) sum_cast1 = cast(sum_add)
sum_cast2 = cast(sum_cast1) sum_cast2 = cast(sum_cast1)
res = sub(sum_cast2, new_y_sum2) res = sub(sum_cast2, new_y_sum2)
output = cast(res) output = cast(res)
...@@ -103,8 +103,8 @@ def test_eliminate_cast(tag): ...@@ -103,8 +103,8 @@ def test_eliminate_cast(tag):
new_x_sum = cast(x) new_x_sum = cast(x)
new_y_sum = cast(y) new_y_sum = cast(y)
new_y_diff = cast(y) new_y_diff = cast(y)
sum = add(new_x_sum, new_y_sum) sum_add = add(new_x_sum, new_y_sum)
res = sub(sum, new_y_diff) res = sub(sum_add, new_y_diff)
output = cast(res) output = cast(res)
new_output = make_tuple(output) new_output = make_tuple(output)
ret = make_tuple(new_output) ret = make_tuple(new_output)
...@@ -118,8 +118,8 @@ def test_eliminate_5to4_depend_4to5(tag): ...@@ -118,8 +118,8 @@ def test_eliminate_5to4_depend_4to5(tag):
@fns @fns
def before(x, y): def before(x, y):
sum = add(x, y) sum_add = add(x, y)
sum_depend = depend(sum, x) sum_depend = depend(sum_add, x)
res = sub(sum_depend, y) res = sub(sum_depend, y)
output = make_tuple(res) output = make_tuple(res)
return output return output
...@@ -128,8 +128,8 @@ def test_eliminate_5to4_depend_4to5(tag): ...@@ -128,8 +128,8 @@ def test_eliminate_5to4_depend_4to5(tag):
def after1(x, y): def after1(x, y):
new_x_sum = transdata(x) new_x_sum = transdata(x)
new_y_sum = transdata(y) new_y_sum = transdata(y)
sum = add(new_x_sum, new_y_sum) sum_add = add(new_x_sum, new_y_sum)
sum_trans = transdata(sum) sum_trans = transdata(sum_add)
depend_between_trans = depend(sum_trans, x) depend_between_trans = depend(sum_trans, x)
depend_trans = transdata(depend_between_trans) depend_trans = transdata(depend_between_trans)
new_y_diff = transdata(y) new_y_diff = transdata(y)
...@@ -143,8 +143,8 @@ def test_eliminate_5to4_depend_4to5(tag): ...@@ -143,8 +143,8 @@ def test_eliminate_5to4_depend_4to5(tag):
def after2(x, y): def after2(x, y):
new_x_sum = transdata(x) new_x_sum = transdata(x)
new_y_sum = transdata(y) new_y_sum = transdata(y)
sum = add(new_x_sum, new_y_sum) sum_add = add(new_x_sum, new_y_sum)
depend_op = depend(sum, x) depend_op = depend(sum_add, x)
new_y_diff = transdata(y) new_y_diff = transdata(y)
res = sub(depend_op, new_y_diff) res = sub(depend_op, new_y_diff)
output = transdata(res) output = transdata(res)
...@@ -160,8 +160,8 @@ def test_eliminate_cast_depend_cast(tag): ...@@ -160,8 +160,8 @@ def test_eliminate_cast_depend_cast(tag):
@fns @fns
def before(x, y): def before(x, y):
sum = add(x, y) sum_add = add(x, y)
sum_depend = depend(sum, x) sum_depend = depend(sum_add, x)
sum_depend2 = depend(sum_depend, x) sum_depend2 = depend(sum_depend, x)
sum_depend3 = depend(sum_depend2, x) sum_depend3 = depend(sum_depend2, x)
res = sub(sum_depend3, y) res = sub(sum_depend3, y)
...@@ -172,8 +172,8 @@ def test_eliminate_cast_depend_cast(tag): ...@@ -172,8 +172,8 @@ def test_eliminate_cast_depend_cast(tag):
def after1(x, y): def after1(x, y):
new_x_sum = cast(x) new_x_sum = cast(x)
new_y_sum = cast(y) new_y_sum = cast(y)
sum = add(new_x_sum, new_y_sum) sum_add = add(new_x_sum, new_y_sum)
sum_cast = cast(sum) sum_cast = cast(sum_add)
depend_between_cast = depend(sum_cast, x) depend_between_cast = depend(sum_cast, x)
depend_between_cast2 = depend(depend_between_cast, x) depend_between_cast2 = depend(depend_between_cast, x)
depend_between_cast3 = depend(depend_between_cast2, x) depend_between_cast3 = depend(depend_between_cast2, x)
...@@ -189,8 +189,8 @@ def test_eliminate_cast_depend_cast(tag): ...@@ -189,8 +189,8 @@ def test_eliminate_cast_depend_cast(tag):
def after2(x, y): def after2(x, y):
new_x_sum = cast(x) new_x_sum = cast(x)
new_y_sum = cast(y) new_y_sum = cast(y)
sum = add(new_x_sum, new_y_sum) sum_add = add(new_x_sum, new_y_sum)
depend_op = depend(sum, x) depend_op = depend(sum_add, x)
depend_op2 = depend(depend_op, x) depend_op2 = depend(depend_op, x)
depend_op3 = depend(depend_op2, x) depend_op3 = depend(depend_op2, x)
new_y_diff = cast(y) new_y_diff = cast(y)
...@@ -201,4 +201,3 @@ def test_eliminate_cast_depend_cast(tag): ...@@ -201,4 +201,3 @@ def test_eliminate_cast_depend_cast(tag):
return ret return ret
return fns[tag] return fns[tag]
...@@ -40,14 +40,14 @@ def test_getnext_memcpy_elimination(tag): ...@@ -40,14 +40,14 @@ def test_getnext_memcpy_elimination(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(x): def before():
res = get_next() res = get_next()
res = memcpy_async_attr(res) res = memcpy_async_attr(res)
res = cast(res) res = cast(res)
return res return res
@fns @fns
def after(x): def after():
res = get_next() res = get_next()
res = cast(res) res = cast(res)
return res return res
...@@ -59,14 +59,14 @@ def test_getnext_memcpy_elimination_no_attr(tag): ...@@ -59,14 +59,14 @@ def test_getnext_memcpy_elimination_no_attr(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(x): def before():
res = get_next() res = get_next()
res = memcpy_async(res) res = memcpy_async(res)
res = cast(res) res = cast(res)
return res return res
@fns @fns
def after(x): def after():
res = get_next() res = get_next()
res = memcpy_async(res) res = memcpy_async(res)
res = cast(res) res = cast(res)
...@@ -79,7 +79,7 @@ def test_getnext_memcpy_elimination_memcpy_multi_users(tag): ...@@ -79,7 +79,7 @@ def test_getnext_memcpy_elimination_memcpy_multi_users(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(x): def before():
res = get_next() res = get_next()
memcpy_out = memcpy_async_attr(res) memcpy_out = memcpy_async_attr(res)
res = cast(memcpy_out) res = cast(memcpy_out)
...@@ -87,7 +87,7 @@ def test_getnext_memcpy_elimination_memcpy_multi_users(tag): ...@@ -87,7 +87,7 @@ def test_getnext_memcpy_elimination_memcpy_multi_users(tag):
return res return res
@fns @fns
def after(x): def after():
res = get_next() res = get_next()
memcpy_out = memcpy_async_attr(res) memcpy_out = memcpy_async_attr(res)
res = cast(memcpy_out) res = cast(memcpy_out)
...@@ -101,14 +101,14 @@ def test_getnext_memcpy_elimination_next_multi_inputs(tag): ...@@ -101,14 +101,14 @@ def test_getnext_memcpy_elimination_next_multi_inputs(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(x): def before():
res = get_next() res = get_next()
memcpy_out = memcpy_async_attr(res) memcpy_out = memcpy_async_attr(res)
res = add(memcpy_out, res) res = add(memcpy_out, res)
return res return res
@fns @fns
def after(x): def after():
res = get_next() res = get_next()
memcpy_out = memcpy_async_attr(res) memcpy_out = memcpy_async_attr(res)
res = add(memcpy_out, res) res = add(memcpy_out, res)
......
...@@ -127,14 +127,14 @@ def test_eliminate_depend_input2(tag): ...@@ -127,14 +127,14 @@ def test_eliminate_depend_input2(tag):
def before(x, y, z): def before(x, y, z):
new_z = four2five(z) new_z = four2five(z)
depend_intput = depend(y, new_z) depend_intput = depend(y, new_z)
sum = add(x, depend_intput) sum_add = add(x, depend_intput)
return sum return sum_add
@fns @fns
def after(x, y, z): def after(x, y, z):
depend_intput = depend(y, z) depend_intput = depend(y, z)
sum = add(x, depend_intput) sum_add = add(x, depend_intput)
return sum return sum_add
return fns[tag] return fns[tag]
...@@ -144,8 +144,8 @@ def test_opt_match(tag): ...@@ -144,8 +144,8 @@ def test_opt_match(tag):
@fns @fns
def graph1(x, y): def graph1(x, y):
sum = add(x, y) sum_add = add(x, y)
output = make_tuple(sum) output = make_tuple(sum_add)
return output return output
@fns @fns
...@@ -178,4 +178,3 @@ def test_func_graph_cse(tag): ...@@ -178,4 +178,3 @@ def test_func_graph_cse(tag):
return d return d
return fns[tag] return fns[tag]
...@@ -66,17 +66,17 @@ def test_lamb_next_mv_rule(tag): ...@@ -66,17 +66,17 @@ def test_lamb_next_mv_rule(tag):
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6, lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6,
constant_mul0_x, constant_mul1_sub, constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_mul4_x, constant_add2_y)
constant_add2_y)
outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1), outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1),
tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3)) tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3))
output = tuple_getitem(outputs, 0) output = tuple_getitem(outputs, 0)
return make_tuple(output) return make_tuple(output)
@fns @fns
def before_unmatched_real_div4(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, def before_unmatched_real_div4(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x,
constant_add2_y):
mul0 = Mul(constant_mul0_x, input4) mul0 = Mul(constant_mul0_x, input4)
mul1 = Mul(constant_mul1_sub, input3) mul1 = Mul(constant_mul1_sub, input3)
add0 = Add(mul0, mul1) add0 = Add(mul0, mul1)
...@@ -98,8 +98,9 @@ def test_lamb_next_mv_rule(tag): ...@@ -98,8 +98,9 @@ def test_lamb_next_mv_rule(tag):
return output return output
@fns @fns
def before_unmatched_real_div0(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, def before_unmatched_real_div0(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x,
constant_add2_y):
mul0 = Mul(constant_mul0_x, input4) mul0 = Mul(constant_mul0_x, input4)
mul1 = Mul(constant_mul1_sub, input3) mul1 = Mul(constant_mul1_sub, input3)
add0 = Add(mul0, mul1) add0 = Add(mul0, mul1)
...@@ -121,8 +122,9 @@ def test_lamb_next_mv_rule(tag): ...@@ -121,8 +122,9 @@ def test_lamb_next_mv_rule(tag):
return output return output
@fns @fns
def before_unmatched_real_div1(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, def before_unmatched_real_div1(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x,
constant_add2_y):
mul0 = Mul(constant_mul0_x, input4) mul0 = Mul(constant_mul0_x, input4)
mul1 = Mul(constant_mul1_sub, input3) mul1 = Mul(constant_mul1_sub, input3)
add0 = Add(mul0, mul1) add0 = Add(mul0, mul1)
...@@ -144,8 +146,9 @@ def test_lamb_next_mv_rule(tag): ...@@ -144,8 +146,9 @@ def test_lamb_next_mv_rule(tag):
return output return output
@fns @fns
def before_unmatched_real_div2(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, def before_unmatched_real_div2(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x,
constant_add2_y):
mul0 = Mul(constant_mul0_x, input4) mul0 = Mul(constant_mul0_x, input4)
mul1 = Mul(constant_mul1_sub, input3) mul1 = Mul(constant_mul1_sub, input3)
add0 = Add(mul0, mul1) add0 = Add(mul0, mul1)
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import Primitive from mindspore.ops import Primitive
......
...@@ -82,8 +82,8 @@ def test_eliminate_cast_op(tag): ...@@ -82,8 +82,8 @@ def test_eliminate_cast_op(tag):
@fns @fns
def before(x, y): def before(x, y):
sum = addn((x, y)) sum_add = addn((x, y))
sum_depend = depend(sum, addn((x, y))) sum_depend = depend(sum_add, addn((x, y)))
diff = sub(x, y) diff = sub(x, y)
res = mul(sum_depend, diff) res = mul(sum_depend, diff)
return res return res
...@@ -92,8 +92,8 @@ def test_eliminate_cast_op(tag): ...@@ -92,8 +92,8 @@ def test_eliminate_cast_op(tag):
def after1(x, y): def after1(x, y):
new_x_sum = cast(x) new_x_sum = cast(x)
new_y_sum = cast(y) new_y_sum = cast(y)
sum = addn(new_x_sum, new_y_sum) sum_add = addn(new_x_sum, new_y_sum)
sum_cast = cast(sum) sum_cast = cast(sum_add)
new_x_depend = cast(x) new_x_depend = cast(x)
new_y_depend = cast(y) new_y_depend = cast(y)
sum_depend = addn(new_x_depend, new_y_depend) sum_depend = addn(new_x_depend, new_y_depend)
...@@ -114,12 +114,12 @@ def test_eliminate_cast_op(tag): ...@@ -114,12 +114,12 @@ def test_eliminate_cast_op(tag):
def after2(x, y): def after2(x, y):
new_x_sum = cast(x) new_x_sum = cast(x)
new_y_sum = cast(y) new_y_sum = cast(y)
sum = addn(new_x_sum, new_y_sum) sum_add = addn(new_x_sum, new_y_sum)
new_x_depend = cast(x) new_x_depend = cast(x)
new_y_depend = cast(y) new_y_depend = cast(y)
sum_depend = addn(new_x_depend, new_y_depend) sum_depend = addn(new_x_depend, new_y_depend)
sum_depend_cast = cast(sum_depend) sum_depend_cast = cast(sum_depend)
depend_between_cast = depend(sum, sum_depend_cast) depend_between_cast = depend(sum_add, sum_depend_cast)
new_x_diff = cast(x) new_x_diff = cast(x)
new_y_diff = cast(y) new_y_diff = cast(y)
diff = sub(new_x_diff, new_y_diff) diff = sub(new_x_diff, new_y_diff)
...@@ -156,8 +156,8 @@ def test_eliminate_cast_new(tag): ...@@ -156,8 +156,8 @@ def test_eliminate_cast_new(tag):
@fns @fns
def before(x, y): def before(x, y):
sum = add(x, y) sum_add = add(x, y)
res = sub(sum, y) res = sub(sum_add, y)
output = make_tuple(res) output = make_tuple(res)
return output return output
...@@ -166,8 +166,8 @@ def test_eliminate_cast_new(tag): ...@@ -166,8 +166,8 @@ def test_eliminate_cast_new(tag):
new_x_sum = cast(x) new_x_sum = cast(x)
new_y_sum = cast(y) new_y_sum = cast(y)
new_y_sum2 = cast(y) new_y_sum2 = cast(y)
sum = add(new_x_sum, new_y_sum) sum_add = add(new_x_sum, new_y_sum)
sum_5to4 = cast(sum) sum_5to4 = cast(sum_add)
sum_4to5 = cast(sum_5to4) sum_4to5 = cast(sum_5to4)
res = sub(sum_4to5, new_y_sum2) res = sub(sum_4to5, new_y_sum2)
output = cast(res) output = cast(res)
...@@ -179,11 +179,10 @@ def test_eliminate_cast_new(tag): ...@@ -179,11 +179,10 @@ def test_eliminate_cast_new(tag):
new_x_sum = cast(x) new_x_sum = cast(x)
new_y_sum = cast(y) new_y_sum = cast(y)
new_y_diff = cast(y) new_y_diff = cast(y)
sum = add(new_x_sum, new_y_sum) sum_add = add(new_x_sum, new_y_sum)
res = sub(sum, new_y_diff) res = sub(sum_add, new_y_diff)
output = cast(res) output = cast(res)
new_output = make_tuple(output) new_output = make_tuple(output)
return new_output return new_output
return fns[tag] return fns[tag]
...@@ -39,14 +39,14 @@ def test_optimize_dependence(tag): ...@@ -39,14 +39,14 @@ def test_optimize_dependence(tag):
def before(x, y, z): def before(x, y, z):
new_z = TransData(z) new_z = TransData(z)
depend_intput = depend(y, new_z) depend_intput = depend(y, new_z)
sum = add(x, depend_intput) sum_add = add(x, depend_intput)
return sum return sum_add
@fns @fns
def after(x, y, z): def after(x, y, z):
depend_intput = depend(y, z) depend_intput = depend(y, z)
sum = add(x, depend_intput) sum_add = add(x, depend_intput)
return sum return sum_add
return fns[tag] return fns[tag]
...@@ -58,14 +58,14 @@ def test_optimize_dependence_with_make_tuple(tag): ...@@ -58,14 +58,14 @@ def test_optimize_dependence_with_make_tuple(tag):
def before(x, y, a, b): def before(x, y, a, b):
z = make_tuple(TransData(a), TransData(b)) z = make_tuple(TransData(a), TransData(b))
depend_intput = depend(y, z) depend_intput = depend(y, z)
sum = add(x, depend_intput) sum_add = add(x, depend_intput)
return sum return sum_add
@fns @fns
def after(x, y, a, b): def after(x, y, a, b):
z = make_tuple(a, b) z = make_tuple(a, b)
depend_intput = depend(y, z) depend_intput = depend(y, z)
sum = add(x, depend_intput) sum_add = add(x, depend_intput)
return sum return sum_add
return fns[tag] return fns[tag]
...@@ -34,8 +34,8 @@ def test_topk_split(tag): ...@@ -34,8 +34,8 @@ def test_topk_split(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(input): def before(x):
topk = TopK(input, 2) topk = TopK(x, 2)
output = tuple_getitem(topk, 0) output = tuple_getitem(topk, 0)
return output return output
......
...@@ -25,6 +25,7 @@ transdata = Primitive("TransData") ...@@ -25,6 +25,7 @@ transdata = Primitive("TransData")
transpose = Primitive("Transpose") transpose = Primitive("Transpose")
Transpose = P.Transpose() Transpose = P.Transpose()
class FnDict: class FnDict:
def __init__(self): def __init__(self):
self.fnDict = {} self.fnDict = {}
...@@ -40,30 +41,32 @@ def test_transdata_split_fraz_nchw(tag): ...@@ -40,30 +41,32 @@ def test_transdata_split_fraz_nchw(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(input): def before(x):
res = Transpose(input, (1, 0, 2, 3)) res = Transpose(x, (1, 0, 2, 3))
return res return res
@fns @fns
def after(input): def after(x):
res = transpose(input) res = transpose(x)
output = transdata(res) output = transdata(res)
output = transpose(output) output = transpose(output)
res = make_tuple(output) res = make_tuple(output)
return res return res
return fns[tag] return fns[tag]
def test_transdata_split_nchw_fraz(tag): def test_transdata_split_nchw_fraz(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(input): def before(x):
res = Transpose(input, (1, 0, 2, 3)) res = Transpose(x, (1, 0, 2, 3))
return res return res
@fns @fns
def after(input): def after(x):
res = transpose(input) res = transpose(x)
output = transdata(res) output = transdata(res)
output = transpose(output) output = transpose(output)
res = make_tuple(output) res = make_tuple(output)
......
...@@ -35,14 +35,14 @@ def test_transpose_reshape_fusion(tag): ...@@ -35,14 +35,14 @@ def test_transpose_reshape_fusion(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(input): def before(x):
transpose = Transpose(input, (1, 0, 2, 3)) transpose = Transpose(x, (1, 0, 2, 3))
reshape = Reshape(transpose, (2, 4, 8, 16)) reshape = Reshape(transpose, (2, 4, 8, 16))
return reshape return reshape
@fns @fns
def after(input): def after(x):
confusion = ConfusionTransposeD(input) confusion = ConfusionTransposeD(x)
res = make_tuple(confusion) res = make_tuple(confusion)
return res return res
......
...@@ -37,13 +37,13 @@ def test_transpose_transdata_fusion(tag): ...@@ -37,13 +37,13 @@ def test_transpose_transdata_fusion(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(input): def before(x):
res = Transpose(input, (1, 0, 2, 3)) res = Transpose(x, (1, 0, 2, 3))
return res return res
@fns @fns
def after(input): def after(x):
output = transdata(input) output = transdata(x)
res = make_tuple(output) res = make_tuple(output)
return res return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册