未验证 提交 6b475981 编写于 作者: Q QI JUN 提交者: GitHub

add data layout (#6832)

* add data layout

* fix ci
上级 ad6d6e9c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle {
namespace framework {
enum DataLayout {
kNHWC = 0,
kNCHW = 1,
kAnyLayout = 2,
};
inline DataLayout StringToDataLayout(const std::string& str) {
if (str == "NHWC" || str == "nhwc") {
return DataLayout::kNHWC;
} else if (str == "NCHW" || str == "nchw") {
return DataLayout::kNCHW;
} else {
PADDLE_THROW("Unknown storage order string: %s", str);
}
}
} // namespace framework
} // namespace paddle
...@@ -13,12 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/batch_norm_op.h" #include "paddle/operators/batch_norm_op.h"
#include "paddle/framework/data_layout.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
template <typename T> template <typename T>
using EigenArrayMap = using EigenArrayMap =
...@@ -60,15 +62,15 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -60,15 +62,15 @@ class BatchNormOp : public framework::OperatorWithKernel {
"Variance and VarianceOut should share the same memory"); "Variance and VarianceOut should share the same memory");
const auto x_dims = ctx->GetInputDim("X"); const auto x_dims = ctx->GetInputDim("X");
const TensorFormat tensor_format = const DataLayout data_layout = framework::StringToDataLayout(
StringToTensorFormat(ctx->Attrs().Get<std::string>("tensor_format")); ctx->Attrs().Get<std::string>("data_layout"));
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"Input X must have 2 to 5 dimensions."); "Input X must have 2 to 5 dimensions.");
const int C = const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C);
...@@ -90,7 +92,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -90,7 +92,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("is_test", "").SetDefault(false); AddAttr<bool>("is_test", "").SetDefault(false);
AddAttr<float>("momentum", "").SetDefault(0.9); AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "").SetDefault(1e-5); AddAttr<float>("epsilon", "").SetDefault(1e-5);
AddAttr<std::string>("tensor_format", "").SetDefault("NCHW"); AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
AddInput("X", "The input tensor"); AddInput("X", "The input tensor");
AddInput("Scale", AddInput("Scale",
"Scale is a 1-dimensional tensor of size C " "Scale is a 1-dimensional tensor of size C "
...@@ -141,9 +143,9 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -141,9 +143,9 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum"); const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const std::string tensor_format_str = const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
ctx.Attr<std::string>("tensor_format"); const DataLayout data_layout =
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str); framework::StringToDataLayout(data_layout_str);
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
...@@ -151,8 +153,8 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -151,8 +153,8 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
"The Input dim size should be between 2 and 5"); "The Input dim size should be between 2 and 5");
const int N = x_dims[0]; const int N = x_dims[0];
const int C = const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
const int sample_size = x->numel() / N / C; const int sample_size = x->numel() / N / C;
auto *y = ctx.Output<Tensor>("Y"); auto *y = ctx.Output<Tensor>("Y");
...@@ -177,8 +179,8 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -177,8 +179,8 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
saved_mean_e.setZero(); saved_mean_e.setZero();
saved_variance_e.setZero(); saved_variance_e.setZero();
switch (tensor_format) { switch (data_layout) {
case TensorFormat::NCHW: { case DataLayout::kNCHW: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C); ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
for (int nc = 0; nc < N * C; ++nc) { for (int nc = 0; nc < N * C; ++nc) {
saved_mean_e(nc % C) += x_arr.col(nc).sum(); saved_mean_e(nc % C) += x_arr.col(nc).sum();
...@@ -191,7 +193,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -191,7 +193,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
saved_variance_e /= N * sample_size; saved_variance_e /= N * sample_size;
break; break;
} }
case TensorFormat::NHWC: { case DataLayout::kNHWC: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size); ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size);
for (int i = 0; i < N * sample_size; ++i) { for (int i = 0; i < N * sample_size; ++i) {
saved_mean_e += x_arr.col(i); saved_mean_e += x_arr.col(i);
...@@ -205,7 +207,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -205,7 +207,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
break; break;
} }
default: default:
PADDLE_THROW("Unknown storage order: %s", tensor_format_str); PADDLE_THROW("Unknown storage order: %s", data_layout_str);
} }
EigenVectorArrayMap<T> running_mean_arr( EigenVectorArrayMap<T> running_mean_arr(
...@@ -247,8 +249,8 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -247,8 +249,8 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
Eigen::Array<T, Eigen::Dynamic, 1> new_bias = Eigen::Array<T, Eigen::Dynamic, 1> new_bias =
bias_arr - mean_arr * inv_std * scale_arr; bias_arr - mean_arr * inv_std * scale_arr;
switch (tensor_format) { switch (data_layout) {
case TensorFormat::NCHW: { case DataLayout::kNCHW: {
EigenArrayMap<T> y_arr(y->mutable_data<T>(ctx.GetPlace()), sample_size, EigenArrayMap<T> y_arr(y->mutable_data<T>(ctx.GetPlace()), sample_size,
N * C); N * C);
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C); ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
...@@ -257,7 +259,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -257,7 +259,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
} }
break; break;
} }
case TensorFormat::NHWC: { case DataLayout::kNHWC: {
EigenArrayMap<T>(y->mutable_data<T>(ctx.GetPlace()), C, EigenArrayMap<T>(y->mutable_data<T>(ctx.GetPlace()), C,
N * sample_size) = N * sample_size) =
(ConstEigenArrayMap<T>(x->data<T>(), C, N * sample_size).colwise() * (ConstEigenArrayMap<T>(x->data<T>(), C, N * sample_size).colwise() *
...@@ -267,7 +269,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -267,7 +269,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
break; break;
} }
default: default:
PADDLE_THROW("Unknown storage order: %d", tensor_format); PADDLE_THROW("Unknown storage order: %d", data_layout);
} }
} }
}; };
...@@ -290,11 +292,11 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -290,11 +292,11 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), ""); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), "");
const auto x_dims = ctx->GetInputDim("X"); const auto x_dims = ctx->GetInputDim("X");
const TensorFormat tensor_format = const DataLayout data_layout = framework::StringToDataLayout(
StringToTensorFormat(ctx->Attrs().Get<std::string>("tensor_format")); ctx->Attrs().Get<std::string>("data_layout"));
const int C = const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
...@@ -333,9 +335,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -333,9 +335,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const auto *saved_mean = ctx.Input<Tensor>("SavedMean"); const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
// SavedVariance have been reverted in forward operator // SavedVariance have been reverted in forward operator
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance"); const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
const std::string tensor_format_str = const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
ctx.Attr<std::string>("tensor_format"); const DataLayout data_layout =
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str); framework::StringToDataLayout(data_layout_str);
// Get the size for each dimension. // Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width] // NCHW [batch_size, in_channels, in_height, in_width]
...@@ -344,8 +346,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -344,8 +346,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
"The Input dim size should be between 2 and 5"); "The Input dim size should be between 2 and 5");
const int N = x_dims[0]; const int N = x_dims[0];
const int C = const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
const int sample_size = x->numel() / N / C; const int sample_size = x->numel() / N / C;
ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C); ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C);
...@@ -376,8 +378,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -376,8 +378,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / (N * sample_size); const auto scale_inv_var_nhw = scale_arr * inv_var_arr / (N * sample_size);
switch (tensor_format) { switch (data_layout) {
case TensorFormat::NCHW: { case DataLayout::kNCHW: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C); ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), sample_size, N * C); ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), sample_size, N * C);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()),
...@@ -400,7 +402,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -400,7 +402,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
} }
break; break;
} }
case TensorFormat::NHWC: { case DataLayout::kNHWC: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size); ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N * sample_size); ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N * sample_size);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C, EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C,
...@@ -425,7 +427,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -425,7 +427,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
break; break;
} }
default: default:
PADDLE_THROW("Unknown storage order: %s", tensor_format_str); PADDLE_THROW("Unknown storage order: %s", data_layout_str);
} }
} }
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/batch_norm_op.h" #include "paddle/operators/batch_norm_op.h"
#include "paddle/framework/data_layout.h"
#include <cfloat> #include <cfloat>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
...@@ -22,12 +23,12 @@ namespace paddle { ...@@ -22,12 +23,12 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = platform::CudnnDataType<T>;
void ExtractNCWHD(const framework::DDim &dims, void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
const TensorFormat &tensor_format, int *N, int *C, int *H, int *N, int *C, int *H, int *W, int *D) {
int *W, int *D) {
*N = dims[0]; *N = dims[0];
if (dims.size() == 2) { if (dims.size() == 2) {
*C = dims[1]; *C = dims[1];
...@@ -35,13 +36,13 @@ void ExtractNCWHD(const framework::DDim &dims, ...@@ -35,13 +36,13 @@ void ExtractNCWHD(const framework::DDim &dims,
*W = 1; *W = 1;
*D = 1; *D = 1;
} else { } else {
*C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1]; *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
*H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1]; *H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*W = dims.size() > 3 *W = dims.size() > 3
? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2]) ? (data_layout == DataLayout::kNCHW ? dims[3] : dims[2])
: 1; : 1;
*D = dims.size() > 4 *D = dims.size() > 4
? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3]) ? (data_layout == DataLayout::kNCHW ? dims[4] : dims[3])
: 1; : 1;
} }
} }
...@@ -56,9 +57,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -56,9 +57,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon")); double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const float momentum = ctx.Attr<float>("momentum"); const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const std::string tensor_format_str = const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
ctx.Attr<std::string>("tensor_format"); const DataLayout data_layout =
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str); framework::StringToDataLayout(data_layout_str);
// Get the size for each dimension. // Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width] // NCHW [batch_size, in_channels, in_height, in_width]
...@@ -67,7 +68,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -67,7 +68,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5"); "The Input dim size should be between 2 and 5");
int N, C, H, W, D; int N, C, H, W, D;
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D); ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t data_desc_;
...@@ -93,7 +94,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -93,7 +94,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
VLOG(1) << "Setting descriptors."; VLOG(1) << "Setting descriptors.";
std::vector<int> dims; std::vector<int> dims;
std::vector<int> strides; std::vector<int> strides;
if (tensor_format == TensorFormat::NCHW) { if (data_layout == DataLayout::kNCHW) {
dims = {N, C, H, W, D}; dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1}; strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else { } else {
...@@ -180,9 +181,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -180,9 +181,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace."); "It must use GPUPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon")); double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string tensor_format_str = const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
ctx.Attr<std::string>("tensor_format"); const DataLayout data_layout =
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str); framework::StringToDataLayout(data_layout_str);
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y")); const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
...@@ -192,7 +193,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -192,7 +193,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5"); "The Input dim size should be between 2 and 5");
int N, C, H, W, D; int N, C, H, W, D;
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D); ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C); PADDLE_ENFORCE_EQ(scale->dims()[0], C);
...@@ -219,7 +220,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -219,7 +220,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
std::vector<int> dims; std::vector<int> dims;
std::vector<int> strides; std::vector<int> strides;
if (tensor_format == TensorFormat::NCHW) { if (data_layout == DataLayout::kNCHW) {
dims = {N, C, H, W, D}; dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1}; strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else { } else {
......
...@@ -19,21 +19,6 @@ limitations under the License. */ ...@@ -19,21 +19,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
enum TensorFormat {
NHWC = 0,
NCHW = 1,
};
inline TensorFormat StringToTensorFormat(const std::string& str) {
if (str == "NHWC" || str == "nhwc") {
return TensorFormat::NHWC;
} else if (str == "NCHW" || str == "nchw") {
return TensorFormat::NCHW;
} else {
PADDLE_THROW("Unknown storage order string: %s", str);
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class BatchNormKernel : public framework::OpKernel<T> { class BatchNormKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -208,7 +208,7 @@ class TestBatchNormOp(OpTest): ...@@ -208,7 +208,7 @@ class TestBatchNormOp(OpTest):
print 'python: NHWC, NCHW, backward checking passed' print 'python: NHWC, NCHW, backward checking passed'
def test_forward_backward(self): def test_forward_backward(self):
def test_with_place(place, tensor_format, shape): def test_with_place(place, data_layout, shape):
# attr # attr
epsilon = 0.00001 epsilon = 0.00001
momentum = 0.9 momentum = 0.9
...@@ -292,7 +292,7 @@ class TestBatchNormOp(OpTest): ...@@ -292,7 +292,7 @@ class TestBatchNormOp(OpTest):
SavedVariance="saved_variance", SavedVariance="saved_variance",
# attrs # attrs
is_test=False, is_test=False,
tensor_format=tensor_format, data_layout=data_layout,
momentum=momentum, momentum=momentum,
epsilon=epsilon) epsilon=epsilon)
...@@ -311,7 +311,7 @@ class TestBatchNormOp(OpTest): ...@@ -311,7 +311,7 @@ class TestBatchNormOp(OpTest):
atol = 1e-4 atol = 1e-4
self.__assert_close(variance_out_tensor, variance_out, self.__assert_close(variance_out_tensor, variance_out,
"variance_out", atol) "variance_out", atol)
print "op test forward passed: ", str(place), tensor_format print "op test forward passed: ", str(place), data_layout
# run backward # run backward
batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set()) batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set())
...@@ -336,7 +336,7 @@ class TestBatchNormOp(OpTest): ...@@ -336,7 +336,7 @@ class TestBatchNormOp(OpTest):
self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad")
self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad")
self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad")
print "op test backward passed: ", str(place), tensor_format print "op test backward passed: ", str(place), data_layout
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu("batch_norm"): if core.is_compile_gpu() and core.op_support_gpu("batch_norm"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册