提交 7c13292c 编写于 作者: D dangqingqing

Fix bug for multi-GPU inference.

上级 55115ac6
......@@ -35,6 +35,13 @@ class Inference(object):
name = param.getName()
assert isinstance(val, api.Vector)
val.copyFromNumpyArray(parameters.get(name).flatten())
# the setValueUpdated function is called in randomize, zeroMem,
# load function in paddle/parameter/Parameter.cpp. But in the
# inference mode, the setValueUpdated is never called, it will
# cause the parameter will not be dispatched
# in MultiGradientMachine for multi-GPU. So setValueUpdated is
# called here, but it's better to call this function in one place.
param.setValueUpdated()
self.__gradient_machine__ = gm
self.__data_types__ = topo.data_type()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册