未验证 提交 76b63c25 编写于 作者: Y Yancey 提交者: GitHub

move transpiler files into transpiler folder (#10415)

上级 55e714e0
......@@ -40,16 +40,14 @@ import backward
import regularizer
import average
import metrics
import transpiler
from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace
from distribute_transpiler import DistributeTranspiler
from distribute_transpiler_simple import SimpleDistributeTranspiler
from transpiler import DistributeTranspiler, SimpleDistributeTranspiler, InferenceTranspiler, memory_optimize, release_memory
from concurrency import (Go, make_channel, channel_send, channel_recv,
channel_close, Select)
from inference_transpiler import InferenceTranspiler
import clip
from memory_optimization_transpiler import memory_optimize, release_memory
import profiler
import unique_name
import recordio_writer
......@@ -58,7 +56,7 @@ from parallel_executor import ParallelExecutor
Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\
trainer.__all__ + inferencer.__all__ + [
trainer.__all__ + inferencer.__all__ + transpiler.__all__ + [
'io',
'initializer',
'layers',
......@@ -76,11 +74,6 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\
'WeightNormParamAttr',
'DataFeeder',
'clip',
'SimpleDistributeTranspiler',
'DistributeTranspiler',
'InferenceTranspiler',
'memory_optimize',
'release_memory',
'profiler',
'unique_name',
'recordio_writer',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distribute_transpiler import DistributeTranspiler
from inference_transpiler import InferenceTranspiler
from memory_optimization_transpiler import memory_optimize, release_memory
from distribute_transpiler_simple import SimpleDistributeTranspiler
__all__ = [
"DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler",
"memory_optimize", "release_memory"
]
......@@ -17,9 +17,8 @@ from __future__ import print_function
import math
import distributed_splitter as splitter
import framework
from framework import Program, default_main_program, Variable, Parameter
from . import core
from .. import core
from ..framework import Program, default_main_program, Variable, Parameter
LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
......@@ -135,6 +134,16 @@ def split_dense_variable(var_list,
return blocks
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
class DistributeTranspiler:
def transpile(self,
trainer_id,
......@@ -317,7 +326,7 @@ class DistributeTranspiler:
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
self.delete_ops(self.origin_program.global_block(), self.optimize_ops)
delete_ops(self.origin_program.global_block(), self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.origin_program.__str__()
return self.origin_program
......@@ -601,7 +610,7 @@ class DistributeTranspiler:
attrs={"axis": 0})
# delete lookup_table_op
self.delete_ops(program.global_block(), [op])
delete_ops(program.global_block(), [op])
# break for loop
break
......@@ -1164,12 +1173,3 @@ class DistributeTranspiler:
in_name.startswith("beta2_pow_acc"):
return True
return False
def delete_ops(self, block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
......@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import framework
from framework import Program, default_main_program, Parameter, Variable
import optimizer
from layer_helper import LayerHelper
from ..framework import Program, default_main_program, Parameter, Variable
from ..layer_helper import LayerHelper
def hash_name_to_server(params_grads, pserver_endpoints):
......
......@@ -13,9 +13,9 @@
# limitations under the License.
import numpy as np
from framework import Program
from executor import global_scope
from . import core
from .. import core
from ..framework import Program
from ..executor import global_scope
class InferenceTranspiler:
......
......@@ -13,11 +13,9 @@
# limitations under the License.
from collections import defaultdict
import framework
from framework import Program, default_main_program, Parameter, Variable
import backward
from backward import _rename_arg_
from . import core
from .. import core
from ..framework import Program, default_main_program, Parameter, Variable
from ..backward import _rename_arg_
dtype_to_size = {
core.VarDesc.VarType.FP16: 2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册