提交 989e8358 编写于 作者: L Liu Yiqun

Reuse the output of mul when there is only one input in FCOp.

上级 fe2ab2ee
...@@ -66,22 +66,25 @@ class FCOp : public NetOp { ...@@ -66,22 +66,25 @@ class FCOp : public NetOp {
} }
// sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1] // sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1]
auto sum_out = mul_out[0];
if (n > 1) { if (n > 1) {
AppendOp(framework::OpRegistry::CreateOp( sum_out = Output("SumOut");
"sum", {{"X", {mul_out}}}, {{"Out", {Output("SumOut")}}}, {})); AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}},
{{"Out", {sum_out}}}, {}));
} else { } else {
AppendOp(framework::OpRegistry::CreateOp( if (Output("SumOut") != framework::kEmptyVarName) {
"identity", {{"X", {mul_out[0]}}}, {{"Y", {Output("SumOut")}}}, {})); this->Rename(Output("SumOut"), framework::kEmptyVarName);
}
} }
// add_out = sum_out + b // add_out = sum_out + b
auto b = Input("B"); auto b = Input("B");
std::string add_out = "SumOut"; auto add_out = sum_out;
if (b != framework::kEmptyVarName) { if (b != framework::kEmptyVarName) {
add_out = "AddOut"; add_out = Output("AddOut");
AppendOp(framework::OpRegistry::CreateOp( AppendOp(framework::OpRegistry::CreateOp(
"rowwise_add", {{"X", {Output("SumOut")}}, {"b", {Input("B")}}}, "rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}},
{{"Out", {Output(add_out)}}}, {})); {{"Out", {add_out}}}, {}));
} else { } else {
if (Output("AddOut") != framework::kEmptyVarName) { if (Output("AddOut") != framework::kEmptyVarName) {
this->Rename(Output("AddOut"), framework::kEmptyVarName); this->Rename(Output("AddOut"), framework::kEmptyVarName);
...@@ -89,8 +92,8 @@ class FCOp : public NetOp { ...@@ -89,8 +92,8 @@ class FCOp : public NetOp {
} }
auto activation = Attr<std::string>("activation"); auto activation = Attr<std::string>("activation");
AppendOp(framework::OpRegistry::CreateOp( AppendOp(framework::OpRegistry::CreateOp(activation, {{"X", {add_out}}},
activation, {{"X", {Output(add_out)}}}, {{"Y", {Output("Out")}}}, {})); {{"Y", {Output("Out")}}}, {}));
CompleteAddOp(false); CompleteAddOp(false);
} }
}; };
......
...@@ -193,12 +193,14 @@ class OpTest(unittest.TestCase): ...@@ -193,12 +193,14 @@ class OpTest(unittest.TestCase):
actual, expect, atol=1e-05), actual, expect, atol=1e-05),
"output name: " + out_name + " has diff") "output name: " + out_name + " has diff")
else: else:
actual = np.array(self.scope.find_var(out_name).get_tensor()) var = self.scope.find_var(out_name)
expect = self.outputs[out_name] if var is not None:
self.assertTrue( actual = np.array(var.get_tensor())
np.allclose( expect = self.outputs[out_name]
actual, expect, atol=1e-05), self.assertTrue(
"output name: " + out_name + " has diff") np.allclose(
actual, expect, atol=1e-05),
"output name: " + out_name + " has diff")
def check_output(self): def check_output(self):
places = [core.CPUPlace()] places = [core.CPUPlace()]
......
...@@ -7,27 +7,19 @@ class TestFCOp1(OpTest): ...@@ -7,27 +7,19 @@ class TestFCOp1(OpTest):
def setUp(self): def setUp(self):
x0 = np.random.random((16, 32)).astype("float32") x0 = np.random.random((16, 32)).astype("float32")
w0 = np.random.random((32, 10)).astype("float32") w0 = np.random.random((32, 10)).astype("float32")
b = np.random.random(10).astype("float32")
mul_out0 = np.dot(x0, w0) mul_out0 = np.dot(x0, w0)
sum_out = mul_out0 identity_out = mul_out0
add_out = sum_out + b
identity_out = add_out
self.op_type = "fc" self.op_type = "fc"
self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)], "B": b} self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)]}
self.outputs = { self.outputs = {"MulOut": [("MulOut0", mul_out0)], "Out": identity_out}
"MulOut": [("MulOut0", mul_out0)],
"SumOut": sum_out,
"AddOut": add_out,
"Out": identity_out
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X0", "W0", "B"], "Out", max_relative_error=0.01) self.check_grad(["X0", "W0"], "Out", max_relative_error=0.01)
class TestFCOp2(OpTest): class TestFCOp2(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册