未验证 提交 2607dbca 编写于 作者: W Wenyu 提交者: GitHub

recompute flag (#6628)

上级 29356b07
...@@ -7,6 +7,7 @@ weights: output/cascade_rcnn_vit_large_hrfpn_cae_1x_coco/model_final ...@@ -7,6 +7,7 @@ weights: output/cascade_rcnn_vit_large_hrfpn_cae_1x_coco/model_final
depth: &depth 24 depth: &depth 24
dim: &dim 1024 dim: &dim 1024
use_fused_allreduce_gradients: &use_checkpoint True
VisionTransformer: VisionTransformer:
img_size: [800, 1344] img_size: [800, 1344]
...@@ -15,6 +16,7 @@ VisionTransformer: ...@@ -15,6 +16,7 @@ VisionTransformer:
num_heads: 16 num_heads: 16
drop_path_rate: 0.25 drop_path_rate: 0.25
out_indices: [7, 11, 15, 23] out_indices: [7, 11, 15, 23]
use_checkpoint: *use_checkpoint
pretrained: https://bj.bcebos.com/v1/paddledet/models/pretrained/vit_large_cae_pretrained.pdparams pretrained: https://bj.bcebos.com/v1/paddledet/models/pretrained/vit_large_cae_pretrained.pdparams
HRFPN: HRFPN:
......
...@@ -596,7 +596,7 @@ class VisionTransformer(nn.Layer): ...@@ -596,7 +596,7 @@ class VisionTransformer(nn.Layer):
feats = [] feats = []
for idx, blk in enumerate(self.blocks): for idx, blk in enumerate(self.blocks):
if self.use_checkpoint: if self.use_checkpoint and self.training:
x = paddle.distributed.fleet.utils.recompute( x = paddle.distributed.fleet.utils.recompute(
blk, x, rel_pos_bias, **{"preserve_rng_state": True}) blk, x, rel_pos_bias, **{"preserve_rng_state": True})
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册