提交 7a7843c7 编写于 作者: C chengmo

add evaluate only choice

上级 9b961565
...@@ -55,7 +55,12 @@ class SingleTrainer(TranspileTrainer): ...@@ -55,7 +55,12 @@ class SingleTrainer(TranspileTrainer):
if metrics: if metrics:
self.fetch_vars = metrics.values() self.fetch_vars = metrics.values()
self.fetch_alias = metrics.keys() self.fetch_alias = metrics.keys()
context['status'] = 'startup_pass' evaluate_only = envs.get_global_env(
'evaluate_only', False, namespace='evaluate')
if evaluate_only:
context['status'] = 'infer_pass'
else:
context['status'] = 'startup_pass'
def startup(self, context): def startup(self, context):
self._exe.run(fluid.default_startup_program()) self._exe.run(fluid.default_startup_program())
......
...@@ -229,8 +229,17 @@ class TranspileTrainer(Trainer): ...@@ -229,8 +229,17 @@ class TranspileTrainer(Trainer):
metrics_format = ", ".join(metrics_format) metrics_format = ", ".join(metrics_format)
self._exe.run(startup_program) self._exe.run(startup_program)
for (epoch, model_dir) in self.increment_models: model_list = self.increment_models
print("Begin to infer epoch {}, model_dir: {}".format(epoch, model_dir))
evaluate_only = envs.get_global_env(
'evaluate_only', False, namespace='evaluate')
if evaluate_only:
model_list = [(0, envs.get_global_env(
'evaluate_model_path', "", namespace='evaluate'))]
for (epoch, model_dir) in model_list:
print("Begin to infer No.{} model, model_dir: {}".format(
epoch, model_dir))
program = infer_program.clone() program = infer_program.clone()
fluid.io.load_persistables(self._exe, model_dir, program) fluid.io.load_persistables(self._exe, model_dir, program)
reader.start() reader.start()
......
...@@ -12,7 +12,11 @@ ...@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
evaluate: evaluate:
workspace: "paddlerec.models.recall.word2vec" workspace: "paddlerec.models.recall.word2vec"
evaluate_only: False
evaluate_model_path: ""
reader: reader:
batch_size: 50 batch_size: 50
class: "{workspace}/w2v_evaluate_reader.py" class: "{workspace}/w2v_evaluate_reader.py"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册