提交 a96cefd6 编写于 作者: R rensilin

loss_function(*output)

Change-Id: Iecc2518dc91b88125d29958ae78d5a8af3eb7eee
上级 ec85081f
......@@ -30,10 +30,10 @@ class ModelBuilder:
list<Variable>: outputs
pass
def _loss_function(outputs):
def _loss_function(*outputs):
**This function is declared in the network_desc_path file, and will be set in initialize()**
Args:
outputs: the second result of inference()
*outputs: the second result of inference()
Returns:
Variable: loss
......@@ -97,7 +97,7 @@ class ModelBuilder:
with fluid.program_guard(main_program, startup_program):
inputs, outputs = self._inference()
test_program = main_program.clone(for_test=True)
loss, labels = self._loss_function(outputs)
loss, labels = self._loss_function(*outputs)
optimizer = fluid.optimizer.SGD(learning_rate=1.0)
params_grads = optimizer.backward(loss)
......
......@@ -32,10 +32,10 @@ def inference():
ctr_output = fluid.layers.fc(net, 1, act='sigmoid', name='ctr')
return [cvm_input], [ctr_output]
def loss_function(outputs):
def loss_function(ctr_output):
"""
Args:
outputs: the second result of inference()
*outputs: the second result of inference()
Returns:
Variable: loss
......@@ -43,7 +43,6 @@ def loss_function(outputs):
list<Variable>: labels
"""
# TODO: calc loss here
ctr_output, = outputs
label = fluid.layers.data(name='label_ctr', shape=ctr_output.shape, dtype='float32')
loss = fluid.layers.square_error_cost(input=ctr_output, label=label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册