未验证 提交 76b63c25 编写于 作者: Y Yancey 提交者: GitHub

move transpiler files into transpiler folder (#10415)

上级 55e714e0
...@@ -40,16 +40,14 @@ import backward ...@@ -40,16 +40,14 @@ import backward
import regularizer import regularizer
import average import average
import metrics import metrics
import transpiler
from param_attr import ParamAttr, WeightNormParamAttr from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace
from distribute_transpiler import DistributeTranspiler from transpiler import DistributeTranspiler, SimpleDistributeTranspiler, InferenceTranspiler, memory_optimize, release_memory
from distribute_transpiler_simple import SimpleDistributeTranspiler
from concurrency import (Go, make_channel, channel_send, channel_recv, from concurrency import (Go, make_channel, channel_send, channel_recv,
channel_close, Select) channel_close, Select)
from inference_transpiler import InferenceTranspiler
import clip import clip
from memory_optimization_transpiler import memory_optimize, release_memory
import profiler import profiler
import unique_name import unique_name
import recordio_writer import recordio_writer
...@@ -58,7 +56,7 @@ from parallel_executor import ParallelExecutor ...@@ -58,7 +56,7 @@ from parallel_executor import ParallelExecutor
Tensor = LoDTensor Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\
trainer.__all__ + inferencer.__all__ + [ trainer.__all__ + inferencer.__all__ + transpiler.__all__ + [
'io', 'io',
'initializer', 'initializer',
'layers', 'layers',
...@@ -76,11 +74,6 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\ ...@@ -76,11 +74,6 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\
'WeightNormParamAttr', 'WeightNormParamAttr',
'DataFeeder', 'DataFeeder',
'clip', 'clip',
'SimpleDistributeTranspiler',
'DistributeTranspiler',
'InferenceTranspiler',
'memory_optimize',
'release_memory',
'profiler', 'profiler',
'unique_name', 'unique_name',
'recordio_writer', 'recordio_writer',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distribute_transpiler import DistributeTranspiler
from inference_transpiler import InferenceTranspiler
from memory_optimization_transpiler import memory_optimize, release_memory
from distribute_transpiler_simple import SimpleDistributeTranspiler
__all__ = [
"DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler",
"memory_optimize", "release_memory"
]
...@@ -17,9 +17,8 @@ from __future__ import print_function ...@@ -17,9 +17,8 @@ from __future__ import print_function
import math import math
import distributed_splitter as splitter import distributed_splitter as splitter
import framework from .. import core
from framework import Program, default_main_program, Variable, Parameter from ..framework import Program, default_main_program, Variable, Parameter
from . import core
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
...@@ -135,6 +134,16 @@ def split_dense_variable(var_list, ...@@ -135,6 +134,16 @@ def split_dense_variable(var_list,
return blocks return blocks
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
class DistributeTranspiler: class DistributeTranspiler:
def transpile(self, def transpile(self,
trainer_id, trainer_id,
...@@ -317,7 +326,7 @@ class DistributeTranspiler: ...@@ -317,7 +326,7 @@ class DistributeTranspiler:
def get_trainer_program(self): def get_trainer_program(self):
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
self.delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone. # FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.origin_program.__str__() self.origin_program.__str__()
return self.origin_program return self.origin_program
...@@ -601,7 +610,7 @@ class DistributeTranspiler: ...@@ -601,7 +610,7 @@ class DistributeTranspiler:
attrs={"axis": 0}) attrs={"axis": 0})
# delete lookup_table_op # delete lookup_table_op
self.delete_ops(program.global_block(), [op]) delete_ops(program.global_block(), [op])
# break for loop # break for loop
break break
...@@ -1164,12 +1173,3 @@ class DistributeTranspiler: ...@@ -1164,12 +1173,3 @@ class DistributeTranspiler:
in_name.startswith("beta2_pow_acc"): in_name.startswith("beta2_pow_acc"):
return True return True
return False return False
def delete_ops(self, block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
...@@ -12,10 +12,8 @@ ...@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import framework from ..framework import Program, default_main_program, Parameter, Variable
from framework import Program, default_main_program, Parameter, Variable from ..layer_helper import LayerHelper
import optimizer
from layer_helper import LayerHelper
def hash_name_to_server(params_grads, pserver_endpoints): def hash_name_to_server(params_grads, pserver_endpoints):
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from framework import Program from .. import core
from executor import global_scope from ..framework import Program
from . import core from ..executor import global_scope
class InferenceTranspiler: class InferenceTranspiler:
......
...@@ -13,11 +13,9 @@ ...@@ -13,11 +13,9 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
import framework from .. import core
from framework import Program, default_main_program, Parameter, Variable from ..framework import Program, default_main_program, Parameter, Variable
import backward from ..backward import _rename_arg_
from backward import _rename_arg_
from . import core
dtype_to_size = { dtype_to_size = {
core.VarDesc.VarType.FP16: 2, core.VarDesc.VarType.FP16: 2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册