未验证 提交 2607dbca 编写于 作者: W Wenyu 提交者: GitHub

recompute flag (#6628)

上级 29356b07
......@@ -7,6 +7,7 @@ weights: output/cascade_rcnn_vit_large_hrfpn_cae_1x_coco/model_final
depth: &depth 24
dim: &dim 1024
use_fused_allreduce_gradients: &use_checkpoint True
VisionTransformer:
img_size: [800, 1344]
......@@ -15,6 +16,7 @@ VisionTransformer:
num_heads: 16
drop_path_rate: 0.25
out_indices: [7, 11, 15, 23]
use_checkpoint: *use_checkpoint
pretrained: https://bj.bcebos.com/v1/paddledet/models/pretrained/vit_large_cae_pretrained.pdparams
HRFPN:
......
......@@ -596,7 +596,7 @@ class VisionTransformer(nn.Layer):
feats = []
for idx, blk in enumerate(self.blocks):
if self.use_checkpoint:
if self.use_checkpoint and self.training:
x = paddle.distributed.fleet.utils.recompute(
blk, x, rel_pos_bias, **{"preserve_rng_state": True})
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册