diff --git a/paddle/fluid/operators/batch_norm_op.h b/paddle/fluid/operators/batch_norm_op.h index 1440b74290ce43a9e30d59ff5ad94e00eb13f9f1..32e956e15282a60554244cabbbb14af2f457b7ce 100644 --- a/paddle/fluid/operators/batch_norm_op.h +++ b/paddle/fluid/operators/batch_norm_op.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/layout_utils.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/norm_utils.h" @@ -41,127 +42,6 @@ template using ConstEigenVectorArrayMap = Eigen::Map>; -template -inline void ResizeToChannelFirst(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[4]; - in_dims_vec[2] = input->dims()[1]; - in_dims_vec[3] = input->dims()[2]; - in_dims_vec[4] = input->dims()[3]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - - } else if (dim == 2) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[3]; - in_dims_vec[2] = input->dims()[1]; - in_dims_vec[3] = input->dims()[2]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } else if (dim == 1) { - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } -} - -template -inline void TransToChannelFirst(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 4, 1, 2, 3}; - math::Transpose trans5; - trans5(dev_ctx, *input, transformed_input, axis); - - } else if (dim == 2) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 3, 1, 2}; - math::Transpose trans4; - trans4(dev_ctx, *input, transformed_input, axis); - } else if (dim == 1) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 1}; - math::Transpose trans3; - trans3(dev_ctx, *input, transformed_input, axis); - } -} - -template -inline void ResizeToChannelLast(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[3]; - in_dims_vec[3] = input->dims()[4]; - in_dims_vec[4] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - - } else if (dim == 2) { - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[3]; - in_dims_vec[3] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } else if (dim == 1) { - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } -} - -template -inline void TransToChannelLast(const framework::ExecutionContext& context, - const Tensor* input, Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 3, 4, 1}; - math::Transpose trans5; - trans5(dev_ctx, *input, transformed_input, axis); - - } else if (dim == 2) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 3, 1}; - math::Transpose trans4; - trans4(dev_ctx, *input, transformed_input, axis); - } else if (dim == 1) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 1}; - math::Transpose trans3; - trans3(dev_ctx, *input, transformed_input, axis); - } -} - class BatchNormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 662fac9e77e023d2e1b173caa5a9769b56eaf0c4..364e3ab8d26c3f35f41f319b3d31b63964b93abe 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/layout_utils.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" @@ -138,102 +139,6 @@ inline bool IsExpand(const std::vector& filter_dim, return !(filter_1 && strides_1 && padding_0 && dilation_1); } -template -inline void ResizeToChannelFirst(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[4]; - in_dims_vec[2] = input->dims()[1]; - in_dims_vec[3] = input->dims()[2]; - in_dims_vec[4] = input->dims()[3]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - - } else if (dim == 2) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[3]; - in_dims_vec[2] = input->dims()[1]; - in_dims_vec[3] = input->dims()[2]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } -} - -template -inline void ResizeToChannelLast(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[3]; - in_dims_vec[3] = input->dims()[4]; - in_dims_vec[4] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - - } else if (dim == 2) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[3]; - in_dims_vec[3] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } -} - -template -inline void TransToChannelFirst(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 4, 1, 2, 3}; - math::Transpose trans5; - trans5(dev_ctx, *input, transformed_input, axis); - - } else if (dim == 2) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 3, 1, 2}; - math::Transpose trans4; - trans4(dev_ctx, *input, transformed_input, axis); - } -} - -template -inline void TransToChannelLast(const framework::ExecutionContext& context, - const Tensor* input, Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 3, 4, 1}; - math::Transpose trans5; - trans5(dev_ctx, *input, transformed_input, axis); - - } else if (dim == 2) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 3, 1}; - math::Transpose trans4; - trans4(dev_ctx, *input, transformed_input, axis); - } -} // Define Op classes in .h file so that other conv // operator implementations can reuse the code. class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/layout_utils.h b/paddle/fluid/operators/layout_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..52fa7fd1079a7d80becf4ef01e8d4543695ede87 --- /dev/null +++ b/paddle/fluid/operators/layout_utils.h @@ -0,0 +1,155 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +inline void ResizeToChannelFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[4]; + in_dims_vec[2] = input->dims()[1]; + in_dims_vec[3] = input->dims()[2]; + in_dims_vec[4] = input->dims()[3]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + + } else if (dim == 2) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[3]; + in_dims_vec[2] = input->dims()[1]; + in_dims_vec[3] = input->dims()[2]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } else if (dim == 1) { + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } +} + +template +inline void ResizeToChannelLast(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[3]; + in_dims_vec[3] = input->dims()[4]; + in_dims_vec[4] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + + } else if (dim == 2) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[3]; + in_dims_vec[3] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } else if (dim == 1) { + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } +} + +template +inline void TransToChannelFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + VLOG(5) << "Why am I called?"; + int dim = input->dims().size() - 2; + if (dim == 3) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 4, 1, 2, 3}; + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + + } else if (dim == 2) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 3, 1, 2}; + math::Transpose trans4; + trans4(dev_ctx, *input, transformed_input, axis); + } else if (dim == 1) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 1}; + math::Transpose trans3; + trans3(dev_ctx, *input, transformed_input, axis); + } +} + +template +inline void TransToChannelLast(const framework::ExecutionContext& context, + const Tensor* input, Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 3, 4, 1}; + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + + } else if (dim == 2) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 3, 1}; + math::Transpose trans4; + trans4(dev_ctx, *input, transformed_input, axis); + } else if (dim == 1) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 1}; + math::Transpose trans3; + trans3(dev_ctx, *input, transformed_input, axis); + } +} + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py index 2af0b31d6fc26c59803f29dcdc54979491767dd2..324d4cf71103678ed17af8b1a0fd410ddd39ed7b 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py @@ -168,5 +168,59 @@ class TestBatchNorm(unittest.TestCase): self.assertTrue(np.allclose(y1, y2)) +class TestBatchNormChannelLast(unittest.TestCase): + def setUp(self): + self.original_dtyep = paddle.get_default_dtype() + paddle.set_default_dtype("float64") + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + self.places.append(fluid.CUDAPlace(0)) + + def tearDown(self): + paddle.set_default_dtype(self.original_dtyep) + + def test_1d(self): + for p in self.places: + with fluid.dygraph.guard(p): + x = paddle.randn([2, 6, 4]) + net1 = paddle.nn.BatchNorm1d(4, data_format="NLC") + net2 = paddle.nn.BatchNorm1d(4) + net2.weight = net1.weight + net2.bias = net1.bias + y1 = net1(x) + channel_first_x = paddle.transpose(x, [0, 2, 1]) + y2 = net2(channel_first_x) + y2 = paddle.transpose(y2, [0, 2, 1]) + self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) + + def test_2d(self): + for p in self.places: + with fluid.dygraph.guard(p): + x = paddle.randn([2, 6, 6, 4]) + net1 = paddle.nn.BatchNorm2d(4, data_format="NHWC") + net2 = paddle.nn.BatchNorm2d(4) + net2.weight = net1.weight + net2.bias = net1.bias + y1 = net1(x) + channel_first_x = paddle.transpose(x, [0, 3, 1, 2]) + y2 = net2(channel_first_x) + y2 = paddle.transpose(y2, [0, 2, 3, 1]) + self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) + + def test_3d(self): + for p in self.places: + with fluid.dygraph.guard(p): + x = paddle.randn([2, 6, 6, 6, 4]) + net1 = paddle.nn.BatchNorm3d(4, data_format="NDHWC") + net2 = paddle.nn.BatchNorm3d(4) + net2.weight = net1.weight + net2.bias = net1.bias + y1 = net1(x) + channel_first_x = paddle.transpose(x, [0, 4, 1, 2, 3]) + y2 = net2(channel_first_x) + y2 = paddle.transpose(y2, [0, 2, 3, 4, 1]) + self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 0beedd96eb70e52706cf8d2b58ffaf0e611a2855..9b78368259127f5aa070344d1f172df91bc18830 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -138,7 +138,7 @@ def batch_norm(x, epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Defalut False. - data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW" or "NCDHW". Defalut "NCHW". + data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Defalut "NCHW". name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Returns: @@ -174,13 +174,13 @@ def batch_norm(x, mean_out = running_mean variance_out = running_var - true_data_format = ['NC', 'NCL', 'NCHW', 'NCDHW'] + true_data_format = ['NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC'] if data_format not in true_data_format: raise ValueError( - "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', but receive {}". - format(data_format)) + "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', " + "'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format)) - data_format = 'NCHW' + data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC' if in_dygraph_mode(): # for dygraph need tuple diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index e9a2df3dc6ecfc5bdc8526d978c4c3009a986574..ad8dc9b64e78a9592018bba773761776d31a91cc 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -719,14 +719,15 @@ class BatchNorm1d(_BatchNormBase): If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - data_format(str, optional): Specify the input data format, may be "NC", "NCL". Defalut "NCL". + data_format(str, optional): Specify the input data format, may be "NC", "NCL" or "NLC". Defalut "NCL". track_running_stats(bool, optional): Whether to use global mean and variance. In train period, True will track global mean and variance used for inference. When inference, track_running_stats must be True. Default: True. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: - - x: 2-D or 3-D tensor with shape: (batch, num_features) or (batch, num_features, length). + - x: 2-D or 3-D tensor with shape: (batch, num_features) or (batch, num_features, length) when data_format is "NC" or "NCL", + (batch, length, num_features) when data_format is "NLC". - output: 3-D tensor with same shape as input x. Returns: @@ -755,8 +756,11 @@ class BatchNorm1d(_BatchNormBase): def _check_data_format(self, input): if input == 'NCHW' or input == 'NC' or input == 'NCL': self._data_format = 'NCHW' + elif input == "NHWC" or input == 'NLC': + self._data_format = "NHWC" else: - raise ValueError('expected NC , NCL or None for data_format input') + raise ValueError( + 'expected NC , NCL, NLC or None for data_format input') def _check_input_dim(self, input): if len(input.shape) != 2 and len(input.shape) != 3: @@ -812,14 +816,15 @@ class BatchNorm2d(_BatchNormBase): If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - data_format(str, optional): Specify the input data format, the data format can be "NCHW". Default: NCHW. + data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. track_running_stats(bool, optional): Whether to use global mean and variance. In train period, True will track global mean and variance used for inference. When inference, track_running_stats must be True. Default: True. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: - - x: 4-D tensor with shape: (batch, num_features, height, weight). + - x: 4-D tensor with shape: (batch, num_features, height, weight) when data_format is "NCHW", + or (batch, height, weight, num_features) when data_format is "NHWC". - output: 4-D tensor with same shape as input x. Returns: @@ -847,8 +852,10 @@ class BatchNorm2d(_BatchNormBase): def _check_data_format(self, input): if input == 'NCHW': self._data_format = input + elif input == "NHWC": + self._data_format = input else: - raise ValueError('expected NCHW for data_format input') + raise ValueError('expected NCHW or NHWC for data_format input') def _check_input_dim(self, input): if len(input.shape) != 4: @@ -904,14 +911,15 @@ class BatchNorm3d(_BatchNormBase): If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - data_format(str, optional): Specify the input data format, the data format can be "NCDHW". Default: NCDHW. + data_format(str, optional): Specify the input data format, the data format can be "NCDHW" or "NDHWC. Default: NCDHW. track_running_stats(bool, optional): Whether to use global mean and variance. In train period, True will track global mean and variance used for inference. When inference, track_running_stats must be True. Default: True. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: - - x: 5-D tensor with shape: (batch, num_features, dims, height, weight). + - x: 5-D tensor with shape: (batch, num_features, dims, height, weight) when data_format is "NCDHW", + or (batch, dims, height, weight, num_features) when data_format is "NDHWC". - output: 5-D tensor with same shape as input x. Returns: @@ -939,8 +947,11 @@ class BatchNorm3d(_BatchNormBase): def _check_data_format(self, input): if input == 'NCHW' or input == 'NCDHW': self._data_format = 'NCHW' + elif input == "NHWC" or input == "NDHWC": + self._data_format = 'NHWC' else: - raise ValueError('expected NCDHW or None for data_format input') + raise ValueError( + 'expected NCDHW, NDHWC or None for data_format input') def _check_input_dim(self, input): if len(input.shape) != 5: