提交 8f849768 编写于 作者: C chengmo

fix list concat bug

上级 0432b71b
......@@ -144,10 +144,12 @@ class TdmTrainNet(object):
self.need_trace, self.need_detail)
tdm_fc_re = fluid.layers.reshape(tdm_fc, [-1, 2])
sample_label = fluid.layers.concat(sample_label, axis=1)
labels_reshape = fluid.layers.reshape(sample_label, [-1, 1])
cost, softmax_prob = fluid.layers.softmax_with_cross_entropy(
logits=tdm_fc_re, label=labels_reshape, return_softmax=True)
sample_mask = fluid.layers.concat(sample_mask, axis=1)
mask_reshape = fluid.layers.reshape(sample_mask, [-1, 1])
mask_index = fluid.layers.where(mask_reshape != 0)
mask_cost = fluid.layers.gather_nd(cost, mask_index)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册