From bde51fdef5371673360bd9e0091a37128c5fbb7a Mon Sep 17 00:00:00 2001 From: moneypi <1483586698@qq.com> Date: Thu, 5 May 2022 18:31:48 +0800 Subject: [PATCH] fix InvalidArgumentError while training slimface (#1077) --- demo/models/slimfacenet.py | 5 +++-- demo/slimfacenet/train_eval.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/demo/models/slimfacenet.py b/demo/models/slimfacenet.py index 33a4deab..cb9353c7 100644 --- a/demo/models/slimfacenet.py +++ b/demo/models/slimfacenet.py @@ -334,7 +334,7 @@ class SlimFaceNet(): else: pass - one_hot = fluid.one_hot(input=label, depth=out_dim) + one_hot = fluid.layers.one_hot(input=label, depth=out_dim) output = fluid.layers.elementwise_mul( one_hot, phi) + fluid.layers.elementwise_mul( (1.0 - one_hot), cosine) @@ -367,7 +367,8 @@ def SlimFaceNet_C_x0_75(class_dim=None, scale=0.6, arch=None): if __name__ == "__main__": + paddle.enable_static() x = fluid.data(name='x', shape=[-1, 3, 112, 112], dtype='float32') print(x.shape) - model = SlimFaceNet(10000, [1, 3, 3, 1, 1, 0, 0, 1, 0, 1, 1, 0, 5, 5, 3]) + model = SlimFaceNet(10000, arch=[1, 3, 3, 1, 1, 0, 0, 1, 0, 1, 1, 0, 5, 5, 3]) y = model.net(x) diff --git a/demo/slimfacenet/train_eval.py b/demo/slimfacenet/train_eval.py index 77b366b9..82166305 100644 --- a/demo/slimfacenet/train_eval.py +++ b/demo/slimfacenet/train_eval.py @@ -29,6 +29,7 @@ from lfw_eval import parse_filelist, evaluation_10_fold from paddleslim import models from paddleslim.quant import quant_post_static +paddle.enable_static() def now(): return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) -- GitLab