提交 c838fa31 编写于 作者: M minqiyang

Port dist_transpiler to Python3.5

Resume prelu_op_test in python2
上级 90b5be85
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import six
from op_test import OpTest from op_test import OpTest
...@@ -62,17 +63,20 @@ class PReluTest(OpTest): ...@@ -62,17 +63,20 @@ class PReluTest(OpTest):
# TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues # TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues
# class TestCase1(PReluTest): if six.PY2:
# def initTestCase(self):
# self.attrs = {'mode': "all"}
# class TestCase2(PReluTest): class TestCase1(PReluTest):
# def initTestCase(self): def initTestCase(self):
# self.attrs = {'mode': "channel"} self.attrs = {'mode': "all"}
class TestCase2(PReluTest):
def initTestCase(self):
self.attrs = {'mode': "channel"}
class TestCase3(PReluTest):
def initTestCase(self):
self.attrs = {'mode': "element"}
# class TestCase3(PReluTest):
# def initTestCase(self):
# self.attrs = {'mode': "element"}
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -153,7 +153,7 @@ def block_to_code(block, block_idx): ...@@ -153,7 +153,7 @@ def block_to_code(block, block_idx):
indent += 1 indent += 1
# sort all vars # sort all vars
all_vars = sorted(block.vars.iteritems(), key=lambda x: x[0]) all_vars = sorted(six.iteritems(block.vars), key=lambda x: x[0])
for var in all_vars: for var in all_vars:
print("{}{}".format(get_indent_space(indent), variable_to_code(var[1]))) print("{}{}".format(get_indent_space(indent), variable_to_code(var[1])))
......
...@@ -300,7 +300,7 @@ class DistributeTranspiler(object): ...@@ -300,7 +300,7 @@ class DistributeTranspiler(object):
input_deps = grad_name_to_send_dummy_out.values() input_deps = grad_name_to_send_dummy_out.values()
program.global_block().append_op( program.global_block().append_op(
type="send_barrier", type="send_barrier",
inputs={"X": input_deps}, inputs={"X": list(input_deps)},
outputs={"Out": send_barrier_out}, outputs={"Out": send_barrier_out},
attrs={ attrs={
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
...@@ -455,7 +455,7 @@ class DistributeTranspiler(object): ...@@ -455,7 +455,7 @@ class DistributeTranspiler(object):
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
# NOTE: if enable memory optimization, origin vars maybe removed. # NOTE: if enable memory optimization, origin vars maybe removed.
if startup_program.global_block().vars.has_key(varname): if varname in startup_program.global_block().vars:
orig_param = startup_program.global_block().vars[varname] orig_param = startup_program.global_block().vars[varname]
else: else:
origin_param_var = self.origin_program.global_block().vars[ origin_param_var = self.origin_program.global_block().vars[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册