diff --git a/PaddleRec/tdm/tdm_demo/train_network.py b/PaddleRec/tdm/tdm_demo/train_network.py index 44c9efa11ab76f3d08ae9b39cbd4a2a9481e0729..ac32d88f61dc480f75ed3f53b43253853e983aeb 100644 --- a/PaddleRec/tdm/tdm_demo/train_network.py +++ b/PaddleRec/tdm/tdm_demo/train_network.py @@ -199,8 +199,12 @@ class DnnLayerClassifierNet(object): def _expand_layer(self, input_layer, node, layer_idx): input_layer_unsequeeze = fluid.layers.unsqueeze( input=input_layer, axes=[1]) - input_layer_expand = fluid.layers.expand( - input_layer_unsequeeze, expand_times=[1, node[layer_idx].shape[1], 1]) + if self.is_test: + input_layer_expand = fluid.layers.expand( + input_layer_unsequeeze, expand_times=[1, node.shape[1], 1]) + else: + input_layer_expand = fluid.layers.expand( + input_layer_unsequeeze, expand_times=[1, node[layer_idx].shape[1], 1]) return input_layer_expand def classifier_layer(self, input, node):