提交 03789a7d 编写于 作者: Z zchen0211

batch norm fully tortured and passed

上级 8a07aff4
......@@ -208,8 +208,15 @@ class BatchNormGradKernel<platform::GPUPlace, T>
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
std::vector<int> dims = {N, C, H, W, D};
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C};
std::vector<int> dims;
std::vector<int> strides;
if (tensor_format == TensorFormat::NCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * C * D, 1, W * D * C, D * C, C};
}
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
......
......@@ -96,22 +96,25 @@ def create_or_get_tensor(scope, var_name, var, place):
return tensor
def set_output_grad(scope, outputs, place):
def __set_tensor__(name):
def set_output_grad(scope, outputs, place, feed_dict=None):
def __set_tensor__(name, data=None):
out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype()
if out_dtype == core.DataType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32)
else:
raise ValueError("Not supported data type " + str(out_dtype))
if data is None:
if out_dtype == core.DataType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32)
else:
raise ValueError("Not supported data type " + str(out_dtype))
grad_tensor.set(data, place)
for output in outputs:
__set_tensor__(output)
data = None
if output in feed_dict:
data = feed_dict[output]
__set_tensor__(output, data)
class TestBatchNormOp(OpTest):
......@@ -119,7 +122,7 @@ class TestBatchNormOp(OpTest):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_python(self):
data_format = "NCHW"
data_format = "NHWC"
epsilon = 0.00001
momentum = 0.9
......@@ -214,7 +217,10 @@ class TestBatchNormOp(OpTest):
saved_variance = 1. / np.sqrt(var_ref + epsilon)
# for gradient test
y_grad = np.ones(x_shape).astype(np.float32)
# y_grad = np.ones(x_shape).astype(np.float32)
y_grad = np.zeros(x_shape).astype(np.float32)
y_grad[0, 0, 0, 0] = 1.
# y_grad = np.random.random_sample(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format)
......@@ -283,7 +289,8 @@ class TestBatchNormOp(OpTest):
set_output_grad(
scope,
["y_out", "mean", "variance", "saved_mean", "saved_variance"],
place)
place,
feed_dict={"y_out": y_grad})
batch_norm_op_grad.run(scope, ctx)
x_grad_tensor = create_or_get_tensor(scope,
......@@ -297,8 +304,6 @@ class TestBatchNormOp(OpTest):
None, place)
# check gradient output
print 'var x_grad tensor: ', str(place), np.array(x_grad_tensor)
print 'var x_grad by python: ', str(place), x_grad_ref
self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad")
self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad")
self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册