diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5d0f3c3724262a60a463ef3beadd9906d3ebaf6 --- /dev/null +++ b/paddle/operators/fc_op.cc @@ -0,0 +1,197 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class FCOp : public NetOp { + public: + FCOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + PADDLE_ENFORCE(!Inputs("X").empty(), + "Inputs(X) of FCOp should not be null."); + PADDLE_ENFORCE(!Inputs("W").empty(), + "Inputs(W) of FCOp should not be null."); + PADDLE_ENFORCE(!Outputs("MulOut").empty(), + "Outputs(MulOut) of FCOp should not be null."); + PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, + "Output(Out) of FCOp should not be null."); + + auto x = Inputs("X"); + auto w = Inputs("W"); + auto mul_out = Outputs("MulOut"); + PADDLE_ENFORCE_EQ( + x.size(), w.size(), + "The size of inputs X(%d) should be the same as that of weights W(%d).", + x.size(), w.size()); + PADDLE_ENFORCE_EQ(mul_out.size(), x.size(), + "The size of intermediate mul_out(%d) should be the same " + "as that of inputs X(%d).", + mul_out.size(), x.size()); + + size_t n = x.size(); + PADDLE_ENFORCE_GE(n, static_cast(1), + "The size of inputs X(%d) should be no less than 1.", n); + + auto x_num_col_dims = Attr>("xNumColDims"); + + // Set all values or set no values (use the default value) + if (!x_num_col_dims.empty()) { + PADDLE_ENFORCE_EQ(x_num_col_dims.size(), n, + "The size of attribute xNumColDims(%d) should be the " + "same as that of inputs X(%d).", + x_num_col_dims.size(), n); + } else { + x_num_col_dims.resize(n); + for (size_t i = 0; i < n; i++) { + x_num_col_dims[i] = 1; + } + } + + // mul_out[i] = X[i] * W[i] + for (size_t i = 0; i < n; i++) { + framework::AttributeMap mul_attr; + mul_attr["x_num_col_dims"] = static_cast(x_num_col_dims[i]); + mul_attr["y_num_col_dims"] = static_cast(1); + AppendOp( + framework::OpRegistry::CreateOp("mul", {{"X", {x[i]}}, {"Y", {w[i]}}}, + {{"Out", {mul_out[i]}}}, mul_attr)); + } + + // sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1] + auto sum_out = mul_out[0]; + if (n > 1) { + PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName, + "Output(SumOut) of FCOp should not be null when the " + "size of Inputs(X) > 1."); + + sum_out = Output("SumOut"); + AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}}, + {{"Out", {sum_out}}}, {})); + } else { + if (Output("SumOut") != framework::kEmptyVarName) { + this->Rename(Output("SumOut"), framework::kEmptyVarName); + } + } + + // add_out = sum_out + b + auto b = Input("B"); + auto add_out = sum_out; + if (b != framework::kEmptyVarName) { + PADDLE_ENFORCE_NE( + Output("AddOut"), framework::kEmptyVarName, + "Output(AddOut) of FCOp should not be null when Input(B) is set."); + + add_out = Output("AddOut"); + AppendOp(framework::OpRegistry::CreateOp( + "rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}}, + {{"Out", {add_out}}}, {})); + } else { + if (Output("AddOut") != framework::kEmptyVarName) { + this->Rename(Output("AddOut"), framework::kEmptyVarName); + } + } + + auto activation = Attr("activation"); + AppendOp(framework::OpRegistry::CreateOp(activation, {{"X", {add_out}}}, + {{"Y", {Output("Out")}}}, {})); + CompleteAddOp(false); + } +}; + +class FCOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FCOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(A vector of Tensors) each input Tensor can be of arbitrary " + "dimension, and will be reshaped to a 2-D matrix of size " + "(minibatch, number_of_input_features) according to attribute " + "xNumColDims.") + .AsDuplicable(); + AddInput("W", + "(A vector of Tensors) the weights of FC operator, a " + "vector of 2-D matrix of size " + "(number_of_input_features, number_of_neurons).") + .AsDuplicable(); + AddInput("B", + "(Tensor) the bias of FC operator, a 1-D vector of size " + "number_of_neurons."); + + AddOutput("Out", + "(Tensor) the activated output matrix of FC operator, a 2-D " + "matrix of size (minibatch, number_of_neurons)."); + AddOutput("MulOut", + "(A vector of Tensors) the intermediate outputs of FC operator, " + "each Tensor saving the product of X_i * W_i.") + .AsIntermediate() + .AsDuplicable(); + AddOutput( + "SumOut", + "(Tensor) the intermediate output of FC operator, " + "saving the sum of the products of X and W, that is sum{X_i * W_i}.") + .AsIntermediate(); + AddOutput("AddOut", + "(Tensor) the non-actived output of FC operator, " + "saving sum{X_i * W_i} + B.") + .AsIntermediate(); + AddAttr( + "activation", + "(string, default identity) the activation type of FC operator.") + .SetDefault("identity") + .InEnum({"identity", "sigmoid", "softmax"}); + AddAttr>( + "xNumColDims", + "(std::vector) The inputs Tensors of FC operator can be of " + "more than 2 dimensions. In that case, each input Tensor `X_i` will be " + "reshaped to a 2-D matrix. The matrix's first dimension " + "(the length of column) will be the product of `X_i`'s last " + "`xNumColDims_i` dimensions, that is " + "`X_i.dims[0] x ... x X_i.dims[xNumColDims_i - 1]`. " + "The matrix's second dimension (the length of row) will be the product " + "of `X_i`'s first `rank - xNumColDims_i` dimensions, that is " + "`X_i.dims[xNumColDims_i] x ... x X_i.dims[rank - 1]`)") + .SetDefault(std::vector{}); + + AddComment(R"DOC( +Fully Connected Operator, known as Fully Connected Layer or Inner Product Layer +in Convolutional Neural Networks. Neurons in a fully connected layer have +full connections to all activations in the previous layer. +It computes an inner product of a set of +learned weights with a matrix multiplication followed by a bias offset +(optionally). + +Equation: + Out = Act(sum_n{X_i * W_i} + B) + +where X_i is Tensor that will be reshaped to a 2-D matrix of size (M x K), +usually M is the minibatch size and K is the number of input features. +W_i is a 2-D matrix of size (K x N), where N means the number of neurons +in the fully connected layer. B is a 1-D vector of size N. +Thus, the output Out is a 2-D matrix of size (M x N). +Activation type can be set to `identity` (default), `sigmoid` or `softmax`. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fc, ops::FCOp, ops::FCOpMaker); diff --git a/paddle/operators/identity_op.cc b/paddle/operators/identity_op.cc index b67ca5f6f8d516224e18a5eed497f2bfc680259c..2cc632205e63abbe412b09af4b894420ac512ec5 100644 --- a/paddle/operators/identity_op.cc +++ b/paddle/operators/identity_op.cc @@ -27,7 +27,7 @@ class IdentityOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input tensor of identity operator."); - AddOutput("Out", "The output tensor of identity operator."); + AddOutput("Y", "The output tensor of identity operator."); AddComment(R"DOC( The identity operator is an alias of the scale operator with the attribute scale fixed to 1.0. @@ -44,12 +44,13 @@ class IdentityOp : public NetOp { : NetOp(type, inputs, outputs, attrs) { PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, "Input(X) of IdentityOp should not be null."); - PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, - "Output(Out) of IdentityOp should not be null."); + PADDLE_ENFORCE_NE(Output("Y"), framework::kEmptyVarName, + "Output(Y) of IdentityOp should not be null."); AppendOp(framework::OpRegistry::CreateOp( - "scale", {{"X", {Input("X")}}}, {{"Out", {Output("Out")}}}, + "scale", {{"X", {Input("X")}}}, {{"Out", {Output("Y")}}}, {{"scale", static_cast(1)}})); + CompleteAddOp(false); } }; diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index ecf8a6f7795314e2475bb9546b55b8f354b96366..a97bbecdca1779df330d1053cf359bb658aa75c2 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -71,7 +71,7 @@ class MinusGradOp : public NetOp { // x_grad = out_grad AppendOp(framework::OpRegistry::CreateOp("identity", {{"X", {out_grad}}}, - {{"Out", {x_grad}}}, {})); + {{"Y", {x_grad}}}, {})); framework::AttributeMap scale_attr; scale_attr["scale"] = static_cast(-1); diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 00030050700bfb2cee224124d090b0027d456ba0..4f05406c7f74113d8fb10aa6914166e553858338 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,5 +1,5 @@ if(WITH_PYTHON) -cc_library(paddle_pybind SHARED + cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python backward ${GLOB_OP_LIB}) diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index a0533efacdcc0386c0c3ab4691dc74a43435b4e4..75b689982a5b797dd8b5b9ee868b2b3676278f4e 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -28,10 +28,10 @@ def create_op(scope, op_type, inputs, outputs, attrs): if out_name in outputs: kwargs[out_name] = [] if out_dup: - sub_in = outputs[out_name] - for sub_in_name, _ in sub_in: - var = scope.new_var(sub_in_name) - kwargs[out_name].append(sub_in_name) + sub_out = outputs[out_name] + for sub_out_name, _ in sub_out: + var = scope.new_var(sub_out_name) + kwargs[out_name].append(sub_out_name) else: var = scope.new_var(out_name) kwargs[out_name].append(out_name) @@ -39,6 +39,7 @@ def create_op(scope, op_type, inputs, outputs, attrs): for attr_name in Operator.get_op_attr_names(op_type): if attr_name in attrs: kwargs[attr_name] = attrs[attr_name] + return Operator(op_type, **kwargs) @@ -179,8 +180,9 @@ class OpTest(unittest.TestCase): def check_output_with_place(self, place): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() + op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict() - self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs, + self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, op_attrs) if isinstance(place, core.GPUPlace) and not self.op.support_gpu(): return @@ -192,21 +194,23 @@ class OpTest(unittest.TestCase): for out_name, out_dup in Operator.get_op_outputs(self.op.type()): if out_dup: sub_out = self.outputs[out_name] - for sub_out_name in sub_out: + for sub_out_name, sub_out_array in sub_out: actual = np.array( self.scope.find_var(sub_out_name).get_tensor()) - expect = sub_out[sub_out_name] + expect = sub_out_array self.assertTrue( np.allclose( actual, expect, atol=1e-05), - "output name: " + out_name + "has diff") + "output name: " + out_name + " has diff") else: - actual = np.array(self.scope.find_var(out_name).get_tensor()) - expect = self.outputs[out_name] - self.assertTrue( - np.allclose( - actual, expect, atol=1e-05), - "output name: " + out_name + "has diff") + var = self.scope.find_var(out_name) + if var is not None: + actual = np.array(var.get_tensor()) + expect = self.outputs[out_name] + self.assertTrue( + np.allclose( + actual, expect, atol=1e-05), + "output name: " + out_name + " has diff") def check_output(self): places = [core.CPUPlace()] @@ -241,8 +245,9 @@ class OpTest(unittest.TestCase): max_relative_error=0.005): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() + op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict() - self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs, + self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, op_attrs) if no_grad_set is None: no_grad_set = set() diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9f56fe5049c66aa5fce40ce815105e7871ebc3b2 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fc_op.py @@ -0,0 +1,62 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestFCOp1(OpTest): + def setUp(self): + x0 = np.random.random((16, 32)).astype("float32") + w0 = np.random.random((32, 10)).astype("float32") + + mul_out0 = np.dot(x0, w0) + identity_out = mul_out0 + + self.op_type = "fc" + self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)]} + self.outputs = {"MulOut": [("MulOut0", mul_out0)], "Out": identity_out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X0", "W0"], "Out", max_relative_error=0.01) + + +class TestFCOp2(OpTest): + def setUp(self): + x0 = np.random.random((16, 4, 8)).astype("float32") + x1 = np.random.random((4, 4, 32)).astype("float32") + w0 = np.random.random((32, 10)).astype("float32") + w1 = np.random.random((32, 10)).astype("float32") + b = np.random.random(10).astype("float32") + + mul_out0 = np.dot(x0.reshape(16, 4 * 8), w0) + mul_out1 = np.dot(x1.reshape(4 * 4, 32), w1) + sum_out = mul_out0 + mul_out1 + add_out = np.add(sum_out, b) + sigmoid_out = 1 / (1 + np.exp(-add_out)) + + self.op_type = "fc" + self.inputs = { + "X": [("X0", x0), ("X1", x1)], + "W": [("W0", w0), ("W1", w1)], + "B": b + } + self.attrs = {"xNumColDims": [1, 2], "activation": "sigmoid"} + self.outputs = { + "MulOut": [("MulOut0", mul_out0), ("MulOut1", mul_out1)], + "SumOut": sum_out, + "AddOut": add_out, + "Out": sigmoid_out + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ["X0", "X1", "W0", "W1", "B"], "Out", max_relative_error=0.01) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_identity_op.py b/python/paddle/v2/framework/tests/test_identity_op.py index 2e95e7c786e3ff99a04b28218ec5b5decf531360..26cec1fcc3ad003281c9c41571d475b55bd30026 100644 --- a/python/paddle/v2/framework/tests/test_identity_op.py +++ b/python/paddle/v2/framework/tests/test_identity_op.py @@ -7,13 +7,13 @@ class TestIdentityOp(OpTest): def setUp(self): self.op_type = "identity" self.inputs = {'X': np.random.random((10, 10)).astype("float32")} - self.outputs = {'Out': self.inputs['X']} + self.outputs = {'Y': self.inputs['X']} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Y') if __name__ == "__main__":