diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 790cfc4746b1d34da413fa3c29a266f962c6dde6..e1e122091f7759b1a68f1f982bc2a35e8241f9f0 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -123,6 +123,15 @@ OperatorBase::OperatorBase(const std::string& type, CheckAllInputOutputSet(); } +std::vector OperatorBase::InputVars() const { + std::vector ret_val; + for (auto& o : outputs_) { + ret_val.reserve(ret_val.size() + o.second.size()); + ret_val.insert(ret_val.end(), o.second.begin(), o.second.end()); + } + return ret_val; +} + std::vector OperatorBase::OutputVars(bool has_intermediate) const { std::vector ret_val; if (has_intermediate) { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 9a98d4d3be0d1cb875d614b263f1e4365ede4113..4600b06009bcef7d0774d25b816aac4733f30795 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -94,11 +94,14 @@ class OperatorBase { const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Outputs() const { return outputs_; } + //! Get a input with argument's name described in `op_proto` std::string Input(const std::string& name) const; //! Get a input which has multiple variables. const std::vector& Inputs(const std::string& name) const; + std::vector InputVars() const; + //! Get a output with argument's name described in `op_proto` std::string Output(const std::string& name) const; //! Get an output which has multiple variables. @@ -311,9 +314,9 @@ class InferShapeContext { } template - std::vector MultiOutput(const std::string& name) const { + std::vector MultiOutput(const std::string& name) const { auto names = op_.Outputs(name); - std::vector res; + std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) { diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5805826ee8a555ca6dfc1ca81feaadffea9e1012 --- /dev/null +++ b/paddle/operators/sum_op.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sum_op.h" +#include + +namespace paddle { +namespace operators { +using framework::Tensor; + +class SumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto ins = ctx.MultiInput("X"); + auto *out = ctx.Output("Out"); + int N = ins.size(); + + auto in_dim = ins[0]->dims(); + + PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); + for (int i = 1; i < N; i++) { + auto dim = ins[i]->dims(); + PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape"); + } + out->Resize(in_dim); + } +}; + +class SumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "the input tensors of sum operator.").AsDuplicable(); + AddOutput("Out", "the output tensor of sum operator."); + AddComment(R"DOC( + Sum the input tensors. + )DOC"); + } +}; + +class SumGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto outputs = ctx.MultiOutput(framework::GradVarName("X")); + auto dims = ctx.Input(framework::GradVarName("Out"))->dims(); + for (auto output : outputs) { + output->Resize(dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp); +REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel); +REGISTER_OP_CPU_KERNEL(sum_grad, + ops::SumGradKernel); diff --git a/paddle/operators/sum_op.cu b/paddle/operators/sum_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..a465cf3659ba7c51338abadfc62962fb6755a39d --- /dev/null +++ b/paddle/operators/sum_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/sum_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel); +REGISTER_OP_GPU_KERNEL(sum_grad, + ops::SumGradKernel); diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0b1e9ebaa38d455fb5e3ce8c1a39cbbcdad9a940 --- /dev/null +++ b/paddle/operators/sum_op.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class SumKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto ins = context.MultiInput("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + auto place = context.GetEigenDevice(); + auto result = EigenVector::Flatten(*out); + + int N = ins.size(); + auto in = EigenVector::Flatten(*(ins[0])); + result.device(place) = in; + for (int i = 1; i < N; i++) { + auto in = EigenVector::Flatten(*(ins[i])); + result.device(place) = result + in; + } + } +}; + +template +class SumGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input(framework::GradVarName("Out")); + auto outs = context.MultiOutput(framework::GradVarName("X")); + for (auto out : outs) { + out->mutable_data(context.GetPlace()); + } + + auto place = context.GetEigenDevice(); + auto in = EigenVector::Flatten(*input); + for (auto out : outs) { + auto result = EigenVector::Flatten(*out); + result.device(place) = in; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 53985933ed143c7faba6f7e2a704445697c1f58e..db701a2a309b794d8a4238f7176e3986236ed087 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -51,6 +51,7 @@ USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); USE_OP(top_k); USE_OP(squared_l2_distance); +USE_OP(sum); namespace paddle { namespace framework { @@ -216,7 +217,10 @@ All parameter, weight, gradient are variables in Paddle. -> std::map> { return op.Outputs(); }) + .def("output_vars", + [](const OperatorBase &op) { return op.OutputVars(true); }) .def("inputs", [](const OperatorBase &op) { return op.Inputs(); }) + .def("input_vars", [](const OperatorBase &op) { return op.InputVars(); }) .def("__str__", &OperatorBase::DebugString) .def("no_intermediate_outputs", [](const OperatorBase &op) { return op.OutputVars(false); }) diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index c1585bcffcceb75292853018179066c9f614261e..4e91924a50cf6401d4002510e940ddc84edbe61a 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -142,8 +142,8 @@ def create_op_creation_method(op_proto): return OpInfo( method=__impl__, name=op_proto.type, - inputs=[var.name for var in op_proto.inputs], - outputs=[var.name for var in op_proto.outputs], + inputs=[(var.name, var.duplicable) for var in op_proto.inputs], + outputs=[(var.name, var.duplicable) for var in op_proto.outputs], attrs=[attr.name for attr in op_proto.attrs]) @@ -180,9 +180,15 @@ class OperatorFactory(object): return self.op_methods.get(t) def get_op_input_names(self, type): + return map(lambda x: x[0], self.get_op_info(type).inputs) + + def get_op_inputs(self, type): return self.get_op_info(type).inputs def get_op_output_names(self, type): + return map(lambda x: x[0], self.get_op_info(type).outputs) + + def get_op_outputs(self, type): return self.get_op_info(type).outputs def get_op_attr_names(self, type): diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ef910f939be0b9d3cb5e6d49e69a00daa191b1c6..2117fdf0d58520a008d2bd01d56d96dd248be025 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -33,5 +33,6 @@ py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_gradient_checker SRCS test_gradient_checker.py) py_test(test_lookup_table SRCS test_lookup_table.py) py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) +py_test(test_sum_op SRCS test_sum_op.py) py_test(mnist SRCS mnist.py) py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py) diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3a6a5dca4c4ddc1399d80e491e4072f24707c01e --- /dev/null +++ b/python/paddle/v2/framework/tests/op_test.py @@ -0,0 +1,275 @@ +import unittest +import numpy as np +import itertools +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator + + +def grad_var_name(var_name): + return var_name + "@GRAD" + + +def create_op(scope, op_type, inputs, outputs, attrs=None): + kwargs = dict() + + for in_name, in_dup in Operator.get_op_inputs(op_type): + if in_name in inputs: + kwargs[in_name] = [] + if in_dup: + sub_in = inputs[in_name] + for sub_in_name in sub_in: + var = scope.new_var(sub_in_name) + kwargs[in_name].append(sub_in_name) + else: + var = scope.new_var(in_name) + kwargs[in_name].append(in_name) + + for out_name, out_dup in Operator.get_op_outputs(op_type): + if out_name in outputs: + kwargs[out_name] = [] + if out_dup: + sub_in = outputs[out_name] + for sun_in_name in sub_in: + var = scope.new_var(sun_in_name) + kwargs[out_name].append(sun_in_name) + else: + var = scope.new_var(out_name) + kwargs[out_name].append(out_name) + + for attr_name in Operator.get_op_attr_names(op_type): + kwargs[attr_name] = attrs[attr_name] + return Operator(op_type, **kwargs) + + +def set_input(scope, op, inputs, place): + for in_name, in_dup in Operator.get_op_inputs(op.type()): + if in_name in inputs: + if in_dup: + sub_in = inputs[in_name] + for sub_in_name in sub_in: + var = scope.find_var(sub_in_name) + tensor = var.get_tensor() + arr = sub_in[sub_in_name] + tensor.set_dims(arr.shape) + tensor.set(arr, place) + else: + var = scope.find_var(in_name) + tensor = var.get_tensor() + arr = inputs[in_name] + tensor.set_dims(arr.shape) + tensor.set(arr, place) + + +def set_output_grad(scope, op, outputs, place): + for out_name, out_dup in Operator.get_op_outputs(op.type()): + if out_name in outputs: + if out_dup: + sub_out = outputs[out_name] + for sub_out_name in sub_out: + out_tensor = scope.find_var(sub_out_name).get_tensor() + grad_tensor = scope.new_var(grad_var_name( + sub_out_name)).get_tensor() + grad_tensor.set_dims(out_tensor.shape()) + data = np.ones(out_tensor.shape(), dtype=np.float32) + grad_tensor.set(data, place) + else: + out_tensor = scope.find_var(out_name).get_tensor() + grad_tensor = scope.new_var(grad_var_name(out_name)).get_tensor( + ) + grad_tensor.set_dims(out_tensor.shape()) + data = np.ones(out_tensor.shape(), dtype=np.float32) + grad_tensor.set(data, place) + + +def get_numeric_gradient(scope, + op, + inputs, + input_to_check, + output_name, + delta=0.005, + in_place=False): + + set_input(scope, op, inputs, core.CPUPlace()) + op.infer_shape(scope) + + tensor_to_check = scope.find_var(input_to_check).get_tensor() + + def product(dim): + return reduce(lambda a, b: a * b, dim, 1) + + ctx = core.DeviceContext.create(core.CPUPlace()) + + def get_output(): + op.run(scope, ctx) + return np.array(scope.find_var(output_name).get_tensor()).sum() + + tensor_to_check = scope.find_var(input_to_check).get_tensor() + tensor_size = product(tensor_to_check.get_dims()) + gradient_flat = np.zeros(shape=(tensor_size, ), dtype='float32') + # we only compute gradient of one element each time. + # we use a for loop to compute the gradient of every element. + for i in xrange(tensor_size): + if in_place: + set_input(op, inputs, core.CPUPlace()) + + # get one input element throw it's index i. + origin = tensor_to_check.get_float_element(i) + # add delta to it, run op and then get the sum of the result tensor. + x_pos = origin + delta + tensor_to_check.set_float_element(i, x_pos) + y_pos = get_output() + + if in_place: + set_input(op, inputs, core.CPUPlace()) + + x_neg = origin - delta + tensor_to_check.set_float_element(i, x_neg) + y_neg = get_output() + + tensor_to_check.set_float_element(i, origin) + gradient_flat[i] = (y_pos - y_neg) / delta / 2 + + return gradient_flat.reshape(tensor_to_check.get_dims()) + + +def get_backward_op(scope, op, no_grad_set): + backward_op = core.Operator.backward(op, no_grad_set) + for input in backward_op.input_vars(): + var = scope.new_var(input) + var.get_tensor() + for output in backward_op.output_vars(): + var = scope.new_var(output) + var.get_tensor() + return backward_op + + +def get_gradient(scope, op, inputs, outputs, grad_name, place, + no_grad_set=None): + ctx = core.DeviceContext.create(place) + + set_input(scope, op, inputs, place) + + op.infer_shape(scope) + op.run(scope, ctx) + + if no_grad_set is None: + no_grad_set = set() + + backward_op = get_backward_op(scope, op, no_grad_set) + set_output_grad(scope, op, outputs, place) + + backward_op.infer_shape(scope) + backward_op.run(scope, ctx) + + out = np.array(scope.find_var(grad_name).get_tensor()) + return out + + +class OpTest(unittest.TestCase): + def check_output_with_place(self, place): + self.scope = core.Scope() + self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs) + if isinstance(place, core.GPUPlace) and not self.op.support_gpu(): + return + set_input(self.scope, self.op, self.inputs, place) + self.op.infer_shape(self.scope) + ctx = core.DeviceContext.create(place) + self.op.run(self.scope, ctx) + + for out_name, out_dup in Operator.get_op_outputs(self.op.type()): + if out_dup: + sub_out = self.outputs[out_name] + for sub_out_name in sub_out: + actual = np.array( + self.scope.find_var(sub_out_name).get_tensor()) + expect = sub_out[sub_out_name] + self.assertTrue( + np.allclose( + actual, expect, atol=1e-05), + "output name: " + out_name + "has diff") + else: + actual = np.array(self.scope.find_var(out_name).get_tensor()) + expect = self.outputs[out_name] + self.assertTrue( + np.allclose( + actual, expect, atol=1e-05), + "output name: " + out_name + "has diff") + + def check_output(self): + places = [core.CPUPlace()] + if core.is_compile_gpu(): + places.append(core.GPUPlace(0)) + for place in places: + self.check_output_with_place(place) + + def __assert_is_close(self, numeric_grads, analytic_grads, names, + max_relative_error, msg_prefix): + + for a, b, name in itertools.izip(numeric_grads, analytic_grads, names): + abs_a = np.abs(a) + abs_a[abs_a < 1e-3] = 1 + + diff_mat = np.abs(a - b) / abs_a + max_diff = np.max(diff_mat) + + def err_msg(): + offset = np.argmax(diff_mat > max_relative_error) + return "%s Variable %s max gradient diff %f over limit %f, the first " \ + "error element is %d" % ( + msg_prefix, name, max_diff, max_relative_error, offset) + + self.assertLessEqual(max_diff, max_relative_error, err_msg()) + + def check_grad(self, + inputs_to_check, + output_name, + no_grad_set=None, + in_place=False, + max_relative_error=0.005): + self.scope = core.Scope() + self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs) + if no_grad_set is None: + no_grad_set = set() + + numeric_grads = [ + get_numeric_gradient( + self.scope, + self.op, + self.inputs, + input_to_check, + output_name, + in_place=in_place) for input_to_check in inputs_to_check + ] + grad_names = [ + grad_var_name(input_to_check) for input_to_check in inputs_to_check + ] + + cpu_place = core.CPUPlace() + cpu_analytic_grads = [ + get_gradient(self.scope, self.op, self.inputs, self.outputs, + grad_name, cpu_place, no_grad_set) + for grad_name in grad_names + ] + + self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names, + max_relative_error, + "Gradient Check On %s" % str(cpu_place)) + + if core.is_compile_gpu() and self.op.support_gpu(): + gpu_place = core.GPUPlace(0) + gpu_analytic_grads = [ + get_gradient(self.scope, self.op, self.inputs, self.outputs, + grad_name, gpu_place, no_grad_set) + for grad_name in grad_names + ] + + self.__assert_is_close(numeric_grads, gpu_analytic_grads, + grad_names, max_relative_error, + "Gradient Check On %s" % str(gpu_place)) + + for c_grad, g_grad, name in itertools.izip( + cpu_analytic_grads, gpu_analytic_grads, grad_names): + self.assertTrue( + np.allclose( + c_grad, g_grad, atol=1e-4), + "output name: " + name + " has diff") diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index d4277f2a42ce2e66e37405ccd3b2ee444d403d1a..fb6a440e23c26d1766bdf1fc5f24217afe1150f8 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -1,36 +1,27 @@ import unittest import numpy -from op_test_util import OpTestMeta -from gradient_checker import GradientChecker, create_op +from op_test import OpTest -class TestCrossEntropy(unittest.TestCase): - __metaclass__ = OpTestMeta - +class TestCrossEntropy(OpTest): def setUp(self): - self.type = "onehot_cross_entropy" + self.op_type = "onehot_cross_entropy" batch_size = 30 class_num = 10 - X = numpy.random.random((batch_size, class_num)).astype("float32") - label = 5 * numpy.ones(batch_size).astype("int32") + X = numpy.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label = (class_num / 2) * numpy.ones(batch_size).astype("int32") self.inputs = {'X': X, 'label': label} Y = [] for i in range(0, batch_size): Y.append(-numpy.log(X[i][label[i]])) self.outputs = {'Y': numpy.array(Y).astype("float32")} + def test_check_output(self): + self.check_output() -class CrossEntropyGradOpTest(GradientChecker): def test_check_grad(self): - op = create_op("onehot_cross_entropy") - batch_size = 30 - class_num = 10 - inputs = { - "X": numpy.random.uniform( - 0.1, 1.0, [batch_size, class_num]).astype("float32"), - "label": (class_num / 2) * numpy.ones(batch_size).astype("int32") - } - self.check_grad(op, inputs, set("X"), "Y") + self.check_grad(["X"], "Y") if __name__ == "__main__": diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 273c2e5ab1a84d12621fe9568c4cf22073b6aed4..2316e49eff7bb1cdb53acb3889a6ef05060b59f3 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -1,27 +1,21 @@ import unittest import numpy as np -from op_test_util import OpTestMeta -from gradient_checker import GradientChecker, create_op +from op_test import OpTest -class TestSigmoidOp(unittest.TestCase): - __metaclass__ = OpTestMeta - +class TestSigmoid(OpTest): def setUp(self): - self.type = "sigmoid" - self.inputs = {'X': np.random.random((15, 31)).astype("float32")} + self.op_type = "sigmoid" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} + def test_check_output(self): + self.check_output() -class TestSigmoidGradOp(GradientChecker): - def test_grad(self): - op = create_op("sigmoid") - inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")} - # compare gpu and cpu results for backward op. - # this test will be skiped if only compiling CPU version. - self.compare_grad(op, inputs) - # check gradients - self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007) + def test_check_grad(self): + self.check_grad(["X"], "Y", max_relative_error=0.007) if __name__ == '__main__': diff --git a/python/paddle/v2/framework/tests/test_sum_op.py b/python/paddle/v2/framework/tests/test_sum_op.py new file mode 100644 index 0000000000000000000000000000000000000000..66417d70e81186465e6f59a17fb62255afeddea5 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_sum_op.py @@ -0,0 +1,24 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestSumOp(OpTest): + def setUp(self): + self.op_type = "sum" + x0 = np.random.random((3, 4)).astype('float32') + x1 = np.random.random((3, 4)).astype('float32') + x2 = np.random.random((3, 4)).astype('float32') + self.inputs = {"X": {"x0": x0, "x1": x1, "x2": x2}} + y = x0 + x1 + x2 + self.outputs = {'Out': y} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["x0"], "Out") + + +if __name__ == '__main__': + unittest.main()