未验证 提交 ebd0f512 编写于 作者: H hong 提交者: GitHub

Move bn to pten (#39347)

* add bn cpu version; test=develop

* move batch norm to pten

* move batch norm to pten; test=develop

* fix bug; test=develop

* fix func::tranpose depend bug; test=develop

* fix compile bugs; test=develop

* fix use_op batch_norm bug; test=develop

* fix cudnn bn add relu test; test=develop

* fix pten context build and double grad bug; test= develop

* remve useless code; test=develop

* add batch norm gpu fp16 support; test=develop

* fix test bn op bug; test=develop

* remove output dtype set; test=develop

* fix bug; test=develop

* fix bug; test=develop

* fix applay pass to program bug; test=develop

* revert to develop; test=develop

* fix rocm bug; test=develop

* revert operator to develop; test=develop

* fix pre_commit; test=develop

* fix statci check error; test=develop

* resolve conflict; test=develop

* ana batch norm bug;

* revert batch norm op

* resolve conlict

* fix nan inf and speed bug; test=develop

* fix bug; test=develop

* fix error; test=develop

* test expand op; test=develop

* fix bug; test=develop

* resolve confilct

* resolve confilct; test=develop

* polish code; test=develop

* polish code; test=develop

* change mutable data to ctx alloc; test=develop

* make format same with ci; test=develop

* fix format error with ci; test=develop
上级 c16f85f9
......@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include <string>
#include <unordered_set>
#include <gtest/gtest.h>
#include <boost/logic/tribool.hpp>
#include <random>
#include <unordered_set>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
......@@ -25,7 +26,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/place.h"
USE_OP(batch_norm);
USE_OP_ITSELF(batch_norm);
USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN);
USE_OP(conv2d_transpose);
USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
......
......@@ -2215,8 +2215,6 @@ void OperatorWithKernel::BuildPhiKernelContext(
vector_int_attr.end());
pt_kernel_context->EmplaceBackAttr(vector_int64_attr);
}
// TODO(YuanRisheng) Need support vector<int64_t> attr
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
......
......@@ -1289,15 +1289,3 @@ REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp,
ops::BatchNormDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_norm_grad_grad, ops::BatchNormDoubleGradOp,
ops::BatchNormDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
batch_norm_grad_grad,
ops::BatchNormDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchNormDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -32,7 +32,7 @@ namespace platform = paddle::platform;
namespace op = paddle::operators;
using Tensor = paddle::framework::Tensor;
USE_OP(batch_norm);
USE_OP_ITSELF(batch_norm);
USE_CUDA_ONLY_OP(fused_bn_add_activation);
USE_CUDA_ONLY_OP(fused_bn_add_activation_grad);
......
......@@ -17,6 +17,8 @@
#include <string>
#include <vector>
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/phi/kernels/batch_norm_grad_kernel.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
namespace paddle {
namespace operators {
......@@ -202,8 +204,7 @@ class InplaceABNOpGradMaker : public framework::SingleGradOpMaker<T> {
};
template <typename DeviceContext, typename T>
class InplaceABNKernel
: public paddle::operators::BatchNormKernel<DeviceContext, T> {
class InplaceABNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
......@@ -213,7 +214,33 @@ class InplaceABNKernel
auto activation =
GetInplaceABNActivationType(ctx.Attr<std::string>("activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
BatchNormKernel<DeviceContext, T>::Compute(ctx);
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* mean = ctx.Input<Tensor>("Mean");
auto* variance = ctx.Input<Tensor>("Variance");
auto momentum = ctx.Attr<float>("momentum");
auto epsilon = ctx.Attr<float>("epsilon");
auto data_layout = ctx.Attr<std::string>("data_layout");
auto is_test = ctx.Attr<bool>("is_test");
auto use_global_stats = ctx.Attr<bool>("use_global_stats");
auto trainable_statistics = ctx.Attr<bool>("trainable_statistics");
auto fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
auto* mean_out = ctx.Output<Tensor>("MeanOut");
auto* variance_out = ctx.Output<Tensor>("VarianceOut");
auto* saved_mean = ctx.Output<Tensor>("SavedMean");
auto* saved_variance = ctx.Output<Tensor>("SavedVariance");
auto* reserve_space = ctx.Output<Tensor>("ReserveSpace");
auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::BatchNormKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *scale, *bias, *mean, *variance, momentum, epsilon, data_layout,
is_test, use_global_stats, trainable_statistics, fuse_with_relu, y,
mean_out, variance_out, saved_mean, saved_variance, reserve_space);
auto cur_y = EigenVector<T>::Flatten(*y);
InplaceABNActivation<DeviceContext, T> functor;
......@@ -222,8 +249,7 @@ class InplaceABNKernel
};
template <typename DeviceContext, typename T>
class InplaceABNGradKernel
: public paddle::operators::BatchNormGradKernel<DeviceContext, T> {
class InplaceABNGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* y = ctx.Input<Tensor>("Y");
......@@ -244,7 +270,52 @@ class InplaceABNGradKernel
InplaceABNActivation<DeviceContext, T> functor;
functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy);
BatchNormGradKernel<DeviceContext, T>::Compute(ctx);
// BatchNormGradKernel<DeviceContext, T>::Compute(ctx);
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* saved_mean = ctx.Input<Tensor>("SavedMean");
auto* saved_variance = ctx.Input<Tensor>("SavedVariance");
auto momentum = ctx.Attr<float>("momentum");
auto epsilon = ctx.Attr<float>("epsilon");
auto data_layout = ctx.Attr<std::string>("data_layout");
auto is_test = ctx.Attr<bool>("is_test");
auto use_global_stats = ctx.Attr<bool>("use_global_stats");
auto trainable_statistics = ctx.Attr<bool>("trainable_statistics");
auto fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
auto* scale_grad = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* bias_grad = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto* reserve_space = ctx.Input<Tensor>("ReserveSpace");
auto* mean = ctx.Input<Tensor>("ReserveSpace");
auto* variance = ctx.Input<Tensor>("ReserveSpace");
paddle::optional<const Tensor&> space_opt = paddle::none;
paddle::optional<const Tensor&> mean_opt = paddle::none;
paddle::optional<const Tensor&> variance_opt = paddle::none;
if (reserve_space != nullptr) {
space_opt = *reserve_space;
}
if (mean != nullptr) {
mean_opt = *mean;
}
if (variance != nullptr) {
variance_opt = *variance;
}
auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::BatchNormGradRawKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*d_y, *y, *scale, *bias, *saved_mean, *saved_variance, space_opt,
mean_opt, variance_opt, momentum, epsilon, data_layout, is_test,
use_global_stats, trainable_statistics, fuse_with_relu, true, d_x,
scale_grad, bias_grad);
}
};
......
......@@ -15,14 +15,15 @@ limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/inplace_abn_op.h"
#include "paddle/fluid/operators/sync_batch_norm_op.cu.h"
#include "paddle/phi/kernels/batch_norm_grad_kernel.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class InplaceABNKernel
: public paddle::operators::SyncBatchNormKernel<DeviceContext, T>,
public paddle::operators::BatchNormKernel<DeviceContext, T> {
: public paddle::operators::SyncBatchNormKernel<DeviceContext, T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* y = ctx.Output<Tensor>("Y");
......@@ -36,7 +37,33 @@ class InplaceABNKernel
if (ctx.Attr<bool>("use_sync_bn")) {
SyncBatchNormKernel<DeviceContext, T>::Compute(ctx);
} else {
BatchNormKernel<DeviceContext, T>::Compute(ctx);
// BatchNormKernel<DeviceContext, T>::Compute(ctx);
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* mean = ctx.Input<Tensor>("Mean");
auto* variance = ctx.Input<Tensor>("Variance");
auto momentum = ctx.Attr<float>("momentum");
auto epsilon = ctx.Attr<float>("epsilon");
auto data_layout = ctx.Attr<std::string>("data_layout");
auto is_test = ctx.Attr<bool>("is_test");
auto use_global_stats = ctx.Attr<bool>("use_global_stats");
auto trainable_statistics = ctx.Attr<bool>("trainable_statistics");
auto fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
auto* mean_out = ctx.Output<Tensor>("MeanOut");
auto* variance_out = ctx.Output<Tensor>("VarianceOut");
auto* saved_mean = ctx.Output<Tensor>("SavedMean");
auto* saved_variance = ctx.Output<Tensor>("SavedVariance");
auto* reserve_space = ctx.Output<Tensor>("ReserveSpace");
auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::BatchNormKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *scale, *bias, *mean, *variance, momentum, epsilon, data_layout,
is_test, use_global_stats, trainable_statistics, fuse_with_relu, y,
mean_out, variance_out, saved_mean, saved_variance, reserve_space);
}
auto cur_y = EigenVector<T>::Flatten(*y);
......@@ -49,8 +76,7 @@ class InplaceABNKernel
// https://kevinzakka.github.io/2016/09/14/batch_normalization/
template <typename DeviceContext, typename T>
class InplaceABNGradKernel
: public paddle::operators::SyncBatchNormGradKernel<DeviceContext, T>,
public paddle::operators::BatchNormGradKernel<DeviceContext, T> {
: public paddle::operators::SyncBatchNormGradKernel<DeviceContext, T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* y = ctx.Input<Tensor>("Y");
......@@ -74,7 +100,50 @@ class InplaceABNGradKernel
if (ctx.Attr<bool>("use_sync_bn")) {
SyncBatchNormGradKernel<DeviceContext, T>::Compute(ctx);
} else {
BatchNormGradKernel<DeviceContext, T>::Compute(ctx);
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* saved_mean = ctx.Input<Tensor>("SavedMean");
auto* saved_variance = ctx.Input<Tensor>("SavedVariance");
auto momentum = ctx.Attr<float>("momentum");
auto epsilon = ctx.Attr<float>("epsilon");
auto data_layout = ctx.Attr<std::string>("data_layout");
auto is_test = ctx.Attr<bool>("is_test");
auto use_global_stats = ctx.Attr<bool>("use_global_stats");
auto trainable_statistics = ctx.Attr<bool>("trainable_statistics");
auto fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
auto* scale_grad = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* bias_grad = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto* reserve_space = ctx.Input<Tensor>("ReserveSpace");
auto* mean = ctx.Input<Tensor>("ReserveSpace");
auto* variance = ctx.Input<Tensor>("ReserveSpace");
paddle::optional<const Tensor&> space_opt = paddle::none;
paddle::optional<const Tensor&> mean_opt = paddle::none;
paddle::optional<const Tensor&> variance_opt = paddle::none;
if (reserve_space != nullptr) {
space_opt = *reserve_space;
}
if (mean != nullptr) {
mean_opt = *mean;
}
if (variance != nullptr) {
variance_opt = *variance;
}
auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::BatchNormGradRawKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*d_y, *y, *scale, *bias, *saved_mean, *saved_variance, space_opt,
mean_opt, variance_opt, momentum, epsilon, data_layout, is_test,
use_global_stats, trainable_statistics, fuse_with_relu, true, d_x,
scale_grad, bias_grad);
}
}
};
......
......@@ -389,11 +389,12 @@ __global__ void DoubleGradComputeDDYWithGlobal(
}
template <typename DeviceContext, typename T>
void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
void NormDoubleGradFunctor(const DeviceContext &ctx,
const DataLayout data_layout, const Tensor *X,
const Tensor *Scale, const Tensor *dY,
const Tensor *Saved_mean,
const Tensor *Saved_variance, const double epsilon,
const Tensor *Saved_variance, const Tensor *Mean,
const Tensor *Variance, const double epsilon,
const bool use_global_stats, const Tensor *ddX,
const Tensor *ddScale, const Tensor *ddBias,
Tensor *dX, Tensor *dScale, Tensor *ddY) {
......@@ -404,8 +405,7 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
const T *ddscale_data = (ddScale == nullptr ? nullptr : ddScale->data<T>());
const T *ddbias_data = (ddBias == nullptr ? nullptr : ddBias->data<T>());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_constant;
phi::funcs::SetConstant<DeviceContext, T> set_constant;
auto &x_dims = X->dims();
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
......@@ -416,7 +416,7 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
Tensor scale_tmp;
if (!Scale) {
scale_tmp.mutable_data<T>({C}, ctx.GetPlace());
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
set_constant(ctx, &scale_tmp, static_cast<T>(1));
}
const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>();
#ifdef __HIPCC__
......@@ -424,15 +424,15 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
#else
const int block = 512;
#endif
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
int max_threads = ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(C, max_blocks);
int grid1 = (num + block - 1) / block;
const T *mean_data, *variance_data;
if (use_global_stats) {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_var = ctx.Input<Tensor>("Variance");
const auto *running_mean = Mean;
const auto *running_var = Variance;
const auto *running_mean_data = running_mean->template data<T>();
const auto *running_var_data = running_var->template data<T>();
mean_data = running_mean_data;
......@@ -440,34 +440,35 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
} else {
const T *smean_data = Saved_mean->data<T>();
const T *svariance_data = Saved_variance->data<T>();
mean_data = smean_data;
variance_data = svariance_data;
}
if (dX) {
T *dx_data = dX->mutable_data<T>(ctx.GetPlace());
set_constant(dev_ctx, dX, static_cast<T>(0));
set_constant(ctx, dX, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDXWithGlobal<
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
T, DataLayout::kNHWC><<<grid1, block, 0, ctx.stream()>>>(
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
dx_data);
} else {
DoubleGradComputeDXWithGlobal<
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
T, DataLayout::kNCHW><<<grid1, block, 0, ctx.stream()>>>(
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
dx_data);
}
} else {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDX<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
T, block, DataLayout::kNHWC><<<grid, block, 0, ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, scale_data,
ddscale_data, N, C, sample_size, epsilon, dx_data);
} else {
DoubleGradComputeDX<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
T, block, DataLayout::kNCHW><<<grid, block, 0, ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, scale_data,
ddscale_data, N, C, sample_size, epsilon, dx_data);
}
......@@ -475,28 +476,28 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
}
if (dScale) {
T *dscale_data = dScale->mutable_data<T>(ctx.GetPlace());
set_constant(dev_ctx, dScale, static_cast<T>(0));
set_constant(ctx, dScale, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDScaleWithGlobal<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
T, block, DataLayout::kNHWC><<<grid, block, 0, ctx.stream()>>>(
ddx_data, variance_data, dy_data, epsilon, N, C, sample_size,
dscale_data);
} else {
DoubleGradComputeDScaleWithGlobal<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
T, block, DataLayout::kNCHW><<<grid, block, 0, ctx.stream()>>>(
ddx_data, variance_data, dy_data, epsilon, N, C, sample_size,
dscale_data);
}
} else {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDScale<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
T, block, DataLayout::kNHWC><<<grid, block, 0, ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, N, C,
sample_size, epsilon, dscale_data);
} else {
DoubleGradComputeDScale<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
T, block, DataLayout::kNCHW><<<grid, block, 0, ctx.stream()>>>(
x_data, mean_data, variance_data, ddx_data, dy_data, N, C,
sample_size, epsilon, dscale_data);
}
......@@ -504,28 +505,28 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
}
if (ddY) {
T *ddy_data = ddY->mutable_data<T>(ctx.GetPlace());
set_constant(dev_ctx, ddY, static_cast<T>(0));
set_constant(ctx, ddY, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDDYWithGlobal<
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
T, DataLayout::kNHWC><<<grid1, block, 0, ctx.stream()>>>(
ddx_data, scale_data, mean_data, variance_data, x_data, ddbias_data,
ddscale_data, epsilon, C, sample_size, num, ddy_data);
} else {
DoubleGradComputeDDYWithGlobal<
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
T, DataLayout::kNCHW><<<grid1, block, 0, ctx.stream()>>>(
ddx_data, scale_data, mean_data, variance_data, x_data, ddbias_data,
ddscale_data, epsilon, C, sample_size, num, ddy_data);
}
} else {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDDY<
T, block, DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
T, block, DataLayout::kNHWC><<<grid, block, 0, ctx.stream()>>>(
x_data, mean_data, variance_data, ddscale_data, ddbias_data,
ddx_data, scale_data, N, C, sample_size, epsilon, ddy_data);
} else {
DoubleGradComputeDDY<
T, block, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
T, block, DataLayout::kNCHW><<<grid, block, 0, ctx.stream()>>>(
x_data, mean_data, variance_data, ddscale_data, ddbias_data,
ddx_data, scale_data, N, C, sample_size, epsilon, ddy_data);
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& dev_ctx,
const DenseTensor& y_grad,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> reserve_space,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
bool is_inplace,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad);
template <typename T, typename Context>
void BatchNormGradKernel(const Context& dev_ctx,
const DenseTensor& y_grad,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> reserve_space,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad);
template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x_grad_grad,
const DenseTensor& scale_grad_grad,
const DenseTensor& bias_grad_grad,
const DenseTensor& y_grad,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* y_grad_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BatchNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
DenseTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space);
} // namespace phi
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace phi {
template <typename T>
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using ConstEigenArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T, typename Context>
void BatchNormKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout_str,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
DenseTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space) {
bool test_mode = is_test && (!trainable_statistics);
bool global_stats = test_mode || use_global_stats;
auto data_layout = paddle::framework::StringToDataLayout(data_layout_str);
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(
x_dims.size(),
2,
phi::errors::InvalidArgument(
"The size of input X's dimensions should be larger than 1."
"But received: the size of input X's dimensions is [%d]",
x_dims.size()));
PADDLE_ENFORCE_LE(
x_dims.size(),
5,
phi::errors::InvalidArgument(
"The size of input X's dimensions should be less than 6."
"But received: the size of input X's dimensionss is [%d]",
x_dims.size()));
const int N = x_dims[0];
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int sample_size = x.numel() / N / C;
// alloc memory
ctx.template Alloc<T>(y);
ctx.template Alloc<T>(mean_out);
ctx.template Alloc<T>(variance_out);
ctx.template Alloc<T>(saved_mean);
ctx.template Alloc<T>(saved_variance);
// input dimension is 2 and the format is NCHW. The input can be regarded
// as NHWC format
if (x_dims.size() == 2 && data_layout == DataLayout::kNCHW) {
data_layout = DataLayout::kNHWC;
}
if (!global_stats) {
// saved_xx is use just in this batch of data
EigenVectorArrayMap<T> saved_mean_e(ctx.template Alloc<T>(saved_mean), C);
EigenVectorArrayMap<T> saved_variance_e(
ctx.template Alloc<T>(saved_variance), C);
saved_mean_e.setZero();
saved_variance_e.setZero();
EigenVectorArrayMap<T> running_mean_arr(ctx.template Alloc<T>(mean_out), C);
EigenVectorArrayMap<T> running_var_arr(ctx.template Alloc<T>(variance_out),
C);
if ((N * sample_size) == 1) {
// Only 1 element in normalization dimension,
// we skip the batch norm calculation, let y = x.
paddle::framework::TensorCopy(x, ctx.GetPlace(), y);
return;
}
switch (data_layout) {
case DataLayout::kNCHW: {
ConstEigenArrayMap<T> x_arr(x.data<T>(), sample_size, N * C);
for (int nc = 0; nc < N * C; ++nc) {
saved_mean_e(nc % C) += x_arr.col(nc).sum();
}
saved_mean_e /= N * sample_size;
for (int nc = 0; nc < N * C; ++nc) {
saved_variance_e(nc % C) +=
(x_arr.col(nc) - saved_mean_e(nc % C)).matrix().squaredNorm();
}
saved_variance_e /= N * sample_size;
break;
}
case DataLayout::kNHWC: {
ConstEigenArrayMap<T> x_arr(x.data<T>(), C, N * sample_size);
for (int i = 0; i < N * sample_size; ++i) {
saved_mean_e += x_arr.col(i);
}
saved_mean_e /= N * sample_size;
for (int i = 0; i < N * sample_size; ++i) {
saved_variance_e +=
(x_arr.col(i) - saved_mean_e) * (x_arr.col(i) - saved_mean_e);
}
saved_variance_e /= N * sample_size;
break;
}
default:
PADDLE_THROW(phi::errors::InvalidArgument("Unknown storage order: %s",
data_layout_str));
}
// if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch
running_mean_arr =
running_mean_arr * momentum + saved_mean_e * (1. - momentum);
running_var_arr =
running_var_arr * momentum + saved_variance_e * (1. - momentum);
}
// use SavedMean and SavedVariance to do normalize
Eigen::Array<T, Eigen::Dynamic, 1> inv_std(C);
if (global_stats) {
ConstEigenVectorArrayMap<T> var_arr(variance.data<T>(), C);
inv_std = (var_arr + epsilon).sqrt().inverse();
} else {
EigenVectorArrayMap<T> saved_inv_std(saved_variance->data<T>(), C);
// inverse SavedVariance first, gradient will use it too.
saved_inv_std = (saved_inv_std + epsilon).inverse().sqrt();
inv_std = saved_inv_std;
}
ConstEigenVectorArrayMap<T> mean_arr(
global_stats ? mean.data<T>() : saved_mean->data<T>(), C);
// ((x - est_mean) * (inv_var) * scale + bias
// formula transform ====>
// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
ConstEigenVectorArrayMap<T> scale_arr(scale.data<T>(), C);
ConstEigenVectorArrayMap<T> bias_arr(bias.data<T>(), C);
Eigen::Array<T, Eigen::Dynamic, 1> new_scale = inv_std * scale_arr;
Eigen::Array<T, Eigen::Dynamic, 1> new_bias =
bias_arr - mean_arr * inv_std * scale_arr;
switch (data_layout) {
case DataLayout::kNCHW: {
EigenArrayMap<T> y_arr(ctx.template Alloc<T>(y), sample_size, N * C);
ConstEigenArrayMap<T> x_arr(x.data<T>(), sample_size, N * C);
for (int nc = 0; nc < N * C; ++nc) {
y_arr.col(nc) = x_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C);
}
break;
}
case DataLayout::kNHWC: {
EigenArrayMap<T>(ctx.template Alloc<T>(y), C, N * sample_size) =
(ConstEigenArrayMap<T>(x.data<T>(), C, N * sample_size).colwise() *
new_scale)
.colwise() +
new_bias;
break;
}
default:
PADDLE_THROW(phi::errors::InvalidArgument("Unknown storage order: %d",
data_layout));
}
}
} // namespace phi
PD_REGISTER_KERNEL(
batch_norm, CPU, ALL_LAYOUT, phi::BatchNormKernel, float, double) {}
此差异已折叠。
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
using Tensor = DenseTensor;
template <typename DeviceContext, typename T>
inline void ResizeToChannelFirst(const DeviceContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = phi::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input->Resize(phi::make_ddim(in_dims_vec));
context.template Alloc<T>(transformed_input);
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = phi::vectorize(input->dims());
in_dims_vec[1] = input->dims()[3];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
transformed_input->Resize(phi::make_ddim(in_dims_vec));
context.template Alloc<T>(transformed_input);
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = phi::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(phi::make_ddim(in_dims_vec));
context.template Alloc<T>(transformed_input);
}
}
template <typename DeviceContext, typename T>
inline void ResizeToChannelLast(const DeviceContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = phi::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(phi::make_ddim(in_dims_vec));
context.template Alloc<T>(transformed_input);
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = phi::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(phi::make_ddim(in_dims_vec));
context.template Alloc<T>(transformed_input);
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = phi::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(phi::make_ddim(in_dims_vec));
context.template Alloc<T>(transformed_input);
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelFirst(const DeviceContext& context,
const Tensor* input,
Tensor* transformed_input) {
VLOG(5) << "Why am I called?";
int dim = input->dims().size() - 2;
if (dim == 3) {
std::vector<int> axis{0, 4, 1, 2, 3};
funcs::Transpose<DeviceContext, T, 5> trans5;
trans5(context, *input, transformed_input, axis);
} else if (dim == 2) {
std::vector<int> axis{0, 3, 1, 2};
funcs::Transpose<DeviceContext, T, 4> trans4;
trans4(context, *input, transformed_input, axis);
} else if (dim == 1) {
std::vector<int> axis{0, 2, 1};
funcs::Transpose<DeviceContext, T, 3> trans3;
trans3(context, *input, transformed_input, axis);
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelLast(const DeviceContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
std::vector<int> axis{0, 2, 3, 4, 1};
funcs::Transpose<DeviceContext, T, 5> trans5;
trans5(context, *input, transformed_input, axis);
} else if (dim == 2) {
std::vector<int> axis{0, 2, 3, 1};
funcs::Transpose<DeviceContext, T, 4> trans4;
trans4(context, *input, transformed_input, axis);
} else if (dim == 1) {
std::vector<int> axis{0, 2, 1};
funcs::Transpose<DeviceContext, T, 3> trans3;
trans3(context, *input, transformed_input, axis);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("batch_norm",
{"X", "Scale", "Bias", "Mean", "Variance"},
{"momentum",
"epsilon",
"data_layout",
"is_test",
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{"Y",
"MeanOut",
"VarianceOut",
"SavedMean",
"SavedVariance",
"ReserveSpace"});
}
KernelSignature BatchNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"batch_norm_grad",
{GradVarName("Y"),
"X",
"Scale",
"Bias",
"SavedMean",
"SavedVariance",
"ReserveSpace",
"Mean",
"Variance"},
{"momentum",
"epsilon",
"data_layout",
"is_test",
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")});
}
KernelSignature BatchNormGradGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("batch_norm_grad_grad",
{"DDX",
"DDScale",
"DDBias",
"DY",
"X",
"Scale",
"SavedMean",
"SavedVariance",
"Mean",
"Variance"},
{"momentum",
"epsilon",
"data_layout",
"is_test",
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{"DX", "DScale", "DDY"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(batch_norm, phi::BatchNormOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(batch_norm_grad,
phi::BatchNormGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(batch_norm_grad_grad,
phi::BatchNormGradGradOpArgumentMapping);
......@@ -520,6 +520,7 @@ def predict_static(args, data):
paddle.enable_static()
exe = fluid.Executor(args.place)
# load inference model
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
args.model_save_dir,
......
......@@ -162,6 +162,7 @@ class TestIRPassBase(unittest.TestCase):
for k, v in self.get_strategy().items():
setattr(build_strategy, k, v)
self.check_before_applied(main2, startup2)
apply_build_strategy(main2, startup2, build_strategy,
{"use_cuda": self.use_cuda})
self.check_after_applied(main2, startup2)
......
......@@ -320,7 +320,7 @@ class TestBatchNormOpInference(unittest.TestCase):
def test_check_output(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
......@@ -342,13 +342,13 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
def test_check_output(self):
places = []
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
places.append(place)
for place in places:
for data_format in ["NCHW", "NHWC"]:
#for data_format in ["NCHW", "NHWC"]:
for data_format in ["NCHW"]:
self.check_with_place(place, data_format, self.dtype,
[2, 3, 4, 5])
self.check_with_place(place, data_format, self.dtype, [2, 3])
......@@ -517,7 +517,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
......@@ -657,7 +657,7 @@ class TestDygraphBatchNormAPIError(unittest.TestCase):
class TestDygraphBatchNormTrainableStats(unittest.TestCase):
def test_dygraph(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
shape = [4, 10, 4, 4]
......@@ -678,7 +678,7 @@ class TestDygraphBatchNormTrainableStats(unittest.TestCase):
def test_static(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
exe = fluid.Executor(p)
......@@ -716,4 +716,6 @@ class TestDygraphBatchNormOpenReserveSpace(unittest.TestCase):
if __name__ == '__main__':
import paddle
paddle.enable_static()
unittest.main()
......@@ -28,7 +28,7 @@ import paddle
class TestBatchNorm(unittest.TestCase):
def test_name(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
with fluid.dygraph.guard(p):
......@@ -36,7 +36,7 @@ class TestBatchNorm(unittest.TestCase):
def test_error(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
#paddle.disable_static()
......@@ -83,7 +83,7 @@ class TestBatchNorm(unittest.TestCase):
def test_dygraph(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
shape = [4, 10, 4, 4]
......@@ -135,7 +135,7 @@ class TestBatchNorm(unittest.TestCase):
def test_static(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
exe = fluid.Executor(p)
......@@ -177,7 +177,7 @@ class TestBatchNormChannelLast(unittest.TestCase):
else:
paddle.set_default_dtype("float64")
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def tearDown(self):
......@@ -247,7 +247,7 @@ class TestBatchNormChannelLast(unittest.TestCase):
class TestBatchNormUseGlobalStats(unittest.TestCase):
def setUp(self):
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
self.init_test()
......@@ -300,4 +300,6 @@ class TestBatchNormUseGlobalStatsCase3(TestBatchNormUseGlobalStats):
if __name__ == '__main__':
import paddle
paddle.enable_static()
unittest.main()
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle
import paddle.fluid.core as core
......@@ -1001,4 +1002,5 @@ create_test_cudnn_channel_last_fp16_class(
TestWithDilation_AsyPadding, grad_check=False)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -231,4 +231,5 @@ class TestExpandV2API(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -23,6 +23,7 @@ import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import compiler
import paddle.fluid.unique_name as unique_name
import paddle
class TestInplaceANBOpTraining(unittest.TestCase):
......@@ -138,14 +139,14 @@ class TestInplaceANBOpTraining(unittest.TestCase):
outs[0].name if not only_forward else None,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
bn_fetches = exe.run(program=comp_prog1,
bn_fetches = exe.run(program=main,
feed={'input': data},
fetch_list=fetch_name)
fetch_outs.append(bn_fetches)
fetch_names.append(fetch_name)
for bn_val, inplace_abn_val, name1, name2 in zip(*(fetch_outs +
fetch_names)):
for bn_val, inplace_abn_val, name1, name2 in zip(*(
fetch_outs + fetch_names)):
self.assertTrue(
np.allclose(
bn_val, inplace_abn_val, atol=1e-2),
......@@ -156,6 +157,7 @@ class TestInplaceANBOpTraining(unittest.TestCase):
def test_op(self):
use_cudas = [False, True] if core.is_compiled_with_cuda() else [False]
#use_cudas = [False]
for use_cuda in use_cudas:
place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
layouts = ["NCHW", "NHWC"]
......@@ -186,4 +188,5 @@ class TestInplaceANBOpTraining(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -21,6 +21,7 @@ import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import gradient_checker
import paddle
from decorator_helper import prog_scope
......@@ -167,4 +168,5 @@ class TestBatchNormDoubleGradCheckCase6(TestBatchNormDoubleGradCheckCase5):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -24,6 +24,7 @@ from simple_nets import init_data, simple_fc_net, fc_with_batchnorm
import seresnext_net
from test_parallel_executor_transformer import transformer, get_feed_data_reader, DeviceType
from fake_reader import fake_imdb_reader
import paddle
def lstm_net(use_feed):
......@@ -309,4 +310,5 @@ class TestProgramPruneBackward(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -507,4 +507,5 @@ class TestReshapeZeroTensor(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册