From f50e36e2855a27160750aae26458f07eaaaae4d7 Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 8 Sep 2017 14:54:29 +0800 Subject: [PATCH] follow comments --- paddle/framework/operator.cc | 9 +++ paddle/framework/operator.h | 22 +----- paddle/operators/sum_op.cc | 10 ++- paddle/pybind/pybind.cc | 11 +-- python/paddle/v2/framework/tests/op_test.py | 78 ++++++++++--------- .../framework/tests/test_cross_entropy_op.py | 8 +- .../v2/framework/tests/test_sigmoid_op.py | 4 +- .../paddle/v2/framework/tests/test_sum_op.py | 4 +- 8 files changed, 67 insertions(+), 79 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 790cfc4746b..e1e122091f7 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -123,6 +123,15 @@ OperatorBase::OperatorBase(const std::string& type, CheckAllInputOutputSet(); } +std::vector OperatorBase::InputVars() const { + std::vector ret_val; + for (auto& o : outputs_) { + ret_val.reserve(ret_val.size() + o.second.size()); + ret_val.insert(ret_val.end(), o.second.begin(), o.second.end()); + } + return ret_val; +} + std::vector OperatorBase::OutputVars(bool has_intermediate) const { std::vector ret_val; if (has_intermediate) { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index be302669cdf..d0316224e99 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -95,31 +95,13 @@ class OperatorBase { const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Outputs() const { return outputs_; } - const std::vector InputsNames() const { - std::vector result; - for (auto& kv : inputs_) { - for (auto& name : kv.second) { - result.push_back(name); - } - } - return result; - } - - const std::vector OutputsNames() const { - std::vector result; - for (auto& kv : outputs_) { - for (auto& name : kv.second) { - result.push_back(name); - } - } - return result; - } - //! Get a input with argument's name described in `op_proto` std::string Input(const std::string& name) const; //! Get a input which has multiple variables. const std::vector& Inputs(const std::string& name) const; + std::vector InputVars() const; + //! Get a output with argument's name described in `op_proto` std::string Output(const std::string& name) const; //! Get an output which has multiple variables. diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index cf650787c62..5805826ee8a 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -26,10 +26,14 @@ class SumOp : public framework::OperatorWithKernel { auto *out = ctx.Output("Out"); int N = ins.size(); - PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); + auto in_dim = ins[0]->dims(); - auto dim_zero = ins[0]->dims(); - out->Resize(dim_zero); + PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); + for (int i = 1; i < N; i++) { + auto dim = ins[i]->dims(); + PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape"); + } + out->Resize(in_dim); } }; diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index a678bc49408..7d8e2d8fb72 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -214,15 +214,10 @@ All parameter, weight, gradient are variables in Paddle. -> std::map> { return op.Outputs(); }) - .def("outputs_names", - [](const OperatorBase &op) -> std::vector { - return op.OutputsNames(); - }) + .def("output_vars", + [](const OperatorBase &op) { return op.OutputVars(true); }) .def("inputs", [](const OperatorBase &op) { return op.Inputs(); }) - .def("inputs_names", - [](const OperatorBase &op) -> std::vector { - return op.InputsNames(); - }) + .def("input_vars", [](const OperatorBase &op) { return op.InputVars(); }) .def("__str__", &OperatorBase::DebugString) .def("no_intermediate_outputs", [](const OperatorBase &op) { return op.OutputVars(false); }) diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 3f8e1236ff7..09ee8ce385a 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -9,54 +9,40 @@ def grad_var_name(var_name): return var_name + "@GRAD" -def remove_grad_var_name(var_name): - return var_name[0:-5] - - def create_op(scope, op_type, inputs, outputs, attrs=None): kwargs = dict() - for ins in Operator.get_op_inputs(op_type): - in_name = ins[0] - in_dup = ins[1] + for in_name, in_dup in Operator.get_op_inputs(op_type): if in_name in inputs: kwargs[in_name] = [] if in_dup: sub_in = inputs[in_name] for sub_in_name in sub_in: var = scope.new_var(sub_in_name) - tensor = var.get_tensor() kwargs[in_name].append(sub_in_name) else: var = scope.new_var(in_name) - tensor = var.get_tensor() kwargs[in_name].append(in_name) - for outs in Operator.get_op_outputs(op_type): - out_name = outs[0] - out_dup = outs[1] + for out_name, out_dup in Operator.get_op_outputs(op_type): if out_name in outputs: kwargs[out_name] = [] if out_dup: sub_in = outputs[out_name] for sun_in_name in sub_in: var = scope.new_var(sun_in_name) - tensor = var.get_tensor() kwargs[out_name].append(sun_in_name) else: var = scope.new_var(out_name) - tensor = var.get_tensor() kwargs[out_name].append(out_name) - # for attr_name in Operator.get_op_attr_names(op_type): - # kwargs[attr_name] = attrs[attr_name] + for attr_name in Operator.get_op_attr_names(op_type): + kwargs[attr_name] = attrs[attr_name] return Operator(op_type, **kwargs) def set_input(scope, op, inputs, place): - for ins in Operator.get_op_inputs(op.type()): - in_name = ins[0] - in_dup = ins[1] + for in_name, in_dup in Operator.get_op_inputs(op.type()): if in_name in inputs: if in_dup: sub_in = inputs[in_name] @@ -75,9 +61,7 @@ def set_input(scope, op, inputs, place): def set_output_grad(scope, op, outputs, place): - for outs in Operator.get_op_outputs(op.type()): - out_name = outs[0] - out_dup = outs[1] + for out_name, out_dup in Operator.get_op_outputs(op.type()): if out_name in outputs: if out_dup: sub_out = outputs[out_name] @@ -150,10 +134,10 @@ def get_numeric_gradient(scope, def get_backward_op(scope, op, no_grad_set): backward_op = core.Operator.backward(op, no_grad_set) - for input in backward_op.inputs_names(): + for input in backward_op.input_vars(): var = scope.new_var(input) var.get_tensor() - for output in backward_op.outputs_names(): + for output in backward_op.output_vars(): var = scope.new_var(output) var.get_tensor() return backward_op @@ -182,7 +166,7 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place, class OpTest(unittest.TestCase): - def check_output(self, place): + def check_output_with_place(self, place): self.scope = core.Scope() self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs) if isinstance(place, core.GPUPlace) and not self.op.support_gpu(): @@ -192,9 +176,7 @@ class OpTest(unittest.TestCase): ctx = core.DeviceContext.create(place) self.op.run(self.scope, ctx) - for outs in Operator.get_op_outputs(self.op.type()): - out_name = outs[0] - out_dup = outs[1] + for out_name, out_dup in Operator.get_op_outputs(self.op.type()): if out_dup: sub_out = self.outputs[out_name] for sub_out_name in sub_out: @@ -213,6 +195,13 @@ class OpTest(unittest.TestCase): actual, expect, atol=1e-05), "output name: " + out_name + "has diff") + def check_output(self): + places = [core.CPUPlace()] + if core.is_compile_gpu() and self.op.support_gpu(): + places.append(core.GPUPlace(0)) + for place in places: + self.check_output_with_place(place) + def __assert_is_close(self, numeric_grads, analytic_grads, names, max_relative_error, msg_prefix): @@ -255,17 +244,32 @@ class OpTest(unittest.TestCase): grad_var_name(input_to_check) for input_to_check in inputs_to_check ] - places = [core.CPUPlace()] - if core.is_compile_gpu() and self.op.support_gpu(): - places.append(core.GPUPlace(0)) + cpu_place = core.CPUPlace() + cpu_analytic_grads = [ + get_gradient(self.scope, self.op, self.inputs, self.outputs, + grad_name, cpu_place, no_grad_set) + for grad_name in grad_names + ] - for place in places: - analytic_grads = [ + self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names, + max_relative_error, + "Gradient Check On %s" % str(cpu_place)) + + if core.is_compile_gpu() and self.op.support_gpu(): + gpu_place = core.GPUPlace(0) + gpu_analytic_grads = [ get_gradient(self.scope, self.op, self.inputs, self.outputs, - grad_name, place, no_grad_set) + grad_name, gpu_place, no_grad_set) for grad_name in grad_names ] - self.__assert_is_close(numeric_grads, analytic_grads, grad_names, - max_relative_error, - "Gradient Check On %s" % str(place)) + self.__assert_is_close(numeric_grads, gpu_analytic_grads, + grad_names, max_relative_error, + "Gradient Check On %s" % str(gpu_place)) + + for c_grad, g_grad, name in itertools.izip( + cpu_analytic_grads, gpu_analytic_grads, grad_names): + self.assertTrue( + numpy.allclose( + c_grad, g_grad, atol=1e-4), + "output name: " + name + " has diff") diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 20e0de3520b..1956df0bb4e 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -1,14 +1,13 @@ import unittest import numpy from op_test import OpTest -import paddle.v2.framework.core as core class TestCrossEntropy(OpTest): def setUp(self): self.op_type = "onehot_cross_entropy" - batch_size = 4 - class_num = 4 + batch_size = 30 + class_num = 10 X = numpy.random.random((batch_size, class_num)).astype("float32") label = (class_num / 2) * numpy.ones(batch_size).astype("int32") self.inputs = {'X': X, 'label': label} @@ -18,8 +17,7 @@ class TestCrossEntropy(OpTest): self.outputs = {'Y': numpy.array(Y).astype("float32")} def test_check_output(self): - self.check_output(core.CPUPlace()) - self.check_output(core.GPUPlace(0)) + self.check_output() def test_check_grad(self): self.check_grad(["X"], "Y") diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index ff0823508fc..2316e49eff7 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -1,7 +1,6 @@ import unittest import numpy as np from op_test import OpTest -import paddle.v2.framework.core as core class TestSigmoid(OpTest): @@ -13,8 +12,7 @@ class TestSigmoid(OpTest): self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} def test_check_output(self): - self.check_output(core.CPUPlace()) - self.check_output(core.GPUPlace(0)) + self.check_output() def test_check_grad(self): self.check_grad(["X"], "Y", max_relative_error=0.007) diff --git a/python/paddle/v2/framework/tests/test_sum_op.py b/python/paddle/v2/framework/tests/test_sum_op.py index 2a7b65ef527..66417d70e81 100644 --- a/python/paddle/v2/framework/tests/test_sum_op.py +++ b/python/paddle/v2/framework/tests/test_sum_op.py @@ -1,7 +1,6 @@ import unittest import numpy as np from op_test import OpTest -import paddle.v2.framework.core as core class TestSumOp(OpTest): @@ -15,8 +14,7 @@ class TestSumOp(OpTest): self.outputs = {'Out': y} def test_check_output(self): - self.check_output(core.CPUPlace()) - self.check_output(core.GPUPlace(0)) + self.check_output() def test_check_grad(self): self.check_grad(["x0"], "Out") -- GitLab