提交 5d0e4a9c 编写于 作者: C chengmo

fix expand layer

上级 4806f9b4
......@@ -199,6 +199,10 @@ class DnnLayerClassifierNet(object):
def _expand_layer(self, input_layer, node, layer_idx):
input_layer_unsequeeze = fluid.layers.unsqueeze(
input=input_layer, axes=[1])
if self.is_test:
input_layer_expand = fluid.layers.expand(
input_layer_unsequeeze, expand_times=[1, node.shape[1], 1])
else:
input_layer_expand = fluid.layers.expand(
input_layer_unsequeeze, expand_times=[1, node[layer_idx].shape[1], 1])
return input_layer_expand
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册