提交 989e8358 编写于 作者: L Liu Yiqun

Reuse the output of mul when there is only one input in FCOp.

上级 fe2ab2ee
......@@ -66,22 +66,25 @@ class FCOp : public NetOp {
}
// sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1]
auto sum_out = mul_out[0];
if (n > 1) {
AppendOp(framework::OpRegistry::CreateOp(
"sum", {{"X", {mul_out}}}, {{"Out", {Output("SumOut")}}}, {}));
sum_out = Output("SumOut");
AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}},
{{"Out", {sum_out}}}, {}));
} else {
AppendOp(framework::OpRegistry::CreateOp(
"identity", {{"X", {mul_out[0]}}}, {{"Y", {Output("SumOut")}}}, {}));
if (Output("SumOut") != framework::kEmptyVarName) {
this->Rename(Output("SumOut"), framework::kEmptyVarName);
}
}
// add_out = sum_out + b
auto b = Input("B");
std::string add_out = "SumOut";
auto add_out = sum_out;
if (b != framework::kEmptyVarName) {
add_out = "AddOut";
add_out = Output("AddOut");
AppendOp(framework::OpRegistry::CreateOp(
"rowwise_add", {{"X", {Output("SumOut")}}, {"b", {Input("B")}}},
{{"Out", {Output(add_out)}}}, {}));
"rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}},
{{"Out", {add_out}}}, {}));
} else {
if (Output("AddOut") != framework::kEmptyVarName) {
this->Rename(Output("AddOut"), framework::kEmptyVarName);
......@@ -89,8 +92,8 @@ class FCOp : public NetOp {
}
auto activation = Attr<std::string>("activation");
AppendOp(framework::OpRegistry::CreateOp(
activation, {{"X", {Output(add_out)}}}, {{"Y", {Output("Out")}}}, {}));
AppendOp(framework::OpRegistry::CreateOp(activation, {{"X", {add_out}}},
{{"Y", {Output("Out")}}}, {}));
CompleteAddOp(false);
}
};
......
......@@ -193,12 +193,14 @@ class OpTest(unittest.TestCase):
actual, expect, atol=1e-05),
"output name: " + out_name + " has diff")
else:
actual = np.array(self.scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name]
self.assertTrue(
np.allclose(
actual, expect, atol=1e-05),
"output name: " + out_name + " has diff")
var = self.scope.find_var(out_name)
if var is not None:
actual = np.array(var.get_tensor())
expect = self.outputs[out_name]
self.assertTrue(
np.allclose(
actual, expect, atol=1e-05),
"output name: " + out_name + " has diff")
def check_output(self):
places = [core.CPUPlace()]
......
......@@ -7,27 +7,19 @@ class TestFCOp1(OpTest):
def setUp(self):
x0 = np.random.random((16, 32)).astype("float32")
w0 = np.random.random((32, 10)).astype("float32")
b = np.random.random(10).astype("float32")
mul_out0 = np.dot(x0, w0)
sum_out = mul_out0
add_out = sum_out + b
identity_out = add_out
identity_out = mul_out0
self.op_type = "fc"
self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)], "B": b}
self.outputs = {
"MulOut": [("MulOut0", mul_out0)],
"SumOut": sum_out,
"AddOut": add_out,
"Out": identity_out
}
self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)]}
self.outputs = {"MulOut": [("MulOut0", mul_out0)], "Out": identity_out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X0", "W0", "B"], "Out", max_relative_error=0.01)
self.check_grad(["X0", "W0"], "Out", max_relative_error=0.01)
class TestFCOp2(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册