未验证 提交 5be972e6 编写于 作者: X xiaomuchongwhs 提交者: GitHub

Merge pull request #45 from Oneflow-Inc/dev-hswang

dev-hswang
......@@ -634,5 +634,5 @@ python3 cnn_benchmark/of_cnn_train_val.py \
--model="vgg" \
```
The top1 accuracy and the top5 acuuracy are 69.3359% and 89.1370%, respectively for our oneflow model after 90 epochs of training.
The top1 accuracy and the top5 acuuracy are 69.7326% and 89.3806%, respectively for our oneflow model after 90 epochs of training.
For reference, the top1 accuracy and the top5 accuracy are 71.5% and 89.9%, respectively for the model from the tensorflow benchmarks after 90 epochs of training.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import math
import oneflow as flow
import ofrecord_util
import config as configs
from util import Snapshot, Summary, InitNodes, Metric
......@@ -15,14 +12,10 @@ import resnet_model
import vgg_model
import alexnet_model
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
parser = configs.get_parser()
args = parser.parse_args()
configs.print_args(args)
total_device_num = args.num_nodes * args.gpu_num_per_node
train_batch_size = total_device_num * args.batch_size_per_device
val_batch_size = total_device_num * args.val_batch_size_per_device
......
......@@ -60,6 +60,14 @@ def gen_model_update_conf(args):
"decay_batches": decay_batches,
"end_learning_rate": 0.00001,
}}
# weight decay
# if args.wd > 0:
# assert args.wd < 1.0
# model_update_conf['weight_decay_conf'] = {
# "weight_decay_rate": args.wd,
# "excludes": {"pattern": ['_bn-']}
# }
pprint.pprint(model_update_conf)
return model_update_conf
......
......@@ -40,7 +40,6 @@ def conv2d_layer(
bn=True,
):
weight_shape = (filters, input.shape[1], kernel_size, kernel_size)
print("weight_shape:{}".format(weight_shape))
weight = flow.get_variable(
name + "_weight",
shape=weight_shape,
......@@ -65,7 +64,7 @@ def conv2d_layer(
output = _batch_norm(output, name + "_bn")
output = flow.nn.relu(output)
else:
output = flow.nn.relu(output)
output = flow.nn.relu(output)
else:
raise NotImplementedError
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册