提交 c30d2ac0 编写于 作者: M mir-of

fix data loader transpose logic

上级 0ae082d3
......@@ -17,7 +17,7 @@ def conv2d_layer(
weight_initializer=flow.random_uniform_initializer(),
bias_initializer=flow.constant_initializer(),
):
weight_shape = (filters, input.static_shape[1], kernel_size, kernel_size)
weight_shape = (filters, input.shape[1], kernel_size, kernel_size)
weight = flow.get_variable(
name + "-weight",
shape=weight_shape,
......
......@@ -60,8 +60,9 @@ def TrainNet():
print("Loading synthetic data.")
(labels, images) = ofrecord_util.load_synthetic(args)
logits = model_dict[args.model](
images, need_transpose=not args.use_new_dataloader)
images, need_transpose=False if (args.use_new_dataloader and args.train_data_dir) else True)
loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
labels, logits, name="softmax_loss")
loss = flow.math.reduce_mean(loss)
......@@ -85,7 +86,7 @@ def InferenceNet():
(labels, images) = ofrecord_util.load_synthetic(args)
logits = model_dict[args.model](
images, need_transpose=not args.use_new_dataloader)
images, need_transpose=False if (args.use_new_dataloader and args.train_data_dir) else True)
predictions = flow.nn.softmax(logits)
outputs = {"predictions": predictions, "labels": labels}
return outputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册