未验证 提交 72965226 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #12818 from velconia/fix_python3_CI_job

Fix python3 CI job
......@@ -200,9 +200,11 @@ TEST(GraphTest, WriteAfterWrite) {
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2);
ASSERT_EQ(control_dep1, control_dep2);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
}
} // namespace framework
} // namespace paddle
......@@ -27,6 +27,7 @@ import unittest
from multiprocessing import Process
import os
import signal
import six
import collections
SEED = 1
......@@ -55,7 +56,8 @@ def cnn_model(data):
# TODO(dzhwinter) : refine the initializer and random seed settting
SIZE = 10
input_shape = conv_pool_2.shape
param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
param_shape = [six.moves.reduce(lambda a, b: a * b, input_shape[1:], 1)
] + [SIZE]
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5
predict = fluid.layers.fc(
......@@ -108,7 +110,7 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
def operator_equal(a, b):
for k, v in a.__dict__.iteritems():
for k, v in six.iteritems(a.__dict__):
if isinstance(v, fluid.framework.Program) or \
isinstance(v, fluid.framework.Block):
continue
......@@ -118,8 +120,8 @@ def operator_equal(a, b):
raise ValueError("In operator_equal not equal:{0}\n".format(k))
elif isinstance(v, collections.OrderedDict):
v0 = sorted(v.iteritems(), key=lambda x: x[0])
v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0])
v0 = sorted(six.iteritems(v), key=lambda x: x[0])
v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0])
if v0 != v1:
raise ValueError("In operator_equal not equal:{0}\n".format(k))
......@@ -131,7 +133,7 @@ def operator_equal(a, b):
def block_equal(a, b):
for k, v in a.__dict__.iteritems():
for k, v in six.iteritems(a.__dict__):
if isinstance(v, core.ProgramDesc) or isinstance(
v, fluid.framework.Program) or isinstance(v, core.BlockDesc):
continue
......@@ -143,8 +145,8 @@ def block_equal(a, b):
assert (len(a.ops) == len(b.ops))
elif isinstance(v, collections.OrderedDict):
v0 = sorted(v.iteritems(), key=lambda x: x[0])
v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0])
v0 = sorted(six.iteritems(v), key=lambda x: x[0])
v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0])
if v0 != v1:
raise ValueError("In block_equal not equal:{0}\n".format(k))
......@@ -156,7 +158,7 @@ def block_equal(a, b):
def program_equal(a, b):
for k, v in a.__dict__.iteritems():
for k, v in six.iteritems(a.__dict__):
if isinstance(v, core.ProgramDesc):
continue
......
......@@ -369,7 +369,7 @@ class DistributeTranspiler(object):
# FIXME(gongwb): delete not need ops.
# note that: some parameter is not trainable and those ops can't be deleted.
for varname, splited_var in self.param_var_mapping.iteritems():
for varname, splited_var in six.iteritems(self.param_var_mapping):
# Get the eplist of recv vars
eps = []
for var in splited_var:
......@@ -406,7 +406,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in self.param_var_mapping.iteritems():
for varname, splited_var in six.iteritems(self.param_var_mapping):
#add concat ops to merge splited parameters received from parameter servers.
if len(splited_var) <= 1:
continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册