未验证 提交 bde51fde 编写于 作者: M moneypi 提交者: GitHub

fix InvalidArgumentError while training slimface (#1077)

上级 3eee512e
...@@ -334,7 +334,7 @@ class SlimFaceNet(): ...@@ -334,7 +334,7 @@ class SlimFaceNet():
else: else:
pass pass
one_hot = fluid.one_hot(input=label, depth=out_dim) one_hot = fluid.layers.one_hot(input=label, depth=out_dim)
output = fluid.layers.elementwise_mul( output = fluid.layers.elementwise_mul(
one_hot, phi) + fluid.layers.elementwise_mul( one_hot, phi) + fluid.layers.elementwise_mul(
(1.0 - one_hot), cosine) (1.0 - one_hot), cosine)
...@@ -367,7 +367,8 @@ def SlimFaceNet_C_x0_75(class_dim=None, scale=0.6, arch=None): ...@@ -367,7 +367,8 @@ def SlimFaceNet_C_x0_75(class_dim=None, scale=0.6, arch=None):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
x = fluid.data(name='x', shape=[-1, 3, 112, 112], dtype='float32') x = fluid.data(name='x', shape=[-1, 3, 112, 112], dtype='float32')
print(x.shape) print(x.shape)
model = SlimFaceNet(10000, [1, 3, 3, 1, 1, 0, 0, 1, 0, 1, 1, 0, 5, 5, 3]) model = SlimFaceNet(10000, arch=[1, 3, 3, 1, 1, 0, 0, 1, 0, 1, 1, 0, 5, 5, 3])
y = model.net(x) y = model.net(x)
...@@ -29,6 +29,7 @@ from lfw_eval import parse_filelist, evaluation_10_fold ...@@ -29,6 +29,7 @@ from lfw_eval import parse_filelist, evaluation_10_fold
from paddleslim import models from paddleslim import models
from paddleslim.quant import quant_post_static from paddleslim.quant import quant_post_static
paddle.enable_static()
def now(): def now():
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册