From f314330c238388dd25a8be27b138a81bee2d8bfb Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 5 Sep 2017 19:32:47 +0800 Subject: [PATCH] refactor operator python test and add sum operator --- paddle/framework/operator.h | 25 +- paddle/operators/sum_op.cc | 69 +++++ paddle/operators/sum_op.cu | 18 ++ paddle/operators/sum_op.h | 65 +++++ paddle/pybind/pybind.cc | 9 + python/paddle/v2/framework/op.py | 10 +- .../paddle/v2/framework/tests/CMakeLists.txt | 1 + python/paddle/v2/framework/tests/op_test.py | 271 ++++++++++++++++++ .../framework/tests/test_cross_entropy_op.py | 30 +- .../v2/framework/tests/test_sigmoid_op.py | 28 +- .../paddle/v2/framework/tests/test_sum_op.py | 26 ++ 11 files changed, 513 insertions(+), 39 deletions(-) create mode 100644 paddle/operators/sum_op.cc create mode 100644 paddle/operators/sum_op.cu create mode 100644 paddle/operators/sum_op.h create mode 100644 python/paddle/v2/framework/tests/op_test.py create mode 100644 python/paddle/v2/framework/tests/test_sum_op.py diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index da92220b04e..be302669cdf 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -94,6 +94,27 @@ class OperatorBase { const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Outputs() const { return outputs_; } + + const std::vector InputsNames() const { + std::vector result; + for (auto& kv : inputs_) { + for (auto& name : kv.second) { + result.push_back(name); + } + } + return result; + } + + const std::vector OutputsNames() const { + std::vector result; + for (auto& kv : outputs_) { + for (auto& name : kv.second) { + result.push_back(name); + } + } + return result; + } + //! Get a input with argument's name described in `op_proto` std::string Input(const std::string& name) const; //! Get a input which has multiple variables. @@ -311,9 +332,9 @@ class InferShapeContext { } template - std::vector MultiOutput(const std::string& name) const { + std::vector MultiOutput(const std::string& name) const { auto names = op_.Outputs(name); - std::vector res; + std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) { diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc new file mode 100644 index 00000000000..cf650787c62 --- /dev/null +++ b/paddle/operators/sum_op.cc @@ -0,0 +1,69 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sum_op.h" +#include + +namespace paddle { +namespace operators { +using framework::Tensor; + +class SumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto ins = ctx.MultiInput("X"); + auto *out = ctx.Output("Out"); + int N = ins.size(); + + PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); + + auto dim_zero = ins[0]->dims(); + out->Resize(dim_zero); + } +}; + +class SumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "the input tensors of sum operator.").AsDuplicable(); + AddOutput("Out", "the output tensor of sum operator."); + AddComment(R"DOC( + Sum the input tensors. + )DOC"); + } +}; + +class SumGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto outputs = ctx.MultiOutput(framework::GradVarName("X")); + auto dims = ctx.Input(framework::GradVarName("Out"))->dims(); + for (auto output : outputs) { + output->Resize(dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp); +REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel); +REGISTER_OP_CPU_KERNEL(sum_grad, + ops::SumGradKernel); diff --git a/paddle/operators/sum_op.cu b/paddle/operators/sum_op.cu new file mode 100644 index 00000000000..a465cf3659b --- /dev/null +++ b/paddle/operators/sum_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/sum_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel); +REGISTER_OP_GPU_KERNEL(sum_grad, + ops::SumGradKernel); diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h new file mode 100644 index 00000000000..0b1e9ebaa38 --- /dev/null +++ b/paddle/operators/sum_op.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class SumKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto ins = context.MultiInput("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + auto place = context.GetEigenDevice(); + auto result = EigenVector::Flatten(*out); + + int N = ins.size(); + auto in = EigenVector::Flatten(*(ins[0])); + result.device(place) = in; + for (int i = 1; i < N; i++) { + auto in = EigenVector::Flatten(*(ins[i])); + result.device(place) = result + in; + } + } +}; + +template +class SumGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input(framework::GradVarName("Out")); + auto outs = context.MultiOutput(framework::GradVarName("X")); + for (auto out : outs) { + out->mutable_data(context.GetPlace()); + } + + auto place = context.GetEigenDevice(); + auto in = EigenVector::Flatten(*input); + for (auto out : outs) { + auto result = EigenVector::Flatten(*out); + result.device(place) = in; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 6896422617b..a678bc49408 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -48,6 +48,7 @@ USE_NO_KERNEL_OP(identity); USE_OP(minus); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); +USE_OP(sum); namespace paddle { namespace framework { @@ -213,7 +214,15 @@ All parameter, weight, gradient are variables in Paddle. -> std::map> { return op.Outputs(); }) + .def("outputs_names", + [](const OperatorBase &op) -> std::vector { + return op.OutputsNames(); + }) .def("inputs", [](const OperatorBase &op) { return op.Inputs(); }) + .def("inputs_names", + [](const OperatorBase &op) -> std::vector { + return op.InputsNames(); + }) .def("__str__", &OperatorBase::DebugString) .def("no_intermediate_outputs", [](const OperatorBase &op) { return op.OutputVars(false); }) diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index e7e932f6fe4..431cf4b260a 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -133,8 +133,8 @@ def create_op_creation_method(op_proto): return OpInfo( method=__impl__, name=op_proto.type, - inputs=[var.name for var in op_proto.inputs], - outputs=[var.name for var in op_proto.outputs], + inputs=[(var.name, var.duplicable) for var in op_proto.inputs], + outputs=[(var.name, var.duplicable) for var in op_proto.outputs], attrs=[attr.name for attr in op_proto.attrs]) @@ -168,9 +168,15 @@ class OperatorFactory(object): return self.op_methods.get(t) def get_op_input_names(self, type): + return map(lambda x: x[0], self.get_op_info(type).inputs) + + def get_op_inputs(self, type): return self.get_op_info(type).inputs def get_op_output_names(self, type): + return map(lambda x: x[0], self.get_op_info(type).outputs) + + def get_op_outputs(self, type): return self.get_op_info(type).outputs def get_op_attr_names(self, type): diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 661ebd89648..a4ce2630394 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -31,4 +31,5 @@ py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_gradient_checker SRCS test_gradient_checker.py) py_test(test_lookup_table SRCS test_lookup_table.py) py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) +py_test(test_sum_op SRCS test_sum_op.py) py_test(mnist SRCS mnist.py) diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py new file mode 100644 index 00000000000..9b2f10fdf95 --- /dev/null +++ b/python/paddle/v2/framework/tests/op_test.py @@ -0,0 +1,271 @@ +import unittest +import numpy as np +import itertools +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator + + +def grad_var_name(var_name): + return var_name + "@GRAD" + + +def remove_grad_var_name(var_name): + return var_name[0:-5] + + +def create_op(scope, op_type, inputs, outputs, attrs=None): + kwargs = dict() + + for ins in Operator.get_op_inputs(op_type): + in_name = ins[0] + in_dup = ins[1] + if in_name in inputs: + kwargs[in_name] = [] + if in_dup: + sub_in = inputs[in_name] + for sub_in_name in sub_in: + var = scope.new_var(sub_in_name) + tensor = var.get_tensor() + kwargs[in_name].append(sub_in_name) + else: + var = scope.new_var(in_name) + tensor = var.get_tensor() + kwargs[in_name].append(in_name) + + for outs in Operator.get_op_outputs(op_type): + out_name = outs[0] + out_dup = outs[1] + if out_name in outputs: + kwargs[out_name] = [] + if out_dup: + sub_in = outputs[out_name] + for sun_in_name in sub_in: + var = scope.new_var(sun_in_name) + tensor = var.get_tensor() + kwargs[out_name].append(sun_in_name) + else: + var = scope.new_var(out_name) + tensor = var.get_tensor() + kwargs[out_name].append(out_name) + + # for attr_name in Operator.get_op_attr_names(op_type): + # kwargs[attr_name] = attrs[attr_name] + return Operator(op_type, **kwargs) + + +def set_input(scope, op, inputs, place): + for ins in Operator.get_op_inputs(op.type()): + in_name = ins[0] + in_dup = ins[1] + if in_name in inputs: + if in_dup: + sub_in = inputs[in_name] + for sub_in_name in sub_in: + var = scope.find_var(sub_in_name) + tensor = var.get_tensor() + arr = sub_in[sub_in_name] + tensor.set_dims(arr.shape) + tensor.set(arr, place) + else: + var = scope.find_var(in_name) + tensor = var.get_tensor() + arr = inputs[in_name] + tensor.set_dims(arr.shape) + tensor.set(arr, place) + + +def set_output_grad(scope, op, outputs, place): + for outs in Operator.get_op_outputs(op.type()): + out_name = outs[0] + out_dup = outs[1] + if out_name in outputs: + if out_dup: + sub_out = outputs[out_name] + for sub_out_name in sub_out: + out_tensor = scope.find_var(sub_out_name).get_tensor() + grad_tensor = scope.new_var(grad_var_name( + sub_out_name)).get_tensor() + grad_tensor.set_dims(out_tensor.shape()) + data = np.ones(out_tensor.shape(), dtype=np.float32) + grad_tensor.set(data, place) + else: + out_tensor = scope.find_var(out_name).get_tensor() + grad_tensor = scope.new_var(grad_var_name(out_name)).get_tensor( + ) + grad_tensor.set_dims(out_tensor.shape()) + data = np.ones(out_tensor.shape(), dtype=np.float32) + grad_tensor.set(data, place) + + +def get_numeric_gradient(scope, + op, + inputs, + input_to_check, + output_name, + delta=0.005, + in_place=False): + + set_input(scope, op, inputs, core.CPUPlace()) + op.infer_shape(scope) + + tensor_to_check = scope.find_var(input_to_check).get_tensor() + + def product(dim): + return reduce(lambda a, b: a * b, dim, 1) + + ctx = core.DeviceContext.create(core.CPUPlace()) + + def get_output(): + op.run(scope, ctx) + return np.array(scope.find_var(output_name).get_tensor()).sum() + + tensor_to_check = scope.find_var(input_to_check).get_tensor() + tensor_size = product(tensor_to_check.get_dims()) + gradient_flat = np.zeros(shape=(tensor_size, ), dtype='float32') + # we only compute gradient of one element each time. + # we use a for loop to compute the gradient of every element. + for i in xrange(tensor_size): + if in_place: + set_input(op, inputs, core.CPUPlace()) + + # get one input element throw it's index i. + origin = tensor_to_check.get_float_element(i) + # add delta to it, run op and then get the sum of the result tensor. + x_pos = origin + delta + tensor_to_check.set_float_element(i, x_pos) + y_pos = get_output() + + if in_place: + set_input(op, inputs, core.CPUPlace()) + + x_neg = origin - delta + tensor_to_check.set_float_element(i, x_neg) + y_neg = get_output() + + tensor_to_check.set_float_element(i, origin) + gradient_flat[i] = (y_pos - y_neg) / delta / 2 + + return gradient_flat.reshape(tensor_to_check.get_dims()) + + +def get_backward_op(scope, op, no_grad_set): + backward_op = core.Operator.backward(op, no_grad_set) + for input in backward_op.inputs_names(): + var = scope.new_var(input) + var.get_tensor() + for output in backward_op.outputs_names(): + var = scope.new_var(output) + var.get_tensor() + return backward_op + + +def get_gradient(scope, op, inputs, outputs, grad_name, place, + no_grad_set=None): + ctx = core.DeviceContext.create(place) + + set_input(scope, op, inputs, place) + + op.infer_shape(scope) + op.run(scope, ctx) + + if no_grad_set is None: + no_grad_set = set() + + backward_op = get_backward_op(scope, op, no_grad_set) + set_output_grad(scope, op, outputs, place) + + backward_op.infer_shape(scope) + backward_op.run(scope, ctx) + + out = np.array(scope.find_var(grad_name).get_tensor()) + return out + + +class OpTest(unittest.TestCase): + def check_output(self, place): + self.scope = core.Scope() + self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs) + if isinstance(place, core.GPUPlace) and not self.op.support_gpu(): + return + set_input(self.scope, self.op, self.inputs, place) + self.op.infer_shape(self.scope) + ctx = core.DeviceContext.create(place) + self.op.run(self.scope, ctx) + + for outs in Operator.get_op_outputs(self.op.type()): + out_name = outs[0] + out_dup = outs[1] + if out_dup: + sub_out = self.outputs[out_name] + for sub_out_name in sub_out: + actual = np.array( + self.scope.find_var(sub_out_name).get_tensor()) + expect = sub_out[sub_out_name] + self.assertTrue( + np.allclose( + actual, expect, atol=1e-05), + "output name: " + out_name + "has diff") + else: + actual = np.array(self.scope.find_var(out_name).get_tensor()) + expect = self.outputs[out_name] + self.assertTrue( + np.allclose( + actual, expect, atol=1e-05), + "output name: " + out_name + "has diff") + + def __assert_is_close(self, numeric_grads, analytic_grads, names, + max_relative_error, msg_prefix): + + for a, b, name in itertools.izip(numeric_grads, analytic_grads, names): + abs_a = np.abs(a) + abs_a[abs_a < 1e-3] = 1 + + diff_mat = np.abs(a - b) / abs_a + max_diff = np.max(diff_mat) + + def err_msg(): + offset = np.argmax(diff_mat > max_relative_error) + return "%s Variable %s max gradient diff %f over limit %f, the first " \ + "error element is %d" % ( + msg_prefix, name, max_diff, max_relative_error, offset) + + self.assertLessEqual(max_diff, max_relative_error, err_msg()) + + def check_grad(self, + inputs_to_check, + output_name, + no_grad_set=None, + in_place=False, + max_relative_error=0.005): + self.scope = core.Scope() + self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs) + if no_grad_set is None: + no_grad_set = set() + + numeric_grads = [ + get_numeric_gradient( + self.scope, + self.op, + self.inputs, + input_to_check, + output_name, + in_place=in_place) for input_to_check in inputs_to_check + ] + grad_names = [ + grad_var_name(input_to_check) for input_to_check in inputs_to_check + ] + + places = [core.CPUPlace()] + if core.is_compile_gpu() and op.support_gpu(): + places.append(core.GPUPlace(0)) + + for place in places: + analytic_grads = [ + get_gradient(self.scope, self.op, self.inputs, self.outputs, + grad_name, place, no_grad_set) + for grad_name in grad_names + ] + + self.__assert_is_close(numeric_grads, analytic_grads, grad_names, + max_relative_error, + "Gradient Check On %s" % str(place)) diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index d4277f2a42c..20e0de3520b 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -1,36 +1,28 @@ import unittest import numpy -from op_test_util import OpTestMeta -from gradient_checker import GradientChecker, create_op +from op_test import OpTest +import paddle.v2.framework.core as core -class TestCrossEntropy(unittest.TestCase): - __metaclass__ = OpTestMeta - +class TestCrossEntropy(OpTest): def setUp(self): - self.type = "onehot_cross_entropy" - batch_size = 30 - class_num = 10 + self.op_type = "onehot_cross_entropy" + batch_size = 4 + class_num = 4 X = numpy.random.random((batch_size, class_num)).astype("float32") - label = 5 * numpy.ones(batch_size).astype("int32") + label = (class_num / 2) * numpy.ones(batch_size).astype("int32") self.inputs = {'X': X, 'label': label} Y = [] for i in range(0, batch_size): Y.append(-numpy.log(X[i][label[i]])) self.outputs = {'Y': numpy.array(Y).astype("float32")} + def test_check_output(self): + self.check_output(core.CPUPlace()) + self.check_output(core.GPUPlace(0)) -class CrossEntropyGradOpTest(GradientChecker): def test_check_grad(self): - op = create_op("onehot_cross_entropy") - batch_size = 30 - class_num = 10 - inputs = { - "X": numpy.random.uniform( - 0.1, 1.0, [batch_size, class_num]).astype("float32"), - "label": (class_num / 2) * numpy.ones(batch_size).astype("int32") - } - self.check_grad(op, inputs, set("X"), "Y") + self.check_grad(["X"], "Y") if __name__ == "__main__": diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 273c2e5ab1a..ff0823508fc 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -1,27 +1,23 @@ import unittest import numpy as np -from op_test_util import OpTestMeta -from gradient_checker import GradientChecker, create_op +from op_test import OpTest +import paddle.v2.framework.core as core -class TestSigmoidOp(unittest.TestCase): - __metaclass__ = OpTestMeta - +class TestSigmoid(OpTest): def setUp(self): - self.type = "sigmoid" - self.inputs = {'X': np.random.random((15, 31)).astype("float32")} + self.op_type = "sigmoid" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} + def test_check_output(self): + self.check_output(core.CPUPlace()) + self.check_output(core.GPUPlace(0)) -class TestSigmoidGradOp(GradientChecker): - def test_grad(self): - op = create_op("sigmoid") - inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")} - # compare gpu and cpu results for backward op. - # this test will be skiped if only compiling CPU version. - self.compare_grad(op, inputs) - # check gradients - self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007) + def test_check_grad(self): + self.check_grad(["X"], "Y", max_relative_error=0.007) if __name__ == '__main__': diff --git a/python/paddle/v2/framework/tests/test_sum_op.py b/python/paddle/v2/framework/tests/test_sum_op.py new file mode 100644 index 00000000000..2a7b65ef527 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_sum_op.py @@ -0,0 +1,26 @@ +import unittest +import numpy as np +from op_test import OpTest +import paddle.v2.framework.core as core + + +class TestSumOp(OpTest): + def setUp(self): + self.op_type = "sum" + x0 = np.random.random((3, 4)).astype('float32') + x1 = np.random.random((3, 4)).astype('float32') + x2 = np.random.random((3, 4)).astype('float32') + self.inputs = {"X": {"x0": x0, "x1": x1, "x2": x2}} + y = x0 + x1 + x2 + self.outputs = {'Out': y} + + def test_check_output(self): + self.check_output(core.CPUPlace()) + self.check_output(core.GPUPlace(0)) + + def test_check_grad(self): + self.check_grad(["x0"], "Out") + + +if __name__ == '__main__': + unittest.main() -- GitLab