From 989e8358b3803c5c15ae4ec0a3fa93fd7b915302 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 14 Sep 2017 06:50:08 +0000 Subject: [PATCH] Reuse the output of mul when there is only one input in FCOp. --- paddle/operators/fc_op.cc | 23 +++++++++++-------- python/paddle/v2/framework/tests/op_test.py | 14 ++++++----- .../paddle/v2/framework/tests/test_fc_op.py | 16 ++++--------- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc index 14a7fa8467..5549a836c9 100644 --- a/paddle/operators/fc_op.cc +++ b/paddle/operators/fc_op.cc @@ -66,22 +66,25 @@ class FCOp : public NetOp { } // sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1] + auto sum_out = mul_out[0]; if (n > 1) { - AppendOp(framework::OpRegistry::CreateOp( - "sum", {{"X", {mul_out}}}, {{"Out", {Output("SumOut")}}}, {})); + sum_out = Output("SumOut"); + AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}}, + {{"Out", {sum_out}}}, {})); } else { - AppendOp(framework::OpRegistry::CreateOp( - "identity", {{"X", {mul_out[0]}}}, {{"Y", {Output("SumOut")}}}, {})); + if (Output("SumOut") != framework::kEmptyVarName) { + this->Rename(Output("SumOut"), framework::kEmptyVarName); + } } // add_out = sum_out + b auto b = Input("B"); - std::string add_out = "SumOut"; + auto add_out = sum_out; if (b != framework::kEmptyVarName) { - add_out = "AddOut"; + add_out = Output("AddOut"); AppendOp(framework::OpRegistry::CreateOp( - "rowwise_add", {{"X", {Output("SumOut")}}, {"b", {Input("B")}}}, - {{"Out", {Output(add_out)}}}, {})); + "rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}}, + {{"Out", {add_out}}}, {})); } else { if (Output("AddOut") != framework::kEmptyVarName) { this->Rename(Output("AddOut"), framework::kEmptyVarName); @@ -89,8 +92,8 @@ class FCOp : public NetOp { } auto activation = Attr("activation"); - AppendOp(framework::OpRegistry::CreateOp( - activation, {{"X", {Output(add_out)}}}, {{"Y", {Output("Out")}}}, {})); + AppendOp(framework::OpRegistry::CreateOp(activation, {{"X", {add_out}}}, + {{"Y", {Output("Out")}}}, {})); CompleteAddOp(false); } }; diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index c6e4c59881..41690961b5 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -193,12 +193,14 @@ class OpTest(unittest.TestCase): actual, expect, atol=1e-05), "output name: " + out_name + " has diff") else: - actual = np.array(self.scope.find_var(out_name).get_tensor()) - expect = self.outputs[out_name] - self.assertTrue( - np.allclose( - actual, expect, atol=1e-05), - "output name: " + out_name + " has diff") + var = self.scope.find_var(out_name) + if var is not None: + actual = np.array(var.get_tensor()) + expect = self.outputs[out_name] + self.assertTrue( + np.allclose( + actual, expect, atol=1e-05), + "output name: " + out_name + " has diff") def check_output(self): places = [core.CPUPlace()] diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py index ed8d869a40..9f56fe5049 100644 --- a/python/paddle/v2/framework/tests/test_fc_op.py +++ b/python/paddle/v2/framework/tests/test_fc_op.py @@ -7,27 +7,19 @@ class TestFCOp1(OpTest): def setUp(self): x0 = np.random.random((16, 32)).astype("float32") w0 = np.random.random((32, 10)).astype("float32") - b = np.random.random(10).astype("float32") mul_out0 = np.dot(x0, w0) - sum_out = mul_out0 - add_out = sum_out + b - identity_out = add_out + identity_out = mul_out0 self.op_type = "fc" - self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)], "B": b} - self.outputs = { - "MulOut": [("MulOut0", mul_out0)], - "SumOut": sum_out, - "AddOut": add_out, - "Out": identity_out - } + self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)]} + self.outputs = {"MulOut": [("MulOut0", mul_out0)], "Out": identity_out} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(["X0", "W0", "B"], "Out", max_relative_error=0.01) + self.check_grad(["X0", "W0"], "Out", max_relative_error=0.01) class TestFCOp2(OpTest): -- GitLab