diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index b2813da83d9e4c525e66bb1f79b28769627eaec2..6c26183818a9d6996e3d3ce2af74ba36f4711eca 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -260,12 +260,6 @@ class OpRegistry { return CreateOp(op_desc.type(), inputs, outputs, attrs); } - static bool SupportGPU(const std::string& op_type) { - OperatorWithKernel::OpKernelKey key; - key.place_ = platform::GPUPlace(); - return OperatorWithKernel::AllOpKernels().at(op_type).count(key) != 0; - } - static std::shared_ptr CreateGradOp(const OperatorBase& op) { PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 03fabff79b637299f8e133aab29ccb0e145379cf..c324fa6702de1eabab3f75cbf4e6568c99b60470 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -88,6 +88,8 @@ class OperatorBase { virtual bool IsNetOp() const { return false; } + virtual bool SupportGPU() const { return false; } + /// rename inputs outputs name void Rename(const std::string& old_name, const std::string& new_name); @@ -308,7 +310,7 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; - void InferShape(const Scope& scope) const { + void InferShape(const Scope& scope) const override { InferShape(InferShapeContext(this, scope)); } @@ -324,6 +326,12 @@ class OperatorWithKernel : public OperatorBase { return g_all_op_kernels; } + bool SupportGPU() const override { + OperatorWithKernel::OpKernelKey key; + key.place_ = platform::GPUPlace(); + return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0; + } + protected: virtual void InferShape(const InferShapeContext& ctx) const = 0; }; diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 011391bc2df4206dde15030cd78d6ea329530de4..e17d0874a938bc615638e78dd4a1a3cc2a9f0878 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -57,6 +57,26 @@ void ExposeOperator(ClassType &m) { [](const typename ClassType::type &op) -> std::vector { return op.outputs_; }) + .def("inputs", + [](const typename ClassType::type &op) -> std::vector { + return op.inputs_; + }) + .def("support_gpu", &ClassType::type::SupportGPU) + .def("temp_outputs", + [](const typename ClassType::type &op) -> std::vector { + auto iter = op.attrs_.find("temporary_index"); + std::vector ret; + if (iter == op.attrs_.end()) { + return ret; + } else { + auto tmp_idx = boost::get>(iter->second); + ret.reserve(tmp_idx.size()); + for (auto &index : tmp_idx) { + ret.push_back(op.outputs_.at(index)); + } + return ret; + } + }) .def("__str__", &ClassType::type::DebugString); } @@ -202,8 +222,6 @@ All parameter, weight, gradient are variables in Paddle. return OpRegistry::CreateOp(desc); }); - operator_base.def_static("support_gpu", &OpRegistry::SupportGPU); - operator_base.def("backward", [](const OperatorBase &forwardOp, const std::unordered_set &no_grad_vars) { diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 942b919079bf06caeb6d185efb31d9d28d193008..ecf63f6494b0a0a0f2dba1f883389e959e8fbe78 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -70,7 +70,8 @@ REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, ops::OnehotCrossEntropyOpMaker); REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, ops::OnehotCrossEntropyOpKernel); - +REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL( onehot_cross_entropy_grad, ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index bb2d02b56f48ac4b2f3b1ca742ae6d6141d3454e..b6d269b9cdc18968b047bffdb5a3799235c5640e 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -65,6 +65,15 @@ class NetOp : public framework::OperatorBase { } } + bool SupportGPU() const override { + for (auto& op : ops_) { + if (!op->SupportGPU()) { + return false; + } + } + return true; + } + /** * @brief Add an operator by ptr */ diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 0328bea7f7f18562fcb6b5a19cd2a2c70f10a532..10659caa882fd3d4060f9947413a392c3b681ee8 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -13,6 +13,7 @@ py_test(test_protobuf SRCS test_protobuf.py) py_test(test_add_two_op SRCS test_add_two_op.py) py_test(test_sigmoid_op SRCS test_sigmoid_op.py) py_test(test_softmax_op SRCS test_softmax_op.py) +py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py) py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py) py_test(gradient_checker SRCS gradient_checker.py) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index cfd29932f5b46920815819c5a75d62a0138e21a2..b73c4869d14a62a951d8e45dafb14b7523355519 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -1,16 +1,31 @@ +import unittest + +import numpy import paddle.v2.framework.core as core from paddle.v2.framework.op import Operator -import numpy -import unittest __all__ = ['get_numeric_gradient'] +def create_op(op_type): + kwargs = dict() + for in_name in Operator.get_op_input_names(op_type): + kwargs[in_name] = in_name + for out_name in Operator.get_op_output_names(op_type): + kwargs[out_name] = out_name + + return Operator(op_type, **kwargs) + + +def grad_var_name(var_name): + return var_name + "@GRAD" + + def get_numeric_gradient(op, input_values, output_name, input_to_check, - delta=1e-2, + delta=0.005, local_scope=None): """ Get Numeric Gradient for an operator's input. @@ -76,6 +91,113 @@ def get_numeric_gradient(op, return gradient_flat.reshape(tensor_to_check.get_dims()) +class GradientChecker(unittest.TestCase): + def __is_close(self, numeric_grads, scope, max_relative_error): + for name in numeric_grads: + op_grad = numpy.array( + scope.find_var(grad_var_name(name)).get_tensor()) + is_close = numpy.allclose( + numeric_grads[name], op_grad, rtol=max_relative_error, atol=100) + if not is_close: + return False + return True + + def check_grad(self, + forward_op, + input_vars, + inputs_to_check, + output_name, + no_grad_set=None, + only_cpu=False, + max_relative_error=0.005): + """ + :param forward_op: used to create backward_op + :param input_vars: numpy value of input variable. The following + computation will use these variables. + :param inputs_to_check: inputs var names that should check gradient. + :param output_name: output name that used to + :param max_relative_error: The relative tolerance parameter. + :param no_grad_set: used when create backward ops + :param only_cpu: only compute and check gradient on cpu kernel. + :return: + """ + if no_grad_set is None: + no_grad_set = set() + + tmp_outs = forward_op.temp_outputs() + no_tmp_out = filter(lambda name: name not in tmp_outs, + forward_op.outputs()) + if len(no_tmp_out) != 1: + raise ValueError("non temp out_names should be 1") + + in_names = forward_op.inputs() + for no_grad in no_grad_set: + if no_grad not in in_names: + raise ValueError("no_grad should be in in_names") + + backward_op = core.Operator.backward(forward_op, no_grad_set) + + places = [core.CPUPlace()] + if not only_cpu and core.is_compile_gpu() and backward_op.support_gpu(): + places.append(core.GPUPlace(0)) + + numeric_grad = dict() + # get numeric gradient + for check_name in inputs_to_check: + numeric_grad[check_name] = \ + get_numeric_gradient(forward_op, input_vars, output_name, check_name) + + # get operator gradient according to different device + for place in places: + scope = core.Scope() + ctx = core.DeviceContext.create(place) + + # create input var and set value + for name, value in input_vars.iteritems(): + if name not in in_names: + raise ValueError(name + " not in op.inputs_") + var = scope.new_var(name).get_tensor() + var.set_dims(value.shape) + var.set(value, place) + + # create output var + for out_name in forward_op.outputs(): + scope.new_var(out_name).get_tensor() + + # infer the shape of output var and compute/set value of output var + forward_op.infer_shape(scope) + forward_op.run(scope, ctx) + + # create output grad var + # set shape as the output var + # set value of this grad to ones + for name in forward_op.outputs(): + out_tensor = scope.find_var(name).get_tensor() + grad_tensor = scope.new_var(grad_var_name(name)).get_tensor() + grad_tensor.set_dims(out_tensor.shape()) + data = 1.0 * numpy.ones(out_tensor.shape()) + grad_tensor.set(data, place) + + # create input grad var + for name in backward_op.outputs(): + scope.new_var(name).get_tensor() + + # infer the shape of input gradient var and compute/set it's value + # with backward op + backward_op.infer_shape(scope) + backward_op.run(scope, ctx) + + if isinstance(place, core.CPUPlace): + msg = "CPU kernel gradient is not close to numeric gradient" + else: + if isinstance(place, core.GPUPlace): + msg = "GPU kernel gradient is not close to numeric gradient" + else: + raise ValueError("unknown place " + type(place)) + self.assertTrue( + self.__is_close(numeric_grad, scope, max_relative_error), msg) + + if __name__ == '__main__': class GetNumericGradientTest(unittest.TestCase): @@ -87,4 +209,28 @@ if __name__ == '__main__': arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X') self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2) + def test_softmax_op(self): + def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + shiftx = x - numpy.max(x) + exps = numpy.exp(shiftx) + return exps / numpy.sum(exps) + + def label_softmax_grad(Y, dY): + dX = Y * 0.0 + for i in range(Y.shape[0]): + d = numpy.dot(Y[i, :], dY[i, :]) + dX[i, :] = Y[i, :] * (dY[i, :] - d) + return dX + + softmax_op = Operator("softmax", X="X", Y="Y") + + X = numpy.random.random((2, 2)).astype("float32") + Y = numpy.apply_along_axis(stable_softmax, 1, X) + dY = numpy.ones(Y.shape) + dX = label_softmax_grad(Y, dY) + + arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X') + numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2) + unittest.main() diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index da6bed0fcd690d5a7f53f44d0181c75f12e5d074..dd65e0f2dc23d3f657ff16c55fb297dae210b2d7 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -1,6 +1,5 @@ -import paddle.v2.framework.core as core -import unittest import numpy +import paddle.v2.framework.core as core from paddle.v2.framework.op import Operator @@ -24,7 +23,7 @@ class OpTestMeta(type): scope = core.Scope() kwargs = dict() places = [core.CPUPlace()] - if core.is_compile_gpu() and core.Operator.support_gpu(self.type): + if core.is_compile_gpu(): places.append(core.GPUPlace(0)) for place in places: @@ -53,6 +52,8 @@ class OpTestMeta(type): kwargs[attr_name] = self.attrs[attr_name] op = Operator(self.type, **kwargs) + if isinstance(place, core.GPUPlace) and not op.support_gpu(): + return op.infer_shape(scope) diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index b26e25d58b59bd1cb16e9ba2a1cccd27799b15f2..4815192e255c6e0429db3f50918a76a773b30131 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -1,9 +1,10 @@ import unittest import numpy from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op -class TestSGD(unittest.TestCase): +class TestCrossEntropy(unittest.TestCase): __metaclass__ = OpTestMeta def setUp(self): @@ -20,7 +21,18 @@ class TestSGD(unittest.TestCase): self.outputs = {'Y': numpy.array(Y).astype("float32")} -# TODO(superjom) add gradient check +class CrossEntropyGradOpTest(GradientChecker): + def test_softmax_grad(self): + op = create_op("onehot_cross_entropy") + batch_size = 100 + class_num = 10 + inputs = { + "X": numpy.random.uniform( + 0.1, 1.0, [batch_size, class_num]).astype("float32"), + "label": (class_num / 2) * numpy.ones(batch_size).astype("int32") + } + self.check_grad(op, inputs, set("X"), "Y") + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/framework/tests/test_softmax_op.py b/python/paddle/v2/framework/tests/test_softmax_op.py index d20e085b8e43488480edf07b6cd4edcd861883f3..e670d93653e07d35e5019c9daac45c214eddf367 100644 --- a/python/paddle/v2/framework/tests/test_softmax_op.py +++ b/python/paddle/v2/framework/tests/test_softmax_op.py @@ -1,9 +1,8 @@ import unittest import numpy as np -import paddle.v2.framework.core as core -from paddle.v2.framework.op import Operator +from gradient_checker import GradientChecker, create_op from op_test_util import OpTestMeta @@ -25,62 +24,11 @@ class TestSoftmaxOp(unittest.TestCase): } -class TestSoftmaxGradOp(unittest.TestCase): - def test_softmax_grad(self): - op = Operator('softmax', X="X", Y="Y") - backward_op = core.Operator.backward(op, set()) - self.assertEqual(backward_op.type(), "softmax_grad") - expected = '''Op(softmax_grad), inputs:(X, Y, Y@GRAD), outputs:(X@GRAD).''' - self.assertEqual(expected, str(backward_op)) - - batch_size = 3 - class_num = 5 - # Initialize X and add 1e-2 for numerical stability - Y = np.random.rand(batch_size, class_num).astype(np.float32) - Y = Y + 1e-2 - dY = np.random.rand(batch_size, class_num).astype(np.float32) - - # Reference implementation of cross entropy with soft labels - def label_softmax_grad(Y, dY): - dX = Y * 0.0 - for i in range(batch_size): - d = np.dot(Y[i, :], dY[i, :]) - dX[i, :] = Y[i, :] * (dY[i, :] - d) - return dX - - expected = label_softmax_grad(Y, dY) - - scope = core.Scope() - places = [] - places.append(core.CPUPlace()) - if core.is_compile_gpu(): - places.append(core.GPUPlace(0)) - - for place in places: - y = scope.new_var("Y") - y_tensor = y.get_tensor() - y_tensor.set_dims([batch_size, class_num]) - y_tensor.alloc_float(place) - y_tensor.set(Y, place) - - dy = scope.new_var("Y@GRAD") - dy_tensor = dy.get_tensor() - dy_tensor.set_dims([batch_size, class_num]) - dy_tensor.alloc_float(place) - dy_tensor.set(dY, place) - - x = scope.new_var("X") - dx = scope.new_var("X@GRAD") - - tensor = scope.find_var("X@GRAD").get_tensor() - backward_op.infer_shape(scope) - self.assertEqual([batch_size, class_num], tensor.shape()) - - ctx = core.DeviceContext.create(place) - backward_op.run(scope, ctx) - actual = np.array(tensor) - - np.testing.assert_almost_equal(actual, expected, decimal=3) +class SoftmaxGradOpTest(GradientChecker): + def test_softmax(self): + op = create_op("softmax") + inputs = {"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32")} + self.check_grad(op, inputs, set("X"), "Y") if __name__ == '__main__':