提交 1047f81c 编写于 作者: D dengkaipeng

remove reader/metrics in attention_lstm.py

上级 dd6126ed
......@@ -31,7 +31,7 @@ class AttentionLSTM(ModelBase):
self.feature_num = self.cfg.MODEL.feature_num
self.feature_names = self.cfg.MODEL.feature_names
self.feature_dims = self.cfg.MODEL.feature_dims
self.class_num = self.cfg.MODEL.class_num
self.num_classes = self.cfg.MODEL.num_classes
self.embedding_size = self.cfg.MODEL.embedding_size
self.lstm_size = self.cfg.MODEL.lstm_size
......@@ -59,7 +59,7 @@ class AttentionLSTM(ModelBase):
shapes = []
for dim in self.feature_dims:
shapes.append([-1, dim])
shapes.append([-1, self.class_num]) # label
shapes.append([-1, self.num_classes]) # label
self.py_reader = fluid.layers.py_reader(
capacity=1024,
shapes=shapes,
......@@ -81,7 +81,7 @@ class AttentionLSTM(ModelBase):
self.label_input = None
else:
self.label_input = fluid.layers.data(
shape=[self.class_num], dtype='float32', name='label')
shape=[self.num_classes], dtype='float32', name='label')
def build_model(self):
att_outs = []
......@@ -108,7 +108,7 @@ class AttentionLSTM(ModelBase):
regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)))
self.logit = fluid.layers.fc(input=fc2, size=self.class_num, act=None, \
self.logit = fluid.layers.fc(input=fc2, size=self.num_classes, act=None, \
bias_attr=ParamAttr(regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)))
......@@ -149,20 +149,3 @@ class AttentionLSTM(ModelBase):
def weights_info(self):
return (None, None)
def create_dataset_args(self):
dataset_args = {}
dataset_args['num_classes'] = self.class_num
dataset_args['list'] = self.get_config_from_sec(self.mode, 'filelist')
if self.use_gpu and self.py_reader:
dataset_args['batch_size'] = int(self.batch_size / self.gpu_num)
else:
dataset_args['batch_size'] = self.batch_size
return dataset_args
def create_metrics_args(self):
metrics_args = {}
metrics_args['num_classes'] = self.class_num
metrics_args['topk'] = 20
return metrics_args
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册