提交 31f00c62 编写于 作者: G gx_wind

fix gradient bugs

上级 e7045ad6
...@@ -49,7 +49,8 @@ class PaddleModel(Model): ...@@ -49,7 +49,8 @@ class PaddleModel(Model):
loss = self._program.block(0).var(self._cost_name) loss = self._program.block(0).var(self._cost_name)
param_grads = fluid.backward.append_backward( param_grads = fluid.backward.append_backward(
loss, parameter_list=[self._input_name]) loss, parameter_list=[self._input_name])
self._gradient = dict(param_grads)[self._input_name] self._gradient = filter(lambda p: p[0].name == self._input_name,
param_grads)[0][1]
def predict(self, image_batch): def predict(self, image_batch):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册