未验证 提交 561b7278 编写于 作者: S Sławomir Siwek 提交者: GitHub

[PHI] Migrate batch_norm_grad kernel (#48288)

上级 5f995d3f
...@@ -350,7 +350,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -350,7 +350,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Using global stats during training is not supported " "Using global stats during training is not supported "
"in gradient op kernel of batch_norm_mkldnn_op now.")); "in oneDNN version of batch_norm_gradient kernel now."));
} }
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormGrad"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormGrad");
......
...@@ -113,7 +113,7 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { ...@@ -113,7 +113,7 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Using global stats during training is not supported " "Using global stats during training is not supported "
"in gradient op kernel of batch_norm_mkldnn_op now.")); "in oneDNN version of batch_norm_gradient kernel now."));
} }
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "InplaceABNGrad"); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "InplaceABNGrad");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace operators {
using dnnl::memory;
using dnnl::primitive;
using dnnl::stream;
using paddle::platform::MKLDNNDeviceContext;
template <typename T>
class BatchNormMKLDNNHandler : public phi::funcs::OneDNNHandlerNoCachingT<
T,
dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward> {
public:
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const dnnl::engine mkldnn_engine,
const Tensor *in_x,
const Tensor *scale,
const Tensor *out_grad)
: phi::funcs::OneDNNHandlerNoCachingT<T,
dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward>(
mkldnn_engine, ctx.GetPlace()) {
auto scale_tz = phi::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE_EQ(
scale_tz.size(),
1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
const float epsilon = ctx.Attr<float>("epsilon");
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training,
in_x->mem_desc(),
epsilon,
dnnl::normalization_flags::use_scale_shift);
this->AcquireBackwardPrimitiveDescriptor(
dnnl::prop_kind::backward,
out_grad->mem_desc(),
in_x->mem_desc(),
epsilon,
dnnl::normalization_flags::use_scale_shift);
}
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(const Tensor *scale,
const Tensor *shift) {
auto scale_tz = phi::vectorize(scale->dims());
const unsigned int C = scale_tz[0];
PADDLE_ENFORCE_EQ(
scale_tz.size(),
1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
auto scaleshift_memory =
this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
// MKLDNN requires a single piece of memory for scale and shift/bias data
auto mem_data_handle =
reinterpret_cast<T *>(scaleshift_memory->get_data_handle());
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
return scaleshift_memory;
}
std::shared_ptr<dnnl::memory> AcquireDiffScaleShiftMemory(
T *diff_scaleshift_data) {
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(),
diff_scaleshift_data);
}
std::shared_ptr<dnnl::memory> AcquireMeanMemory(
const phi::DenseTensor *mean) {
const T *mean_data = mean->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->mean_desc(), phi::funcs::to_void_cast<T>(mean_data));
}
std::shared_ptr<dnnl::memory> AcquireMeanMemory(phi::DenseTensor *mean) {
T *mean_data = mean->mutable_data<T>(this->place_,
this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data);
}
std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
const phi::DenseTensor *variance) {
const T *variance_data = variance->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->variance_desc(),
phi::funcs::to_void_cast<T>(variance_data));
}
std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
phi::DenseTensor *variance) {
T *variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data);
}
};
template <typename T>
class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<phi::DenseTensor>("X");
const auto *scale = ctx.Input<phi::DenseTensor>("Scale");
const auto *shift = ctx.Input<phi::DenseTensor>("Bias");
const auto *batch_mean = ctx.Input<phi::DenseTensor>("SavedMean");
const auto *batch_variance = ctx.Input<phi::DenseTensor>("SavedVariance");
const auto *diff_y =
ctx.Input<phi::DenseTensor>(framework::GradVarName("Y"));
auto *diff_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
auto *diff_scale =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Scale"));
auto *diff_shift =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Bias"));
BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, scale, diff_y);
// MKLDNN requires a single piece of memory for scale and shift/bias data
const unsigned int C = phi::vectorize(scale->dims())[0];
const size_t scaleshift_size = 2 * C;
std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size);
auto src_memory = handler.AcquireSrcMemory(x);
auto mean_memory = handler.AcquireMeanMemory(batch_mean);
auto variance_memory = handler.AcquireVarianceMemory(batch_variance);
auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y);
auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x);
auto diff_scaleshift_memory =
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
// finally create batch_norm backward primitive
auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive();
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
batch_norm_bwd_p->execute(
astream,
{{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_MEAN, *mean_memory},
{DNNL_ARG_VARIANCE, *variance_memory},
{DNNL_ARG_DIFF_DST, *diff_dst_memory},
{DNNL_ARG_SCALE_SHIFT, *scaleshift_memory},
{DNNL_ARG_DIFF_SRC, *diff_src_memory},
{DNNL_ARG_DIFF_SCALE_SHIFT, *diff_scaleshift_memory}});
astream.wait();
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
// copy back diff scale/shift to output tensors (diff scale/shift)
diff_scaleshift_data.resize(scaleshift_size);
auto it = std::begin(diff_scaleshift_data);
std::copy(it, std::next(it, C), diff_scale_data);
std::copy(
std::next(it, C), std::end(diff_scaleshift_data), diff_shift_data);
// set memory descriptor of out tensor
diff_x->set_mem_desc(diff_src_memory->get_desc());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(batch_norm_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::BatchNormMKLDNNGradOpKernel<float>);
...@@ -27,7 +27,6 @@ register_unity_group( ...@@ -27,7 +27,6 @@ register_unity_group(
bilateral_slice_op.cc) bilateral_slice_op.cc)
register_unity_group( register_unity_group(
cc cc
mkldnn/batch_norm_mkldnn_op.cc
bilinear_tensor_product_op.cc bilinear_tensor_product_op.cc
bmm_op.cc bmm_op.cc
bpr_loss_op.cc bpr_loss_op.cc
......
...@@ -47,7 +47,7 @@ bool constexpr is_int8() { ...@@ -47,7 +47,7 @@ bool constexpr is_int8() {
template <typename T> template <typename T>
constexpr bool is_bfloat16() { constexpr bool is_bfloat16() {
return std::is_same<T, phi::dtype::bfloat16>::value; return std::is_same<T, dtype::bfloat16>::value;
} }
static void AppendActivation(const OneDNNContext& dev_ctx, static void AppendActivation(const OneDNNContext& dev_ctx,
...@@ -102,7 +102,7 @@ static void AppendActivation(const OneDNNContext& dev_ctx, ...@@ -102,7 +102,7 @@ static void AppendActivation(const OneDNNContext& dev_ctx,
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
activation_type, activation_type,
activation_map.end(), activation_map.end(),
phi::errors::InvalidArgument( errors::InvalidArgument(
"Activation '%s' not found in oneDNN algorithms mapper", "Activation '%s' not found in oneDNN algorithms mapper",
fuse_activation)); fuse_activation));
...@@ -810,7 +810,7 @@ class SoftmaxOneDNNHandler ...@@ -810,7 +810,7 @@ class SoftmaxOneDNNHandler
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x->dims(), x->dims(),
out->dims(), out->dims(),
phi::errors::InvalidArgument( errors::InvalidArgument(
"The shape of input and output tensor must be identical.")); "The shape of input and output tensor must be identical."));
const int canonical_axis = funcs::CanonicalAxis(axis, x->dims().size()); const int canonical_axis = funcs::CanonicalAxis(axis, x->dims().size());
...@@ -1145,7 +1145,7 @@ class PReluOneDNNHandler ...@@ -1145,7 +1145,7 @@ class PReluOneDNNHandler
const bool is_test) const bool is_test)
: OneDNNHandlerNoCachingT<T, dnnl::prelu_forward, dnnl::prelu_backward>( : OneDNNHandlerNoCachingT<T, dnnl::prelu_forward, dnnl::prelu_backward>(
engine, cpu_place) { engine, cpu_place) {
auto weights_dims = phi::vectorize(weights.dims()); auto weights_dims = vectorize(weights.dims());
// weights must have same size as X only for "element" case // weights must have same size as X only for "element" case
if (weights.dims().size() != x.dims().size()) { if (weights.dims().size() != x.dims().size()) {
auto new_weights_dims = std::vector<int64_t>(x.dims().size(), 1); auto new_weights_dims = std::vector<int64_t>(x.dims().size(), 1);
...@@ -1304,21 +1304,52 @@ class BatchNormOneDNNHandler ...@@ -1304,21 +1304,52 @@ class BatchNormOneDNNHandler
flags); flags);
} }
BatchNormOneDNNHandler(const dnnl::engine engine,
Place cpu_place,
const float epsilon,
const DenseTensor* in_x,
const DenseTensor* scale,
const DenseTensor* out_grad)
: OneDNNHandlerNoCachingT<T,
dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward>(engine,
cpu_place) {
auto scale_tz = vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE_EQ(
scale_tz.size(),
1,
errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training,
in_x->mem_desc(),
epsilon,
dnnl::normalization_flags::use_scale_shift);
this->AcquireBackwardPrimitiveDescriptor(
dnnl::prop_kind::backward,
out_grad->mem_desc(),
in_x->mem_desc(),
epsilon,
dnnl::normalization_flags::use_scale_shift);
}
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory( std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(
const DenseTensor* scale, const DenseTensor* shift) { const DenseTensor* scale, const DenseTensor* shift) {
auto scale_tz = phi::vectorize(scale->dims()); auto scale_tz = vectorize(scale->dims());
const unsigned int C = scale_tz[0]; const unsigned int C = scale_tz[0];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_tz.size(), scale_tz.size(),
1, 1,
phi::errors::InvalidArgument( errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d", "Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size())); scale_tz.size()));
auto scaleshift_memory = auto scaleshift_memory =
this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc()); this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
// MKLDNN requires a single piece of memory for scale and shift/bias data // oneDNN requires a single piece of memory for scale and shift/bias data
auto mem_data_handle = auto mem_data_handle =
reinterpret_cast<T*>(scaleshift_memory->get_data_handle()); reinterpret_cast<T*>(scaleshift_memory->get_data_handle());
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle); std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
...@@ -1692,7 +1723,7 @@ static std::vector<int64_t> GetInputStrides(const OneDNNContext& dev_ctx, ...@@ -1692,7 +1723,7 @@ static std::vector<int64_t> GetInputStrides(const OneDNNContext& dev_ctx,
auto& MatrixDimsFromVector = auto& MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor( MatDescriptor mat_dim = CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0, transpose_input); MatrixDimsFromVector(new_dims), 0, transpose_input);
std::vector<int64_t> strides; std::vector<int64_t> strides;
...@@ -1728,8 +1759,7 @@ static bool IsOutputFused(const OneDNNContext& dev_ctx) { ...@@ -1728,8 +1759,7 @@ static bool IsOutputFused(const OneDNNContext& dev_ctx) {
} }
template <typename XT, typename YT, typename OT> template <typename XT, typename YT, typename OT>
class MatmulOneDNNHandler class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
public: public:
MatmulOneDNNHandler(const OneDNNContext& dev_ctx, MatmulOneDNNHandler(const OneDNNContext& dev_ctx,
const std::vector<int64_t>& x_org_dims, const std::vector<int64_t>& x_org_dims,
...@@ -1739,8 +1769,8 @@ class MatmulOneDNNHandler ...@@ -1739,8 +1769,8 @@ class MatmulOneDNNHandler
const std::vector<int64_t>& x_strides_override, const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override, const std::vector<int64_t>& y_strides_override,
bool is_output_fused) bool is_output_fused)
: phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>( : OneDNNHandlerNoCachingT<XT, dnnl::matmul>(dev_ctx.GetEngine(),
dev_ctx.GetEngine(), dev_ctx.GetPlace()) { dev_ctx.GetPlace()) {
// M X K * K X N // M X K * K X N
std::vector<int64_t> x_dims(x_org_dims); std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims); std::vector<int64_t> y_dims(y_org_dims);
......
...@@ -36,7 +36,6 @@ using ConstEigenVectorArrayMap = ...@@ -36,7 +36,6 @@ using ConstEigenVectorArrayMap =
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& ctx, void BatchNormGradRawKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& scale, const DenseTensor& scale,
const DenseTensor& bias, const DenseTensor& bias,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/batch_norm_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const paddle::optional<DenseTensor>& mean,
const paddle::optional<DenseTensor>& variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
const DenseTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool is_inplace,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
funcs::BatchNormOneDNNHandler<T> handler(
dev_ctx.GetEngine(), dev_ctx.GetPlace(), epsilon, &x, &scale, &y_grad);
const unsigned int C = vectorize(scale.dims())[0];
const size_t scaleshift_size = 2 * C;
std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size);
auto src_memory = handler.AcquireSrcMemory(&x);
auto mean_memory = handler.AcquireMeanMemory(&saved_mean);
auto variance_memory = handler.AcquireVarianceMemory(&saved_variance);
auto diff_dst_memory = handler.AcquireDiffDstMemory(&y_grad);
auto scaleshift_memory = handler.AcquireScaleShiftMemory(&scale, &bias);
auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad);
auto diff_scaleshift_memory =
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
batch_norm_bwd_p->execute(
astream,
{{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_MEAN, *mean_memory},
{DNNL_ARG_VARIANCE, *variance_memory},
{DNNL_ARG_DIFF_DST, *diff_dst_memory},
{DNNL_ARG_SCALE_SHIFT, *scaleshift_memory},
{DNNL_ARG_DIFF_SRC, *diff_src_memory},
{DNNL_ARG_DIFF_SCALE_SHIFT, *diff_scaleshift_memory}});
astream.wait();
T* diff_scale_data = dev_ctx.template Alloc<T>(scale_grad);
T* diff_shift_data = dev_ctx.template Alloc<T>(bias_grad);
// copy back diff scale/shift to output tensors (diff scale/shift)
diff_scaleshift_data.resize(scaleshift_size);
auto it = std::begin(diff_scaleshift_data);
std::copy(it, std::next(it, C), diff_scale_data);
std::copy(std::next(it, C), std::end(diff_scaleshift_data), diff_shift_data);
// set memory descriptor of out tensor
x_grad->set_mem_desc(diff_src_memory->get_desc());
}
template <typename T, typename Context>
void BatchNormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const paddle::optional<DenseTensor>& mean,
const paddle::optional<DenseTensor>& variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
const DenseTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx,
x,
scale,
bias,
mean,
variance,
saved_mean,
saved_variance,
reserve_space,
y_grad,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
/*is_inplace*/ false,
x_grad,
scale_grad,
bias_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(
batch_norm_grad, OneDNN, ONEDNN, phi::BatchNormGradKernel, float) {}
PD_REGISTER_KERNEL(
batch_norm_grad_raw, OneDNN, ONEDNN, phi::BatchNormGradRawKernel, float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册