未验证 提交 bc07a010 编写于 作者: Y Yiqun Liu 提交者: GitHub

Transfer the value of stop_gradient for feeding data. (#4831)

test=develop
上级 12080a0e
...@@ -36,8 +36,12 @@ def _basic_model(data, model, args, is_train): ...@@ -36,8 +36,12 @@ def _basic_model(data, model, args, is_train):
image = data[0] image = data[0]
label = data[1] label = data[1]
if args.model == "ResNet50": if args.model == "ResNet50":
image_in = fluid.layers.transpose(image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image image_in = fluid.layers.transpose(
net_out = model.net(input=image_in, class_dim=args.class_dim, data_format=args.data_format) image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image
image_in.stop_gradient = image.stop_gradient
net_out = model.net(input=image_in,
class_dim=args.class_dim,
data_format=args.data_format)
else: else:
net_out = model.net(input=image, class_dim=args.class_dim) net_out = model.net(input=image, class_dim=args.class_dim)
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False) softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
...@@ -92,8 +96,12 @@ def _mixup_model(data, model, args, is_train): ...@@ -92,8 +96,12 @@ def _mixup_model(data, model, args, is_train):
lam = data[3] lam = data[3]
if args.model == "ResNet50": if args.model == "ResNet50":
image_in = fluid.layers.transpose(image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image image_in = fluid.layers.transpose(
net_out = model.net(input=image_in, class_dim=args.class_dim, data_format=args.data_format) image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image
image_in.stop_gradient = image.stop_gradient
net_out = model.net(input=image_in,
class_dim=args.class_dim,
data_format=args.data_format)
else: else:
net_out = model.net(input=image, class_dim=args.class_dim) net_out = model.net(input=image, class_dim=args.class_dim)
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False) softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册