提交 03789a7d 编写于 作者: Z zchen0211

batch norm fully tortured and passed

上级 8a07aff4
...@@ -208,8 +208,15 @@ class BatchNormGradKernel<platform::GPUPlace, T> ...@@ -208,8 +208,15 @@ class BatchNormGradKernel<platform::GPUPlace, T>
mode_ = CUDNN_BATCHNORM_SPATIAL; mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif #endif
std::vector<int> dims = {N, C, H, W, D}; std::vector<int> dims;
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C}; std::vector<int> strides;
if (tensor_format == TensorFormat::NCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * C * D, 1, W * D * C, D * C, C};
}
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
......
...@@ -96,22 +96,25 @@ def create_or_get_tensor(scope, var_name, var, place): ...@@ -96,22 +96,25 @@ def create_or_get_tensor(scope, var_name, var, place):
return tensor return tensor
def set_output_grad(scope, outputs, place): def set_output_grad(scope, outputs, place, feed_dict=None):
def __set_tensor__(name): def __set_tensor__(name, data=None):
out_tensor = scope.find_var(name).get_tensor() out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.var(grad_var_name(name)).get_tensor() grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype() out_dtype = out_tensor.dtype()
if out_dtype == core.DataType.FP64: if data is None:
data = np.ones(out_tensor.shape(), dtype=np.float64) if out_dtype == core.DataType.FP64:
elif out_dtype == core.DataType.FP32: data = np.ones(out_tensor.shape(), dtype=np.float64)
data = np.ones(out_tensor.shape(), dtype=np.float32) elif out_dtype == core.DataType.FP32:
else: data = np.ones(out_tensor.shape(), dtype=np.float32)
raise ValueError("Not supported data type " + str(out_dtype)) else:
raise ValueError("Not supported data type " + str(out_dtype))
grad_tensor.set(data, place) grad_tensor.set(data, place)
for output in outputs: for output in outputs:
__set_tensor__(output) data = None
if output in feed_dict:
data = feed_dict[output]
__set_tensor__(output, data)
class TestBatchNormOp(OpTest): class TestBatchNormOp(OpTest):
...@@ -119,7 +122,7 @@ class TestBatchNormOp(OpTest): ...@@ -119,7 +122,7 @@ class TestBatchNormOp(OpTest):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_python(self): def test_python(self):
data_format = "NCHW" data_format = "NHWC"
epsilon = 0.00001 epsilon = 0.00001
momentum = 0.9 momentum = 0.9
...@@ -214,7 +217,10 @@ class TestBatchNormOp(OpTest): ...@@ -214,7 +217,10 @@ class TestBatchNormOp(OpTest):
saved_variance = 1. / np.sqrt(var_ref + epsilon) saved_variance = 1. / np.sqrt(var_ref + epsilon)
# for gradient test # for gradient test
y_grad = np.ones(x_shape).astype(np.float32) # y_grad = np.ones(x_shape).astype(np.float32)
y_grad = np.zeros(x_shape).astype(np.float32)
y_grad[0, 0, 0, 0] = 1.
# y_grad = np.random.random_sample(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad( x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format) x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format)
...@@ -283,7 +289,8 @@ class TestBatchNormOp(OpTest): ...@@ -283,7 +289,8 @@ class TestBatchNormOp(OpTest):
set_output_grad( set_output_grad(
scope, scope,
["y_out", "mean", "variance", "saved_mean", "saved_variance"], ["y_out", "mean", "variance", "saved_mean", "saved_variance"],
place) place,
feed_dict={"y_out": y_grad})
batch_norm_op_grad.run(scope, ctx) batch_norm_op_grad.run(scope, ctx)
x_grad_tensor = create_or_get_tensor(scope, x_grad_tensor = create_or_get_tensor(scope,
...@@ -297,8 +304,6 @@ class TestBatchNormOp(OpTest): ...@@ -297,8 +304,6 @@ class TestBatchNormOp(OpTest):
None, place) None, place)
# check gradient output # check gradient output
print 'var x_grad tensor: ', str(place), np.array(x_grad_tensor)
print 'var x_grad by python: ', str(place), x_grad_ref
self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad")
self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad")
self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册