提交 f6da3bd6 编写于 作者: W weishengyu

move files

上级 86346fda
...@@ -2,7 +2,6 @@ __pycache__/ ...@@ -2,7 +2,6 @@ __pycache__/
*.pyc *.pyc
*.sw* *.sw*
*/workerlog* */workerlog*
dataset/
checkpoints/ checkpoints/
output/ output/
pretrained/ pretrained/
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from . import optimizer from . import optimizer
from .modeling import * from .arch import *
from .optimizer import * from .optimizer import *
from .data import * from .data import *
from .utils import * from .utils import *
...@@ -12,9 +12,8 @@ ...@@ -12,9 +12,8 @@
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
from . import architectures from . import backbone
from . import loss
from .architectures import * from .backbone import *
from .loss import * from .loss import *
from .utils import * from .utils import *
from abc import ABC
from paddle import nn
import re
class Identity(nn.Layer):
def __init__(self):
super(Identity, self).__init__()
def forward(self, inputs):
return inputs
class TheseusLayer(nn.Layer, ABC):
def __init__(self, *args, return_patterns=None, stop_layer=None, **kwargs):
super(TheseusLayer, self).__init__()
self.res_dict = None
self.register_forward_post_hook(self._disconnect_res_dict_hook)
if return_patterns is not None or stop_layer is not None:
self._update_sub(return_patterns, stop_layer)
def forward(self, *input, res_dict=None, **kwargs):
if res_dict is not None:
self.res_dict = res_dict
def _update_sub(self, return_layers, stop_layer):
after_stop = False
for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name()
if stop_layer is not None and layer_name == stop_layer:
after_stop = True
if after_stop:
self._sub_layers[layer_i] = Identity()
for return_pattern in return_layers:
if return_layers is not None and re.match(return_pattern, layer_name):
self._sub_layers[layer_i].register_forward_post_hook(self._save_sub_res_hook)
def _save_sub_res_hook(self, layer, input, output):
self.res_dict[layer.full_name()] = output
def _disconnect_res_dict_hook(self, input, output):
self.res_dict = None
def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name()
if re.match(layer_name_pattern, layer_name):
self._sub_layers[layer_i] = replace_function(self._sub_layers[layer_i])
if recursive and isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i].replace_sub(layer_name_pattern, replace_function, recursive)
'''
example of replace function:
def replace_conv(origin_conv: nn.Conv2D):
new_conv = nn.Conv2D(
in_channels=origin_conv._in_channels,
out_channels=origin_conv._out_channels,
kernel_size=origin_conv._kernel_size,
stride=2
)
return new_conv
'''
\ No newline at end of file
...@@ -16,7 +16,7 @@ import six ...@@ -16,7 +16,7 @@ import six
import types import types
from difflib import SequenceMatcher from difflib import SequenceMatcher
from . import architectures from . import backbone
def get_architectures(): def get_architectures():
...@@ -24,15 +24,15 @@ def get_architectures(): ...@@ -24,15 +24,15 @@ def get_architectures():
get all of model architectures get all of model architectures
""" """
names = [] names = []
for k, v in architectures.__dict__.items(): for k, v in backbone.__dict__.items():
if isinstance(v, (types.FunctionType, six.class_types)): if isinstance(v, (types.FunctionType, six.class_types)):
names.append(k) names.append(k)
return names return names
def get_blacklist_model_in_static_mode(): def get_blacklist_model_in_static_mode():
from ppcls.modeling.architectures import distilled_vision_transformer from ppcls.arch.backbone import distilled_vision_transformer
from ppcls.modeling.architectures import vision_transformer from ppcls.arch.backbone import vision_transformer
blacklist = distilled_vision_transformer.__all__ + vision_transformer.__all__ blacklist = distilled_vision_transformer.__all__ + vision_transformer.__all__
return blacklist return blacklist
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册