提交 7a7843c7 编写于 作者: C chengmo

add evaluate only choice

上级 9b961565
......@@ -55,7 +55,12 @@ class SingleTrainer(TranspileTrainer):
if metrics:
self.fetch_vars = metrics.values()
self.fetch_alias = metrics.keys()
context['status'] = 'startup_pass'
evaluate_only = envs.get_global_env(
'evaluate_only', False, namespace='evaluate')
if evaluate_only:
context['status'] = 'infer_pass'
else:
context['status'] = 'startup_pass'
def startup(self, context):
self._exe.run(fluid.default_startup_program())
......
......@@ -229,8 +229,17 @@ class TranspileTrainer(Trainer):
metrics_format = ", ".join(metrics_format)
self._exe.run(startup_program)
for (epoch, model_dir) in self.increment_models:
print("Begin to infer epoch {}, model_dir: {}".format(epoch, model_dir))
model_list = self.increment_models
evaluate_only = envs.get_global_env(
'evaluate_only', False, namespace='evaluate')
if evaluate_only:
model_list = [(0, envs.get_global_env(
'evaluate_model_path', "", namespace='evaluate'))]
for (epoch, model_dir) in model_list:
print("Begin to infer No.{} model, model_dir: {}".format(
epoch, model_dir))
program = infer_program.clone()
fluid.io.load_persistables(self._exe, model_dir, program)
reader.start()
......
......@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
evaluate:
workspace: "paddlerec.models.recall.word2vec"
workspace: "paddlerec.models.recall.word2vec"
evaluate_only: False
evaluate_model_path: ""
reader:
batch_size: 50
class: "{workspace}/w2v_evaluate_reader.py"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册