未验证 提交 dfb8e269 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #990 from cuicheng01/develop

Update vision_transformer.py
......@@ -243,7 +243,7 @@ class VisionTransformer(nn.Layer):
drop_path_rate=0.,
norm_layer='nn.LayerNorm',
epsilon=1e-5,
**args):
**kwargs):
super().__init__()
self.class_num = class_num
......@@ -331,9 +331,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
)
def ViT_small_patch16_224(pretrained,
model,
model_url,
def ViT_small_patch16_224(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -352,9 +350,7 @@ def ViT_small_patch16_224(pretrained,
return model
def ViT_base_patch16_224(pretrained,
model,
model_url,
def ViT_base_patch16_224(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -374,9 +370,7 @@ def ViT_base_patch16_224(pretrained,
return model
def ViT_base_patch16_384(pretrained,
model,
model_url,
def ViT_base_patch16_384(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -397,9 +391,7 @@ def ViT_base_patch16_384(pretrained,
return model
def ViT_base_patch32_384(pretrained,
model,
model_url,
def ViT_base_patch32_384(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -420,9 +412,7 @@ def ViT_base_patch32_384(pretrained,
return model
def ViT_large_patch16_224(pretrained,
model,
model_url,
def ViT_large_patch16_224(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -442,9 +432,7 @@ def ViT_large_patch16_224(pretrained,
return model
def ViT_large_patch16_384(pretrained,
model,
model_url,
def ViT_large_patch16_384(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -465,9 +453,7 @@ def ViT_large_patch16_384(pretrained,
return model
def ViT_large_patch32_384(pretrained,
model,
model_url,
def ViT_large_patch32_384(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -488,9 +474,7 @@ def ViT_large_patch32_384(pretrained,
return model
def ViT_huge_patch16_224(pretrained,
model,
model_url,
def ViT_huge_patch16_224(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -508,9 +492,7 @@ def ViT_huge_patch16_224(pretrained,
return model
def ViT_huge_patch32_384(pretrained,
model,
model_url,
def ViT_huge_patch32_384(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......
......@@ -574,6 +574,8 @@ class Trainer(object):
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data)
out = self.model(batch_tensor)
if isinstance(out, list):
out = out[0]
result = postprocess_func(out, image_file_list)
print(result)
batch_data.clear()
......
......@@ -63,6 +63,8 @@ class ExportModel(nn.Layer):
def forward(self, x):
x = self.base_model(x)
if isinstance(x, list):
x = x[0]
if self.infer_model_name is not None:
x = x[self.infer_model_name]
if self.infer_output_key is not None:
......@@ -76,7 +78,6 @@ if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(
args.config, overrides=args.override, show=False)
log_file = os.path.join(config['Global']['output_dir'],
config["Arch"]["name"], "export.log")
init_logger(name='root', log_file=log_file)
......@@ -86,7 +87,6 @@ if __name__ == "__main__":
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
device = paddle.set_device(config["Global"]["device"])
model = ExportModel(config["Arch"])
if config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model,
config["Global"]["pretrained_model"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册