未验证 提交 429c0b62 编写于 作者: F Feiyu Chan 提交者: GitHub

support channel last in BatchNorm*d (#27961)

1. support channel last in BatchNorm*d (#27875)
2. fix a bug in batch_norm_op cuda kernel by extracting ResizeToChannelFist(Last), TransToChannelFirst(Last) to operators/layer_utils.h
上级 39c31a20
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/layout_utils.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.h" #include "paddle/fluid/operators/norm_utils.h"
...@@ -41,127 +42,6 @@ template <typename T> ...@@ -41,127 +42,6 @@ template <typename T>
using ConstEigenVectorArrayMap = using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>; Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename DeviceContext, typename T>
inline void ResizeToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[3];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 4, 1, 2, 3};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 3, 1, 2};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
} else if (dim == 1) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 1};
math::Transpose<DeviceContext, T, 3> trans3;
trans3(dev_ctx, *input, transformed_input, axis);
}
}
template <typename DeviceContext, typename T>
inline void ResizeToChannelLast(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelLast(const framework::ExecutionContext& context,
const Tensor* input, Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 4, 1};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 1};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
} else if (dim == 1) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 1};
math::Transpose<DeviceContext, T, 3> trans3;
trans3(dev_ctx, *input, transformed_input, axis);
}
}
class BatchNormOp : public framework::OperatorWithKernel { class BatchNormOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/layout_utils.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
...@@ -138,102 +139,6 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim, ...@@ -138,102 +139,6 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim,
return !(filter_1 && strides_1 && padding_0 && dilation_1); return !(filter_1 && strides_1 && padding_0 && dilation_1);
} }
template <typename DeviceContext, typename T>
inline void ResizeToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[3];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void ResizeToChannelLast(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 4, 1, 2, 3};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 3, 1, 2};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelLast(const framework::ExecutionContext& context,
const Tensor* input, Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 4, 1};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 1};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
}
}
// Define Op classes in .h file so that other conv // Define Op classes in .h file so that other conv
// operator implementations can reuse the code. // operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
inline void ResizeToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[3];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void ResizeToChannelLast(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
VLOG(5) << "Why am I called?";
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 4, 1, 2, 3};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 3, 1, 2};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
} else if (dim == 1) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 1};
math::Transpose<DeviceContext, T, 3> trans3;
trans3(dev_ctx, *input, transformed_input, axis);
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelLast(const framework::ExecutionContext& context,
const Tensor* input, Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 4, 1};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 1};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
} else if (dim == 1) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 1};
math::Transpose<DeviceContext, T, 3> trans3;
trans3(dev_ctx, *input, transformed_input, axis);
}
}
} // namespace operators
} // namespace paddle
...@@ -168,5 +168,59 @@ class TestBatchNorm(unittest.TestCase): ...@@ -168,5 +168,59 @@ class TestBatchNorm(unittest.TestCase):
self.assertTrue(np.allclose(y1, y2)) self.assertTrue(np.allclose(y1, y2))
class TestBatchNormChannelLast(unittest.TestCase):
def setUp(self):
self.original_dtyep = paddle.get_default_dtype()
paddle.set_default_dtype("float64")
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
self.places.append(fluid.CUDAPlace(0))
def tearDown(self):
paddle.set_default_dtype(self.original_dtyep)
def test_1d(self):
for p in self.places:
with fluid.dygraph.guard(p):
x = paddle.randn([2, 6, 4])
net1 = paddle.nn.BatchNorm1d(4, data_format="NLC")
net2 = paddle.nn.BatchNorm1d(4)
net2.weight = net1.weight
net2.bias = net1.bias
y1 = net1(x)
channel_first_x = paddle.transpose(x, [0, 2, 1])
y2 = net2(channel_first_x)
y2 = paddle.transpose(y2, [0, 2, 1])
self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True)
def test_2d(self):
for p in self.places:
with fluid.dygraph.guard(p):
x = paddle.randn([2, 6, 6, 4])
net1 = paddle.nn.BatchNorm2d(4, data_format="NHWC")
net2 = paddle.nn.BatchNorm2d(4)
net2.weight = net1.weight
net2.bias = net1.bias
y1 = net1(x)
channel_first_x = paddle.transpose(x, [0, 3, 1, 2])
y2 = net2(channel_first_x)
y2 = paddle.transpose(y2, [0, 2, 3, 1])
self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True)
def test_3d(self):
for p in self.places:
with fluid.dygraph.guard(p):
x = paddle.randn([2, 6, 6, 6, 4])
net1 = paddle.nn.BatchNorm3d(4, data_format="NDHWC")
net2 = paddle.nn.BatchNorm3d(4)
net2.weight = net1.weight
net2.bias = net1.bias
y1 = net1(x)
channel_first_x = paddle.transpose(x, [0, 4, 1, 2, 3])
y2 = net2(channel_first_x)
y2 = paddle.transpose(y2, [0, 2, 3, 4, 1])
self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -138,7 +138,7 @@ def batch_norm(x, ...@@ -138,7 +138,7 @@ def batch_norm(x,
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Defalut False. training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Defalut False.
data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW" or "NCDHW". Defalut "NCHW". data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Defalut "NCHW".
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
Returns: Returns:
...@@ -174,13 +174,13 @@ def batch_norm(x, ...@@ -174,13 +174,13 @@ def batch_norm(x,
mean_out = running_mean mean_out = running_mean
variance_out = running_var variance_out = running_var
true_data_format = ['NC', 'NCL', 'NCHW', 'NCDHW'] true_data_format = ['NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC']
if data_format not in true_data_format: if data_format not in true_data_format:
raise ValueError( raise ValueError(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', but receive {}". "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', "
format(data_format)) "'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format))
data_format = 'NCHW' data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
if in_dygraph_mode(): if in_dygraph_mode():
# for dygraph need tuple # for dygraph need tuple
......
...@@ -719,14 +719,15 @@ class BatchNorm1d(_BatchNormBase): ...@@ -719,14 +719,15 @@ class BatchNorm1d(_BatchNormBase):
If it is set to None or one attribute of ParamAttr, batch_norm If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable.
If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
data_format(str, optional): Specify the input data format, may be "NC", "NCL". Defalut "NCL". data_format(str, optional): Specify the input data format, may be "NC", "NCL" or "NLC". Defalut "NCL".
track_running_stats(bool, optional): Whether to use global mean and variance. In train period, track_running_stats(bool, optional): Whether to use global mean and variance. In train period,
True will track global mean and variance used for inference. When inference, track_running_stats must be True will track global mean and variance used for inference. When inference, track_running_stats must be
True. Default: True. True. Default: True.
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
Shape: Shape:
- x: 2-D or 3-D tensor with shape: (batch, num_features) or (batch, num_features, length). - x: 2-D or 3-D tensor with shape: (batch, num_features) or (batch, num_features, length) when data_format is "NC" or "NCL",
(batch, length, num_features) when data_format is "NLC".
- output: 3-D tensor with same shape as input x. - output: 3-D tensor with same shape as input x.
Returns: Returns:
...@@ -755,8 +756,11 @@ class BatchNorm1d(_BatchNormBase): ...@@ -755,8 +756,11 @@ class BatchNorm1d(_BatchNormBase):
def _check_data_format(self, input): def _check_data_format(self, input):
if input == 'NCHW' or input == 'NC' or input == 'NCL': if input == 'NCHW' or input == 'NC' or input == 'NCL':
self._data_format = 'NCHW' self._data_format = 'NCHW'
elif input == "NHWC" or input == 'NLC':
self._data_format = "NHWC"
else: else:
raise ValueError('expected NC , NCL or None for data_format input') raise ValueError(
'expected NC , NCL, NLC or None for data_format input')
def _check_input_dim(self, input): def _check_input_dim(self, input):
if len(input.shape) != 2 and len(input.shape) != 3: if len(input.shape) != 2 and len(input.shape) != 3:
...@@ -812,14 +816,15 @@ class BatchNorm2d(_BatchNormBase): ...@@ -812,14 +816,15 @@ class BatchNorm2d(_BatchNormBase):
If it is set to None or one attribute of ParamAttr, batch_norm If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable.
If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
data_format(str, optional): Specify the input data format, the data format can be "NCHW". Default: NCHW. data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW.
track_running_stats(bool, optional): Whether to use global mean and variance. In train period, track_running_stats(bool, optional): Whether to use global mean and variance. In train period,
True will track global mean and variance used for inference. When inference, track_running_stats must be True will track global mean and variance used for inference. When inference, track_running_stats must be
True. Default: True. True. Default: True.
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
Shape: Shape:
- x: 4-D tensor with shape: (batch, num_features, height, weight). - x: 4-D tensor with shape: (batch, num_features, height, weight) when data_format is "NCHW",
or (batch, height, weight, num_features) when data_format is "NHWC".
- output: 4-D tensor with same shape as input x. - output: 4-D tensor with same shape as input x.
Returns: Returns:
...@@ -847,8 +852,10 @@ class BatchNorm2d(_BatchNormBase): ...@@ -847,8 +852,10 @@ class BatchNorm2d(_BatchNormBase):
def _check_data_format(self, input): def _check_data_format(self, input):
if input == 'NCHW': if input == 'NCHW':
self._data_format = input self._data_format = input
elif input == "NHWC":
self._data_format = input
else: else:
raise ValueError('expected NCHW for data_format input') raise ValueError('expected NCHW or NHWC for data_format input')
def _check_input_dim(self, input): def _check_input_dim(self, input):
if len(input.shape) != 4: if len(input.shape) != 4:
...@@ -904,14 +911,15 @@ class BatchNorm3d(_BatchNormBase): ...@@ -904,14 +911,15 @@ class BatchNorm3d(_BatchNormBase):
If it is set to None or one attribute of ParamAttr, batch_norm If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable.
If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
data_format(str, optional): Specify the input data format, the data format can be "NCDHW". Default: NCDHW. data_format(str, optional): Specify the input data format, the data format can be "NCDHW" or "NDHWC. Default: NCDHW.
track_running_stats(bool, optional): Whether to use global mean and variance. In train period, track_running_stats(bool, optional): Whether to use global mean and variance. In train period,
True will track global mean and variance used for inference. When inference, track_running_stats must be True will track global mean and variance used for inference. When inference, track_running_stats must be
True. Default: True. True. Default: True.
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
Shape: Shape:
- x: 5-D tensor with shape: (batch, num_features, dims, height, weight). - x: 5-D tensor with shape: (batch, num_features, dims, height, weight) when data_format is "NCDHW",
or (batch, dims, height, weight, num_features) when data_format is "NDHWC".
- output: 5-D tensor with same shape as input x. - output: 5-D tensor with same shape as input x.
Returns: Returns:
...@@ -939,8 +947,11 @@ class BatchNorm3d(_BatchNormBase): ...@@ -939,8 +947,11 @@ class BatchNorm3d(_BatchNormBase):
def _check_data_format(self, input): def _check_data_format(self, input):
if input == 'NCHW' or input == 'NCDHW': if input == 'NCHW' or input == 'NCDHW':
self._data_format = 'NCHW' self._data_format = 'NCHW'
elif input == "NHWC" or input == "NDHWC":
self._data_format = 'NHWC'
else: else:
raise ValueError('expected NCDHW or None for data_format input') raise ValueError(
'expected NCDHW, NDHWC or None for data_format input')
def _check_input_dim(self, input): def _check_input_dim(self, input):
if len(input.shape) != 5: if len(input.shape) != 5:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册