未验证 提交 38e5cd00 编写于 作者: G gouzil 提交者: GitHub

[fluid] decoupling abn op (#53826)

上级 a63fb4c8
......@@ -17,17 +17,159 @@
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/phi/kernels/batch_norm_grad_kernel.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
namespace paddle {
namespace operators {
class InplaceABNOp : public paddle::operators::BatchNormOp {
class InplaceABNOp : public framework::OperatorWithKernel {
public:
using paddle::operators::BatchNormOp::BatchNormOp;
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNorm");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "BatchNorm");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "BatchNorm");
OP_INOUT_CHECK(ctx->HasInput("Mean"), "Input", "Mean", "BatchNorm");
OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance", "BatchNorm");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "BatchNorm");
bool is_test = ctx->Attrs().Get<bool>("is_test");
bool trainable_stats = ctx->Attrs().Get<bool>("trainable_statistics");
bool test_mode = is_test && (!trainable_stats);
if (!test_mode) {
OP_INOUT_CHECK(
ctx->HasOutput("MeanOut"), "Output", "MeanOut", "BatchNorm");
OP_INOUT_CHECK(
ctx->HasOutput("VarianceOut"), "Output", "VarianceOut", "BatchNorm");
OP_INOUT_CHECK(
ctx->HasOutput("SavedMean"), "Output", "SavedMean", "BatchNorm");
OP_INOUT_CHECK(ctx->HasOutput("SavedVariance"),
"Output",
"SavedVariance",
"BatchNorm");
}
// make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0],
ctx->Outputs("MeanOut")[0],
platform::errors::InvalidArgument(
"Mean and MeanOut should share the same memory"));
PADDLE_ENFORCE_EQ(
ctx->Inputs("Variance")[0],
ctx->Outputs("VarianceOut")[0],
platform::errors::InvalidArgument(
"Variance and VarianceOut should share the same memory"));
const auto x_dims = ctx->GetInputDim("X");
for (int i = 0; i < x_dims.size(); i++) {
PADDLE_ENFORCE_EQ(
(x_dims[i] == -1) || (x_dims[i] > 0),
true,
platform::errors::InvalidArgument(
"Each dimension of input tensor is expected to be -1 or a "
"positive number, but received %d. Input's shape is [%s].",
x_dims[i],
x_dims));
}
const DataLayout data_layout =
phi::StringToDataLayout(ctx->Attrs().Get<std::string>("data_layout"));
if (ctx->IsRuntime() && ctx->HasInput("MomentumTensor")) {
auto mom = ctx->Inputs("MomentumTensor");
PADDLE_ENFORCE_EQ(mom.size(),
1,
platform::errors::InvalidArgument(
"The input tensor MomentumTensor's size must be 1"
"But received: MomentumTensor's size is [%d]",
mom.size()));
}
PADDLE_ENFORCE_GE(x_dims.size(),
2,
platform::errors::InvalidArgument(
"ShapeError: the dimension of input "
"X must greater than or equal to 2. But received: "
"the shape of input "
"X = [%s], the dimension of input X =[%d]",
x_dims,
x_dims.size()));
PADDLE_ENFORCE_LE(x_dims.size(),
5,
platform::errors::InvalidArgument(
"ShapeError: the dimension of input X "
"must smaller than or equal to 5. But received: the "
"shape of input X "
"= [%s], the dimension of input X = [%d]",
x_dims,
x_dims.size()));
VLOG(4) << ctx->IsRunMKLDNNKernel();
VLOG(4) << data_layout;
const int64_t C = ((ctx->IsRunMKLDNNKernel() == true) ||
(data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
auto scale_dim = ctx->GetInputDim("Scale");
auto bias_dim = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(
scale_dim.size(),
1UL,
platform::errors::InvalidArgument(
"ShapeError: the dimension of scale must equal to 1."
"But received: the shape of scale is [%s], the dimension "
"of scale is [%d]",
scale_dim,
scale_dim.size()));
PADDLE_ENFORCE_EQ(
bias_dim.size(),
1UL,
platform::errors::InvalidArgument(
"ShapeError: the dimension of bias must equal to 1."
"But received: the shape of bias is [%s],the dimension "
"of bias is [%d]",
bias_dim,
bias_dim.size()));
bool check = true;
if ((!ctx->IsRuntime()) &&
(phi::product(scale_dim) <= 0 || phi::product(bias_dim) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(scale_dim[0],
C,
platform::errors::InvalidArgument(
"ShapeError: the shape of scale must equal to [%d]"
"But received: the shape of scale is [%d]",
C,
scale_dim[0]));
PADDLE_ENFORCE_EQ(bias_dim[0],
C,
platform::errors::InvalidArgument(
"ShapeError: the shape of bias must equal to [%d]"
"But received: the shape of bias is [%d]",
C,
bias_dim[0]));
}
ctx->SetOutputDim("Y", x_dims);
ctx->ShareLoD("X", "Y");
VLOG(4) << x_dims;
ctx->SetOutputDim("MeanOut", {C});
ctx->SetOutputDim("VarianceOut", {C});
if (!test_mode) {
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
}
if (ctx->HasOutput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", {-1});
}
}
protected:
phi::KernelKey GetExpectedKernelType(
......@@ -65,10 +207,9 @@ class InplaceABNOp : public paddle::operators::BatchNormOp {
}
};
class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
class InplaceABNGradOp : public framework::OperatorWithKernel {
public:
using paddle::operators::BatchNormGradOp::BatchNormGradOp;
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
// check input
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InplaceABNGrad");
......@@ -155,10 +296,82 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
}
};
class InplaceABNOpMaker : public paddle::operators::BatchNormOpMaker {
class InplaceABNOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
BatchNormOpMaker::Make();
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "")
.SetDefault(1e-5)
.AddCustomChecker([](const float& epsilon) {
PADDLE_ENFORCE_GE(
epsilon,
0.0f,
platform::errors::InvalidArgument(
"'epsilon' should be greater or equal than 0.0."));
PADDLE_ENFORCE_LE(
epsilon,
0.001f,
platform::errors::InvalidArgument(
"'epsilon' should be less or equal than 0.001."));
});
AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
AddInput("X", "The input tensor");
AddInput("Scale",
"Scale is a 1-dimensional tensor of size C "
"that is applied to the output");
AddInput("Bias",
"Bias is a 1-dimensional tensor of size C "
"that is applied to the output");
AddInput("Mean",
"The global mean (for training) or "
"estimated mean (for testing)");
AddInput("Variance",
"The global variance (for training) "
"or estimated Variance (for testing)");
AddInput(
"MomentumTensor",
"(phi::DenseTensor<float32>, optional) If provided, batch_norm will "
"use this as momentum, this has a higher priority than "
"attr(momentum), the shape of this tensor MUST BE [1].")
.AsDispensable();
AddOutput("Y", "result after normalization");
AddOutput("MeanOut",
"Share memory with Mean. "
"Store the global mean when training");
AddOutput("VarianceOut",
"Share memory with Variance. "
"Store the global Variance when training");
AddOutput("SavedMean",
"Mean of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
AddOutput("SavedVariance",
"Variance of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
AddOutput("ReserveSpace",
"Reserve GPU space for triggering the new semi-persistent "
"NHWC kernel")
.AsDispensable()
.AsExtra();
AddAttr<bool>("use_global_stats",
"(bool, default false) Whether to use global mean and "
"variance. In inference or test mode, set use_global_stats "
"to true or is_test true. the behavior is equivalent. "
"In train mode, when setting use_global_stats True, the "
"global mean and variance are also used during train time, "
"the BN acts as scaling and shiffting.")
.SetDefault(false);
AddAttr<bool>(
"trainable_statistics",
"(bool, default false) Whether to calculate mean and variance "
"in test mode. If setting true in test mode, mean and variace "
"will be calculated by current batch statistics.")
.SetDefault(false);
AddAttr<std::string>(
"activation",
"(enum string, default identity, can be identity|elu|leaky-relu) "
......@@ -174,6 +387,17 @@ class InplaceABNOpMaker : public paddle::operators::BatchNormOpMaker {
"(bool, default false) Whether use synchronize batch "
"normalization.")
.SetDefault(false);
AddComment(R"DOC(
Batch Normalization.
Batch Norm has been implemented as discussed in the paper:
https://arxiv.org/pdf/1502.03167.pdf
Can be used as a normalizer function for conv2d and fully_connected operations.
The required data format for this layer is one of the following:
1. NHWC `[batch, in_height, in_width, in_channels]`
2. NCHW `[batch, in_channels, in_height, in_width]`
)DOC");
}
};
......@@ -358,6 +582,16 @@ class InplaceABNGradKernel : public framework::OpKernel<T> {
}
};
class InplaceABNOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
}
};
} // namespace operators
} // namespace paddle
......@@ -367,7 +601,7 @@ DECLARE_INPLACE_OP_INFERER(InplaceAbnOpInplaceInferer, {"X", "Y"});
REGISTER_OPERATOR(inplace_abn,
ops::InplaceABNOp,
ops::InplaceABNOpMaker,
ops::BatchNormOpInferVarType,
ops::InplaceABNOpInferVarType,
ops::InplaceABNOpGradMaker<paddle::framework::OpDesc>,
ops::InplaceABNOpGradMaker<paddle::imperative::OpBase>,
InplaceAbnOpInplaceInferer)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册