提交 5e0a35f7 编写于 作者: G guosheng

Add Layer.__init__ in Input.__init__

上级 358f7852
...@@ -55,6 +55,7 @@ def extract_args(func): ...@@ -55,6 +55,7 @@ def extract_args(func):
class Input(fluid.dygraph.Layer): class Input(fluid.dygraph.Layer):
def __init__(self, shape=None, dtype=None, name=None): def __init__(self, shape=None, dtype=None, name=None):
super(Input, self).__init__()
self.shape = shape self.shape = shape
self.dtype = dtype self.dtype = dtype
self.name = name self.name = name
...@@ -429,7 +430,7 @@ class DynamicGraphAdapter(object): ...@@ -429,7 +430,7 @@ class DynamicGraphAdapter(object):
inputs = to_list(inputs) inputs = to_list(inputs)
if labels is not None: if labels is not None:
labels = to_list(labels) labels = to_list(labels)
outputs = self.model.forward(*[to_variable(x) for x in inputs]) outputs = self.model.forward(* [to_variable(x) for x in inputs])
losses = self.model._loss_function(outputs, labels) losses = self.model._loss_function(outputs, labels)
final_loss = fluid.layers.sum(losses) final_loss = fluid.layers.sum(losses)
final_loss.backward() final_loss.backward()
...@@ -444,7 +445,7 @@ class DynamicGraphAdapter(object): ...@@ -444,7 +445,7 @@ class DynamicGraphAdapter(object):
inputs = to_list(inputs) inputs = to_list(inputs)
if labels is not None: if labels is not None:
labels = to_list(labels) labels = to_list(labels)
outputs = self.model.forward(*[to_variable(x) for x in inputs]) outputs = self.model.forward(* [to_variable(x) for x in inputs])
if self.model._loss_function: if self.model._loss_function:
losses = self.model._loss_function(outputs, labels) losses = self.model._loss_function(outputs, labels)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册