diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index b1b8d1c586c98a327a8e5b4890ced00022155e6b..cadda49c399a6d65079cacedfea61f4fd580a69a 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -200,9 +200,11 @@ TEST(GraphTest, WriteAfterWrite) { ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); control_dep2 = n->inputs[1]; ASSERT_EQ(n->inputs.size(), 2); - ASSERT_EQ(control_dep1, control_dep2); } } + ASSERT_NE(control_dep1, nullptr); + ASSERT_NE(control_dep2, nullptr); + ASSERT_EQ(control_dep1, control_dep2); } } // namespace framework } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_desc_clone.py b/python/paddle/fluid/tests/unittests/test_desc_clone.py index 88d44e453c7976f5e0fbda2c0871dfabd4bb30aa..fa6b67956259f33b109758c5939ab5729482695a 100644 --- a/python/paddle/fluid/tests/unittests/test_desc_clone.py +++ b/python/paddle/fluid/tests/unittests/test_desc_clone.py @@ -27,6 +27,7 @@ import unittest from multiprocessing import Process import os import signal +import six import collections SEED = 1 @@ -55,7 +56,8 @@ def cnn_model(data): # TODO(dzhwinter) : refine the initializer and random seed settting SIZE = 10 input_shape = conv_pool_2.shape - param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE] + param_shape = [six.moves.reduce(lambda a, b: a * b, input_shape[1:], 1) + ] + [SIZE] scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5 predict = fluid.layers.fc( @@ -108,7 +110,7 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers): def operator_equal(a, b): - for k, v in a.__dict__.iteritems(): + for k, v in six.iteritems(a.__dict__): if isinstance(v, fluid.framework.Program) or \ isinstance(v, fluid.framework.Block): continue @@ -118,8 +120,8 @@ def operator_equal(a, b): raise ValueError("In operator_equal not equal:{0}\n".format(k)) elif isinstance(v, collections.OrderedDict): - v0 = sorted(v.iteritems(), key=lambda x: x[0]) - v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0]) + v0 = sorted(six.iteritems(v), key=lambda x: x[0]) + v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0]) if v0 != v1: raise ValueError("In operator_equal not equal:{0}\n".format(k)) @@ -131,7 +133,7 @@ def operator_equal(a, b): def block_equal(a, b): - for k, v in a.__dict__.iteritems(): + for k, v in six.iteritems(a.__dict__): if isinstance(v, core.ProgramDesc) or isinstance( v, fluid.framework.Program) or isinstance(v, core.BlockDesc): continue @@ -143,8 +145,8 @@ def block_equal(a, b): assert (len(a.ops) == len(b.ops)) elif isinstance(v, collections.OrderedDict): - v0 = sorted(v.iteritems(), key=lambda x: x[0]) - v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0]) + v0 = sorted(six.iteritems(v), key=lambda x: x[0]) + v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0]) if v0 != v1: raise ValueError("In block_equal not equal:{0}\n".format(k)) @@ -156,7 +158,7 @@ def block_equal(a, b): def program_equal(a, b): - for k, v in a.__dict__.iteritems(): + for k, v in six.iteritems(a.__dict__): if isinstance(v, core.ProgramDesc): continue diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 112a6839a27aad61f708f1d51b68fde03e716134..64c2f32b9a16c9d1cb67bd84c59ba2bd1799fdc5 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -369,7 +369,7 @@ class DistributeTranspiler(object): # FIXME(gongwb): delete not need ops. # note that: some parameter is not trainable and those ops can't be deleted. - for varname, splited_var in self.param_var_mapping.iteritems(): + for varname, splited_var in six.iteritems(self.param_var_mapping): # Get the eplist of recv vars eps = [] for var in splited_var: @@ -406,7 +406,7 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) - for varname, splited_var in self.param_var_mapping.iteritems(): + for varname, splited_var in six.iteritems(self.param_var_mapping): #add concat ops to merge splited parameters received from parameter servers. if len(splited_var) <= 1: continue