提交 92e2d21e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5088 Fix MASS and FasterRcnn CI Problem.

Merge pull request !5088 from linqingke/mass
...@@ -27,7 +27,7 @@ FasterRcnn proposed that convolution feature maps based on region detectors (suc ...@@ -27,7 +27,7 @@ FasterRcnn proposed that convolution feature maps based on region detectors (suc
[Paper](https://arxiv.org/abs/1506.01497): Ren S , He K , Girshick R , et al. Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2015, 39(6). [Paper](https://arxiv.org/abs/1506.01497): Ren S , He K , Girshick R , et al. Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2015, 39(6).
#Model Architecture # Model Architecture
FasterRcnn is a two-stage target detection network,This network uses a region proposal network (RPN), which can share the convolution features of the whole image with the detection network, so that the calculation of region proposal is almost cost free. The whole network further combines RPN and FastRcnn into a network by sharing the convolution features. FasterRcnn is a two-stage target detection network,This network uses a region proposal network (RPN), which can share the convolution features of the whole image with the detection network, so that the calculation of region proposal is almost cost free. The whole network further combines RPN and FastRcnn into a network by sharing the convolution features.
...@@ -42,7 +42,7 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>) ...@@ -42,7 +42,7 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>)
- Data format:image and json files - Data format:image and json files
- Note:Data will be processed in dataset.py - Note:Data will be processed in dataset.py
#Environment Requirements # Environment Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en). - Install [MindSpore](https://www.mindspore.cn/install/en).
...@@ -87,6 +87,8 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>) ...@@ -87,6 +87,8 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>)
After installing MindSpore via the official website, you can start training and evaluation as follows: After installing MindSpore via the official website, you can start training and evaluation as follows:
Note: 1.the first run will generate the mindeocrd file, which will take a long time. 2. pretrained model is a resnet50 checkpoint that trained over ImageNet2012. 3. VALIDATION_JSON_FILE is label file. CHECKPOINT_PATH is a checkpoint file after trained.
``` ```
# standalone training # standalone training
sh run_standalone_train_ascend.sh [PRETRAINED_MODEL] sh run_standalone_train_ascend.sh [PRETRAINED_MODEL]
......
...@@ -97,7 +97,7 @@ class Rcnn(nn.Cell): ...@@ -97,7 +97,7 @@ class Rcnn(nn.Cell):
self.relu = P.ReLU() self.relu = P.ReLU()
self.logicaland = P.LogicalAnd() self.logicaland = P.LogicalAnd()
self.loss_cls = P.SoftmaxCrossEntropyWithLogits() self.loss_cls = P.SoftmaxCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(sigma=1.0) self.loss_bbox = P.SmoothL1Loss(beta=1.0)
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.onehot = P.OneHot() self.onehot = P.OneHot()
self.greater = P.Greater() self.greater = P.Greater()
......
...@@ -137,7 +137,7 @@ class RPN(nn.Cell): ...@@ -137,7 +137,7 @@ class RPN(nn.Cell):
self.CheckValid = P.CheckValid() self.CheckValid = P.CheckValid()
self.sum_loss = P.ReduceSum() self.sum_loss = P.ReduceSum()
self.loss_cls = P.SigmoidCrossEntropyWithLogits() self.loss_cls = P.SigmoidCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(sigma=1.0/9.0) self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0)
self.squeeze = P.Squeeze() self.squeeze = P.Squeeze()
self.cast = P.Cast() self.cast = P.Cast()
self.tile = P.Tile() self.tile = P.Tile()
......
...@@ -151,7 +151,7 @@ def _build_training_pipeline(config: TransformerConfig, ...@@ -151,7 +151,7 @@ def _build_training_pipeline(config: TransformerConfig,
if dataset is None: if dataset is None:
raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.") raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.")
update_steps = dataset.get_repeat_count() * dataset.get_dataset_size() update_steps = config.epochs * dataset.get_dataset_size()
if config.lr_scheduler == "isr": if config.lr_scheduler == "isr":
lr = Tensor(square_root_schedule(lr=config.lr, lr = Tensor(square_root_schedule(lr=config.lr,
update_num=update_steps, update_num=update_steps,
...@@ -331,7 +331,8 @@ if __name__ == '__main__': ...@@ -331,7 +331,8 @@ if __name__ == '__main__':
mode=context.GRAPH_MODE, mode=context.GRAPH_MODE,
device_target=args.platform, device_target=args.platform,
reserve_class_name_in_scope=False, reserve_class_name_in_scope=False,
device_id=device_id) device_id=device_id,
max_call_depth=2000)
_rank_size = os.getenv('RANK_SIZE') _rank_size = os.getenv('RANK_SIZE')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册