未验证 提交 dfb8e269 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #990 from cuicheng01/develop

Update vision_transformer.py
...@@ -243,7 +243,7 @@ class VisionTransformer(nn.Layer): ...@@ -243,7 +243,7 @@ class VisionTransformer(nn.Layer):
drop_path_rate=0., drop_path_rate=0.,
norm_layer='nn.LayerNorm', norm_layer='nn.LayerNorm',
epsilon=1e-5, epsilon=1e-5,
**args): **kwargs):
super().__init__() super().__init__()
self.class_num = class_num self.class_num = class_num
...@@ -331,9 +331,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False): ...@@ -331,9 +331,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
) )
def ViT_small_patch16_224(pretrained, def ViT_small_patch16_224(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -352,9 +350,7 @@ def ViT_small_patch16_224(pretrained, ...@@ -352,9 +350,7 @@ def ViT_small_patch16_224(pretrained,
return model return model
def ViT_base_patch16_224(pretrained, def ViT_base_patch16_224(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -374,9 +370,7 @@ def ViT_base_patch16_224(pretrained, ...@@ -374,9 +370,7 @@ def ViT_base_patch16_224(pretrained,
return model return model
def ViT_base_patch16_384(pretrained, def ViT_base_patch16_384(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -397,9 +391,7 @@ def ViT_base_patch16_384(pretrained, ...@@ -397,9 +391,7 @@ def ViT_base_patch16_384(pretrained,
return model return model
def ViT_base_patch32_384(pretrained, def ViT_base_patch32_384(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -420,9 +412,7 @@ def ViT_base_patch32_384(pretrained, ...@@ -420,9 +412,7 @@ def ViT_base_patch32_384(pretrained,
return model return model
def ViT_large_patch16_224(pretrained, def ViT_large_patch16_224(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -442,9 +432,7 @@ def ViT_large_patch16_224(pretrained, ...@@ -442,9 +432,7 @@ def ViT_large_patch16_224(pretrained,
return model return model
def ViT_large_patch16_384(pretrained, def ViT_large_patch16_384(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -465,9 +453,7 @@ def ViT_large_patch16_384(pretrained, ...@@ -465,9 +453,7 @@ def ViT_large_patch16_384(pretrained,
return model return model
def ViT_large_patch32_384(pretrained, def ViT_large_patch32_384(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -488,9 +474,7 @@ def ViT_large_patch32_384(pretrained, ...@@ -488,9 +474,7 @@ def ViT_large_patch32_384(pretrained,
return model return model
def ViT_huge_patch16_224(pretrained, def ViT_huge_patch16_224(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -508,9 +492,7 @@ def ViT_huge_patch16_224(pretrained, ...@@ -508,9 +492,7 @@ def ViT_huge_patch16_224(pretrained,
return model return model
def ViT_huge_patch32_384(pretrained, def ViT_huge_patch32_384(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
......
...@@ -574,6 +574,8 @@ class Trainer(object): ...@@ -574,6 +574,8 @@ class Trainer(object):
if len(batch_data) >= batch_size or idx == len(image_list) - 1: if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data) batch_tensor = paddle.to_tensor(batch_data)
out = self.model(batch_tensor) out = self.model(batch_tensor)
if isinstance(out, list):
out = out[0]
result = postprocess_func(out, image_file_list) result = postprocess_func(out, image_file_list)
print(result) print(result)
batch_data.clear() batch_data.clear()
......
...@@ -63,6 +63,8 @@ class ExportModel(nn.Layer): ...@@ -63,6 +63,8 @@ class ExportModel(nn.Layer):
def forward(self, x): def forward(self, x):
x = self.base_model(x) x = self.base_model(x)
if isinstance(x, list):
x = x[0]
if self.infer_model_name is not None: if self.infer_model_name is not None:
x = x[self.infer_model_name] x = x[self.infer_model_name]
if self.infer_output_key is not None: if self.infer_output_key is not None:
...@@ -76,7 +78,6 @@ if __name__ == "__main__": ...@@ -76,7 +78,6 @@ if __name__ == "__main__":
args = config.parse_args() args = config.parse_args()
config = config.get_config( config = config.get_config(
args.config, overrides=args.override, show=False) args.config, overrides=args.override, show=False)
log_file = os.path.join(config['Global']['output_dir'], log_file = os.path.join(config['Global']['output_dir'],
config["Arch"]["name"], "export.log") config["Arch"]["name"], "export.log")
init_logger(name='root', log_file=log_file) init_logger(name='root', log_file=log_file)
...@@ -86,7 +87,6 @@ if __name__ == "__main__": ...@@ -86,7 +87,6 @@ if __name__ == "__main__":
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"] assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
device = paddle.set_device(config["Global"]["device"]) device = paddle.set_device(config["Global"]["device"])
model = ExportModel(config["Arch"]) model = ExportModel(config["Arch"])
if config["Global"]["pretrained_model"] is not None: if config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model, load_dygraph_pretrain(model.base_model,
config["Global"]["pretrained_model"]) config["Global"]["pretrained_model"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册