提交 e31a469e 编写于 作者: Q Qiao Longfei 提交者: GitHub

add gradient test framework (#3226)

* init grad op checker

* can run

* add GradeChecker class

* use get_numeric_gradient

* refine code

* add softmax and cross entropy auto grad test

* use close to judge op_grad and numeric_grad

* add cpu and gpu compare

* add comments

* add support_gpu

* fix allclose

* fix name error and symplify code

* optimize gradient checker

* add test_cross_entropy_op

* update gradient_checker.py

* optimize code

* use random.uniform instead of random.random

* fix type bug

* optimize check_grad

* put SupportGPU into OperatorBase

* typo
上级 6540701f
...@@ -260,12 +260,6 @@ class OpRegistry { ...@@ -260,12 +260,6 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
static bool SupportGPU(const std::string& op_type) {
OperatorWithKernel::OpKernelKey key;
key.place_ = platform::GPUPlace();
return OperatorWithKernel::AllOpKernels().at(op_type).count(key) != 0;
}
static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) { static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
PADDLE_ENFORCE(!op.IsNetOp(), PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops"); "Use framework::Backward to get backward ops");
......
...@@ -88,6 +88,8 @@ class OperatorBase { ...@@ -88,6 +88,8 @@ class OperatorBase {
virtual bool IsNetOp() const { return false; } virtual bool IsNetOp() const { return false; }
virtual bool SupportGPU() const { return false; }
/// rename inputs outputs name /// rename inputs outputs name
void Rename(const std::string& old_name, const std::string& new_name); void Rename(const std::string& old_name, const std::string& new_name);
...@@ -308,7 +310,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -308,7 +310,7 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void InferShape(const Scope& scope) const { void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(this, scope)); InferShape(InferShapeContext(this, scope));
} }
...@@ -324,6 +326,12 @@ class OperatorWithKernel : public OperatorBase { ...@@ -324,6 +326,12 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels; return g_all_op_kernels;
} }
bool SupportGPU() const override {
OperatorWithKernel::OpKernelKey key;
key.place_ = platform::GPUPlace();
return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
}
protected: protected:
virtual void InferShape(const InferShapeContext& ctx) const = 0; virtual void InferShape(const InferShapeContext& ctx) const = 0;
}; };
......
...@@ -57,6 +57,26 @@ void ExposeOperator(ClassType &m) { ...@@ -57,6 +57,26 @@ void ExposeOperator(ClassType &m) {
[](const typename ClassType::type &op) -> std::vector<std::string> { [](const typename ClassType::type &op) -> std::vector<std::string> {
return op.outputs_; return op.outputs_;
}) })
.def("inputs",
[](const typename ClassType::type &op) -> std::vector<std::string> {
return op.inputs_;
})
.def("support_gpu", &ClassType::type::SupportGPU)
.def("temp_outputs",
[](const typename ClassType::type &op) -> std::vector<std::string> {
auto iter = op.attrs_.find("temporary_index");
std::vector<std::string> ret;
if (iter == op.attrs_.end()) {
return ret;
} else {
auto tmp_idx = boost::get<std::vector<int>>(iter->second);
ret.reserve(tmp_idx.size());
for (auto &index : tmp_idx) {
ret.push_back(op.outputs_.at(index));
}
return ret;
}
})
.def("__str__", &ClassType::type::DebugString); .def("__str__", &ClassType::type::DebugString);
} }
...@@ -202,8 +222,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -202,8 +222,6 @@ All parameter, weight, gradient are variables in Paddle.
return OpRegistry::CreateOp(desc); return OpRegistry::CreateOp(desc);
}); });
operator_base.def_static("support_gpu", &OpRegistry::SupportGPU);
operator_base.def("backward", operator_base.def("backward",
[](const OperatorBase &forwardOp, [](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) { const std::unordered_set<std::string> &no_grad_vars) {
......
...@@ -70,7 +70,8 @@ REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, ...@@ -70,7 +70,8 @@ REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker); ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>); ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<ops::CPUPlace, float>); ops::OnehotCrossEntropyGradientOpKernel<ops::CPUPlace, float>);
...@@ -65,6 +65,15 @@ class NetOp : public framework::OperatorBase { ...@@ -65,6 +65,15 @@ class NetOp : public framework::OperatorBase {
} }
} }
bool SupportGPU() const override {
for (auto& op : ops_) {
if (!op->SupportGPU()) {
return false;
}
}
return true;
}
/** /**
* @brief Add an operator by ptr * @brief Add an operator by ptr
*/ */
......
...@@ -13,6 +13,7 @@ py_test(test_protobuf SRCS test_protobuf.py) ...@@ -13,6 +13,7 @@ py_test(test_protobuf SRCS test_protobuf.py)
py_test(test_add_two_op SRCS test_add_two_op.py) py_test(test_add_two_op SRCS test_add_two_op.py)
py_test(test_sigmoid_op SRCS test_sigmoid_op.py) py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
py_test(test_softmax_op SRCS test_softmax_op.py) py_test(test_softmax_op SRCS test_softmax_op.py)
py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py) py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(gradient_checker SRCS gradient_checker.py) py_test(gradient_checker SRCS gradient_checker.py)
......
import unittest
import numpy
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator from paddle.v2.framework.op import Operator
import numpy
import unittest
__all__ = ['get_numeric_gradient'] __all__ = ['get_numeric_gradient']
def create_op(op_type):
kwargs = dict()
for in_name in Operator.get_op_input_names(op_type):
kwargs[in_name] = in_name
for out_name in Operator.get_op_output_names(op_type):
kwargs[out_name] = out_name
return Operator(op_type, **kwargs)
def grad_var_name(var_name):
return var_name + "@GRAD"
def get_numeric_gradient(op, def get_numeric_gradient(op,
input_values, input_values,
output_name, output_name,
input_to_check, input_to_check,
delta=1e-2, delta=0.005,
local_scope=None): local_scope=None):
""" """
Get Numeric Gradient for an operator's input. Get Numeric Gradient for an operator's input.
...@@ -76,6 +91,113 @@ def get_numeric_gradient(op, ...@@ -76,6 +91,113 @@ def get_numeric_gradient(op,
return gradient_flat.reshape(tensor_to_check.get_dims()) return gradient_flat.reshape(tensor_to_check.get_dims())
class GradientChecker(unittest.TestCase):
def __is_close(self, numeric_grads, scope, max_relative_error):
for name in numeric_grads:
op_grad = numpy.array(
scope.find_var(grad_var_name(name)).get_tensor())
is_close = numpy.allclose(
numeric_grads[name], op_grad, rtol=max_relative_error, atol=100)
if not is_close:
return False
return True
def check_grad(self,
forward_op,
input_vars,
inputs_to_check,
output_name,
no_grad_set=None,
only_cpu=False,
max_relative_error=0.005):
"""
:param forward_op: used to create backward_op
:param input_vars: numpy value of input variable. The following
computation will use these variables.
:param inputs_to_check: inputs var names that should check gradient.
:param output_name: output name that used to
:param max_relative_error: The relative tolerance parameter.
:param no_grad_set: used when create backward ops
:param only_cpu: only compute and check gradient on cpu kernel.
:return:
"""
if no_grad_set is None:
no_grad_set = set()
tmp_outs = forward_op.temp_outputs()
no_tmp_out = filter(lambda name: name not in tmp_outs,
forward_op.outputs())
if len(no_tmp_out) != 1:
raise ValueError("non temp out_names should be 1")
in_names = forward_op.inputs()
for no_grad in no_grad_set:
if no_grad not in in_names:
raise ValueError("no_grad should be in in_names")
backward_op = core.Operator.backward(forward_op, no_grad_set)
places = [core.CPUPlace()]
if not only_cpu and core.is_compile_gpu() and backward_op.support_gpu():
places.append(core.GPUPlace(0))
numeric_grad = dict()
# get numeric gradient
for check_name in inputs_to_check:
numeric_grad[check_name] = \
get_numeric_gradient(forward_op, input_vars, output_name, check_name)
# get operator gradient according to different device
for place in places:
scope = core.Scope()
ctx = core.DeviceContext.create(place)
# create input var and set value
for name, value in input_vars.iteritems():
if name not in in_names:
raise ValueError(name + " not in op.inputs_")
var = scope.new_var(name).get_tensor()
var.set_dims(value.shape)
var.set(value, place)
# create output var
for out_name in forward_op.outputs():
scope.new_var(out_name).get_tensor()
# infer the shape of output var and compute/set value of output var
forward_op.infer_shape(scope)
forward_op.run(scope, ctx)
# create output grad var
# set shape as the output var
# set value of this grad to ones
for name in forward_op.outputs():
out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.new_var(grad_var_name(name)).get_tensor()
grad_tensor.set_dims(out_tensor.shape())
data = 1.0 * numpy.ones(out_tensor.shape())
grad_tensor.set(data, place)
# create input grad var
for name in backward_op.outputs():
scope.new_var(name).get_tensor()
# infer the shape of input gradient var and compute/set it's value
# with backward op
backward_op.infer_shape(scope)
backward_op.run(scope, ctx)
if isinstance(place, core.CPUPlace):
msg = "CPU kernel gradient is not close to numeric gradient"
else:
if isinstance(place, core.GPUPlace):
msg = "GPU kernel gradient is not close to numeric gradient"
else:
raise ValueError("unknown place " + type(place))
self.assertTrue(
self.__is_close(numeric_grad, scope, max_relative_error), msg)
if __name__ == '__main__': if __name__ == '__main__':
class GetNumericGradientTest(unittest.TestCase): class GetNumericGradientTest(unittest.TestCase):
...@@ -87,4 +209,28 @@ if __name__ == '__main__': ...@@ -87,4 +209,28 @@ if __name__ == '__main__':
arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X') arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2) self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2)
def test_softmax_op(self):
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - numpy.max(x)
exps = numpy.exp(shiftx)
return exps / numpy.sum(exps)
def label_softmax_grad(Y, dY):
dX = Y * 0.0
for i in range(Y.shape[0]):
d = numpy.dot(Y[i, :], dY[i, :])
dX[i, :] = Y[i, :] * (dY[i, :] - d)
return dX
softmax_op = Operator("softmax", X="X", Y="Y")
X = numpy.random.random((2, 2)).astype("float32")
Y = numpy.apply_along_axis(stable_softmax, 1, X)
dY = numpy.ones(Y.shape)
dX = label_softmax_grad(Y, dY)
arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X')
numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2)
unittest.main() unittest.main()
import paddle.v2.framework.core as core
import unittest
import numpy import numpy
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator from paddle.v2.framework.op import Operator
...@@ -24,7 +23,7 @@ class OpTestMeta(type): ...@@ -24,7 +23,7 @@ class OpTestMeta(type):
scope = core.Scope() scope = core.Scope()
kwargs = dict() kwargs = dict()
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compile_gpu() and core.Operator.support_gpu(self.type): if core.is_compile_gpu():
places.append(core.GPUPlace(0)) places.append(core.GPUPlace(0))
for place in places: for place in places:
...@@ -53,6 +52,8 @@ class OpTestMeta(type): ...@@ -53,6 +52,8 @@ class OpTestMeta(type):
kwargs[attr_name] = self.attrs[attr_name] kwargs[attr_name] = self.attrs[attr_name]
op = Operator(self.type, **kwargs) op = Operator(self.type, **kwargs)
if isinstance(place, core.GPUPlace) and not op.support_gpu():
return
op.infer_shape(scope) op.infer_shape(scope)
......
import unittest import unittest
import numpy import numpy
from op_test_util import OpTestMeta from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
class TestSGD(unittest.TestCase): class TestCrossEntropy(unittest.TestCase):
__metaclass__ = OpTestMeta __metaclass__ = OpTestMeta
def setUp(self): def setUp(self):
...@@ -20,7 +21,18 @@ class TestSGD(unittest.TestCase): ...@@ -20,7 +21,18 @@ class TestSGD(unittest.TestCase):
self.outputs = {'Y': numpy.array(Y).astype("float32")} self.outputs = {'Y': numpy.array(Y).astype("float32")}
# TODO(superjom) add gradient check class CrossEntropyGradOpTest(GradientChecker):
def test_softmax_grad(self):
op = create_op("onehot_cross_entropy")
batch_size = 100
class_num = 10
inputs = {
"X": numpy.random.uniform(
0.1, 1.0, [batch_size, class_num]).astype("float32"),
"label": (class_num / 2) * numpy.ones(batch_size).astype("int32")
}
self.check_grad(op, inputs, set("X"), "Y")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
import unittest import unittest
import numpy as np import numpy as np
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta from op_test_util import OpTestMeta
...@@ -25,62 +24,11 @@ class TestSoftmaxOp(unittest.TestCase): ...@@ -25,62 +24,11 @@ class TestSoftmaxOp(unittest.TestCase):
} }
class TestSoftmaxGradOp(unittest.TestCase): class SoftmaxGradOpTest(GradientChecker):
def test_softmax_grad(self): def test_softmax(self):
op = Operator('softmax', X="X", Y="Y") op = create_op("softmax")
backward_op = core.Operator.backward(op, set()) inputs = {"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32")}
self.assertEqual(backward_op.type(), "softmax_grad") self.check_grad(op, inputs, set("X"), "Y")
expected = '''Op(softmax_grad), inputs:(X, Y, Y@GRAD), outputs:(X@GRAD).'''
self.assertEqual(expected, str(backward_op))
batch_size = 3
class_num = 5
# Initialize X and add 1e-2 for numerical stability
Y = np.random.rand(batch_size, class_num).astype(np.float32)
Y = Y + 1e-2
dY = np.random.rand(batch_size, class_num).astype(np.float32)
# Reference implementation of cross entropy with soft labels
def label_softmax_grad(Y, dY):
dX = Y * 0.0
for i in range(batch_size):
d = np.dot(Y[i, :], dY[i, :])
dX[i, :] = Y[i, :] * (dY[i, :] - d)
return dX
expected = label_softmax_grad(Y, dY)
scope = core.Scope()
places = []
places.append(core.CPUPlace())
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for place in places:
y = scope.new_var("Y")
y_tensor = y.get_tensor()
y_tensor.set_dims([batch_size, class_num])
y_tensor.alloc_float(place)
y_tensor.set(Y, place)
dy = scope.new_var("Y@GRAD")
dy_tensor = dy.get_tensor()
dy_tensor.set_dims([batch_size, class_num])
dy_tensor.alloc_float(place)
dy_tensor.set(dY, place)
x = scope.new_var("X")
dx = scope.new_var("X@GRAD")
tensor = scope.find_var("X@GRAD").get_tensor()
backward_op.infer_shape(scope)
self.assertEqual([batch_size, class_num], tensor.shape())
ctx = core.DeviceContext.create(place)
backward_op.run(scope, ctx)
actual = np.array(tensor)
np.testing.assert_almost_equal(actual, expected, decimal=3)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册