未验证 提交 f1be9cf1 编写于 作者: Q qipengh 提交者: GitHub

[MLU]add sync_batch_norm op (#44176)

上级 75aaa08a
......@@ -149,6 +149,10 @@ if (WITH_ASCEND_CL)
op_library(sync_batch_norm_op)
endif()
if (WITH_MLU)
op_library(sync_batch_norm_op)
endif()
op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
op_library(eye_op DEPS ${OP_HEADER_DEPS})
op_library(recurrent_op DEPS ${OP_HEADER_DEPS})
......
......@@ -259,15 +259,16 @@ MLUCnnlTensorDesc::~MLUCnnlTensorDesc() {
MLUCnnlActivationDesc::MLUCnnlActivationDesc(
const cnnlActivationMode_t act_mode, const float ceof) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4(
active_desc_,
act_mode,
CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN,
ceof,
1.0f /*sliced_dim*/,
1.67326319217681884765625 /*selu_alpha*/,
1.05070102214813232421875 /*selu_lambda*/));
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlSetActivationDescriptor_v5(active_desc_,
act_mode,
CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN,
ceof,
1.0f /*sliced_dim*/,
1.67326319217681884765625 /*selu_alpha*/,
1.05070102214813232421875 /*selu_lambda*/,
false /*is_elu_mode*/));
}
MLUCnnlActivationDesc::MLUCnnlActivationDesc(
......@@ -278,14 +279,15 @@ MLUCnnlActivationDesc::MLUCnnlActivationDesc(
const float selu_lambda) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_));
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlSetActivationDescriptor_v4(active_desc_,
cnnlSetActivationDescriptor_v5(active_desc_,
act_mode,
CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN,
ceof,
sliced_dim,
selu_alpha,
selu_lambda));
selu_lambda,
false /*is_elu_mode*/));
}
const cnnlActivationDescriptor_t MLUCnnlActivationDesc::get() const {
......@@ -2350,6 +2352,36 @@ MLURNNDesc::~MLURNNDesc() {
workspace_size));
}
/* static */ void MLUCnnl::Pow(const ExecutionContext& ctx,
cnnlComputationPreference_t prefer,
const cnnlTensorDescriptor_t input1_desc,
const void* input1,
const cnnlTensorDescriptor_t input2_desc,
const void* input2,
const cnnlTensorDescriptor_t output_desc,
void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetPowWorkspaceSize(
handle, input1_desc, input2_desc, output_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlPow(handle,
prefer,
input1_desc,
input1,
input2_desc,
input2,
workspace_ptr,
workspace_size,
output_desc,
output));
}
/* static */ void MLUCnnl::PowR(const ExecutionContext& ctx,
cnnlComputationPreference_t prefer,
const cnnlTensorDescriptor_t input1_desc,
......@@ -4895,5 +4927,180 @@ MLURNNDesc::~MLURNNDesc() {
grads_image));
}
/* static */ void MLUCnnl::SyncBatchNormStats(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const float eps,
const cnnlTensorDescriptor_t mean_desc,
void* mean,
const cnnlTensorDescriptor_t invstd_desc,
void* invstd) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSyncBatchNormStats(
handle, x_desc, x, eps, mean_desc, mean, invstd_desc, invstd));
}
/* static */ void MLUCnnl::SyncBatchNormGatherStatsWithCounts(
const ExecutionContext& ctx,
float momentum,
float eps,
const cnnlTensorDescriptor_t mean_all_desc,
const void* mean_all,
const cnnlTensorDescriptor_t invstd_all_desc,
const void* invstd_all,
const cnnlTensorDescriptor_t moving_mean_desc,
void* moving_mean,
const cnnlTensorDescriptor_t moving_var_desc,
void* moving_var,
const cnnlTensorDescriptor_t count_all_desc,
const void* count_all,
const cnnlTensorDescriptor_t mean_desc,
void* mean,
const cnnlTensorDescriptor_t invstd_desc,
void* invstd) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlSyncBatchNormGatherStatsWithCounts(handle,
mean_all_desc,
mean_all,
invstd_all_desc,
invstd_all,
moving_mean_desc,
moving_mean,
moving_var_desc,
moving_var,
momentum,
eps,
count_all_desc,
count_all,
mean_desc,
mean,
invstd_desc,
invstd));
}
/* static */ void MLUCnnl::SyncBatchNormElemt(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t mean_desc,
const void* mean,
const cnnlTensorDescriptor_t invstd_desc,
const void* invstd,
const cnnlTensorDescriptor_t weight_desc,
const void* weight,
const cnnlTensorDescriptor_t bias_desc,
const void* bias,
const cnnlTensorDescriptor_t y_desc,
void* y) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSyncBatchNormElemt(handle,
x_desc,
x,
mean_desc,
mean,
invstd_desc,
invstd,
weight_desc,
weight,
bias_desc,
bias,
y_desc,
y));
}
/* static */ void MLUCnnl::SyncBatchnormBackwardReduce(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t desc_dz,
const void* dz,
const cnnlTensorDescriptor_t desc_x,
const void* x,
const cnnlTensorDescriptor_t desc_mean,
const void* mean,
const cnnlTensorDescriptor_t desc_invstd,
const void* invstd,
const cnnlTensorDescriptor_t desc_dweight,
void* dweight,
const cnnlTensorDescriptor_t desc_dbias,
void* dbias,
const cnnlTensorDescriptor_t desc_sum_dy,
void* sum_dy,
const cnnlTensorDescriptor_t desc_sum_dy_xmu,
void* sum_dy_xmu,
const bool needs_input_grad0,
const bool needs_input_grad1,
const bool needs_input_grad2) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlSyncBatchnormBackwardReduce(handle,
desc_dz,
dz,
desc_x,
x,
desc_mean,
mean,
desc_invstd,
invstd,
desc_dweight,
dweight,
desc_dbias,
dbias,
desc_sum_dy,
sum_dy,
desc_sum_dy_xmu,
sum_dy_xmu,
needs_input_grad0,
needs_input_grad1,
needs_input_grad2));
}
/* static */ void MLUCnnl::SyncBatchNormBackwardElemt(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t diff_y_desc,
const void* diff_y,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t mean_desc,
const void* mean,
const cnnlTensorDescriptor_t invstd_desc,
const void* invstd,
const cnnlTensorDescriptor_t weight_desc,
const void* weight,
const cnnlTensorDescriptor_t sum_dy_desc,
const void* sum_dy,
const cnnlTensorDescriptor_t sum_dy_xmu_desc,
const void* sum_dy_xmu,
const cnnlTensorDescriptor_t count_desc,
const void* count,
const cnnlTensorDescriptor_t diff_x_desc,
void* diff_x) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSyncBatchNormBackwardElemtV2(handle,
diff_y_desc,
diff_y,
x_desc,
x,
mean_desc,
mean,
invstd_desc,
invstd,
weight_desc,
weight,
sum_dy_desc,
sum_dy,
sum_dy_xmu_desc,
sum_dy_xmu,
count_desc,
count,
diff_x_desc,
diff_x));
}
} // namespace operators
} // namespace paddle
......@@ -1276,6 +1276,15 @@ class MLUCnnl {
const cnnlTensorDescriptor_t output_desc,
void* output);
static void Pow(const ExecutionContext& ctx,
cnnlComputationPreference_t prefer,
const cnnlTensorDescriptor_t input1_desc,
const void* input1,
const cnnlTensorDescriptor_t input2_desc,
const void* input2,
const cnnlTensorDescriptor_t output_desc,
void* output);
static void PowR(const ExecutionContext& ctx,
cnnlComputationPreference_t prefer,
const cnnlTensorDescriptor_t input1_desc,
......@@ -2030,8 +2039,152 @@ class MLUCnnl {
const void* boxes,
const cnnlTensorDescriptor_t grads_image_desc,
void* grads_image);
static void SyncBatchNormStats(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const float eps,
const cnnlTensorDescriptor_t mean_desc,
void* mean,
const cnnlTensorDescriptor_t invstd_desc,
void* invstd);
static void SyncBatchNormGatherStatsWithCounts(
const ExecutionContext& ctx,
float momentum,
float eps,
const cnnlTensorDescriptor_t mean_all_desc,
const void* mean_all,
const cnnlTensorDescriptor_t invstd_all_desc,
const void* invstd_all,
const cnnlTensorDescriptor_t moving_mean_desc,
void* moving_mean,
const cnnlTensorDescriptor_t moving_var_desc,
void* moving_var,
const cnnlTensorDescriptor_t count_all_desc,
const void* count_all,
const cnnlTensorDescriptor_t mean_desc,
void* mean,
const cnnlTensorDescriptor_t invstd_desc,
void* invstd);
static void SyncBatchNormElemt(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t mean_desc,
const void* mean,
const cnnlTensorDescriptor_t invstd_desc,
const void* invstd,
const cnnlTensorDescriptor_t weight_desc,
const void* weight,
const cnnlTensorDescriptor_t bias_desc,
const void* bias,
const cnnlTensorDescriptor_t y_desc,
void* y);
static void SyncBatchnormBackwardReduce(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t desc_dz,
const void* dz,
const cnnlTensorDescriptor_t desc_x,
const void* x,
const cnnlTensorDescriptor_t desc_mean,
const void* mean,
const cnnlTensorDescriptor_t desc_invstd,
const void* invstd,
const cnnlTensorDescriptor_t desc_dweight,
void* dweight,
const cnnlTensorDescriptor_t desc_dbias,
void* dbias,
const cnnlTensorDescriptor_t desc_sum_dy,
void* sum_dy,
const cnnlTensorDescriptor_t desc_sum_dy_xmu,
void* sum_dy_xmu,
const bool needs_input_grad0,
const bool needs_input_grad1,
const bool needs_input_grad2);
static void SyncBatchNormBackwardElemt(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t diff_y_desc,
const void* diff_y,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t mean_desc,
const void* mean,
const cnnlTensorDescriptor_t invstd_desc,
const void* invstd,
const cnnlTensorDescriptor_t weight_desc,
const void* weight,
const cnnlTensorDescriptor_t sum_dy_desc,
const void* sum_dy,
const cnnlTensorDescriptor_t sum_dy_xmu_desc,
const void* sum_dy_xmu,
const cnnlTensorDescriptor_t count_desc,
const void* count,
const cnnlTensorDescriptor_t diff_x_desc,
void* diff_x);
};
const std::map<const std::string, std::pair<std::vector<int>, std::vector<int>>>
TransPermMap = {
// trans_mode, (forward_perm, backward_perm)
{"3D_NCHW2NHWC", {{0, 2, 1}, {0, 2, 1}}},
{"4D_NCHW2NHWC", {{0, 2, 3, 1}, {0, 3, 1, 2}}},
{"5D_NCHWD2NDHWC", {{0, 4, 2, 3, 1}, {0, 4, 2, 3, 1}}},
{"5D_NHWDC2NDHWC", {{0, 3, 1, 2, 4}, {0, 2, 3, 4, 1}}}};
inline void SetMLUTransposePerm(const framework::DDim& dims,
const DataLayout& data_layout,
std::vector<int>* forward_perm,
std::vector<int>* backward_perm,
std::vector<int>* out_shape) {
const int dim_size = dims.size();
PADDLE_ENFORCE_EQ((dim_size >= 3) && (dim_size <= 5),
true,
platform::errors::InvalidArgument(
"MLUTransposePerm func only support (dim_size >= 3) && "
"(dim_size <= 5), but now dim_size is %d.",
dim_size));
PADDLE_ENFORCE_EQ(
(data_layout == DataLayout::kNCHW) || (data_layout == DataLayout::kNHWC),
true,
platform::errors::InvalidArgument(
"MLUTransposePerm func only support DataLayout: kNCHW or kNHWC, but "
"now data_layout is %s.",
data_layout));
// case 1: NCHW of Paddle != NHWC of MLU when dims==3,4
// case 2: NHWDC and NCHWD of Paddle != NDHWC of MLU when dims==5
std::string map_key = "";
if (data_layout == DataLayout::kNCHW) {
switch (dim_size) {
case 3:
map_key = "3D_NCHW2NHWC";
break;
case 4:
map_key = "4D_NCHW2NHWC";
break;
case 5:
map_key = "5D_NCHWD2NDHWC";
break;
}
} else if (data_layout == DataLayout::kNHWC && dim_size == 5) {
map_key = "5D_NHWDC2NDHWC";
}
assert(map_key != "");
forward_perm->assign(TransPermMap.at(map_key).first.begin(),
TransPermMap.at(map_key).first.end());
backward_perm->assign(TransPermMap.at(map_key).second.begin(),
TransPermMap.at(map_key).second.end());
auto in_dims = phi::vectorize(dims);
for (size_t i = 0; i < in_dims.size(); i++) {
out_shape->push_back(in_dims[forward_perm->at(i)]);
}
}
template <typename T>
inline void TransposeFromMLUTensor(const ExecutionContext& ctx,
const std::vector<int> perm,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the Licnse. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/platform/collective_helper.h"
#if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/device/mlu/cncl_helper.h"
#endif
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
#define GET_LAYOUT_OFFSET 2
using Tensor = framework::Tensor;
static std::vector<cnnlTensorLayout_t> supported_input_layout = {
CNNL_LAYOUT_NC, CNNL_LAYOUT_NLC, CNNL_LAYOUT_NHWC, CNNL_LAYOUT_NDHWC};
template <typename T>
class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
float epsilon = ctx.Attr<float>("epsilon");
float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
const std::string layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout layout = framework::StringToDataLayout(layout_str);
PADDLE_ENFORCE_EQ(use_global_stats,
false,
platform::errors::InvalidArgument(
"sync_batch_norm doesn't support "
"to set use_global_stats True. Please use batch_norm "
"in this case."));
const auto *x = ctx.Input<Tensor>("X");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
const auto *mean = ctx.Input<Tensor>("Mean");
const auto *variance = ctx.Input<Tensor>("Variance");
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
auto *y = ctx.Output<Tensor>("Y");
const auto &x_dims = x->dims();
PADDLE_ENFORCE_GE(x_dims.size(),
2,
platform::errors::InvalidArgument(
"The Input dim size should be larger than 1."));
PADDLE_ENFORCE_LE(x_dims.size(),
5,
platform::errors::InvalidArgument(
"The Input dim size should be less than 6."));
int N, C, H, W, D;
ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D);
y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<MPDType>(ctx.GetPlace());
variance_out->mutable_data<MPDType>(ctx.GetPlace());
saved_mean->mutable_data<MPDType>(ctx.GetPlace());
saved_variance->mutable_data<MPDType>(ctx.GetPlace());
Tensor trans_x;
Tensor trans_y;
std::vector<int> forward_perm;
std::vector<int> backward_perm;
std::vector<int> trans_shape;
const bool need_transpose =
((layout == DataLayout::kNCHW && x_dims.size() != 2) ||
x_dims.size() == 5);
if (need_transpose) {
SetMLUTransposePerm(
x_dims, layout, &forward_perm, &backward_perm, &trans_shape);
trans_x.mutable_data<T>(phi::make_ddim(trans_shape), ctx.GetPlace());
trans_y.mutable_data<T>(phi::make_ddim(trans_shape), ctx.GetPlace());
MLUCnnlTensorDesc desc_x(*x);
MLUCnnlTensorDesc desc_trans_x(
trans_shape.size(), trans_shape.data(), ToCnnlDataType(x->dtype()));
MLUCnnl::Transpose(ctx,
forward_perm,
x_dims.size(),
desc_x.get(),
GetBasePtr(x),
desc_trans_x.get(),
GetBasePtr(&trans_x));
} else {
trans_x = *x;
trans_y = *y;
}
MLUCnnlTensorDesc desc_trans(
trans_x,
supported_input_layout[x_dims.size() - GET_LAYOUT_OFFSET],
ToCnnlDataType<T>());
bool test_mode = is_test && (!trainable_stats);
if (test_mode) { // inference
MLUCnnlTensorDesc desc_weight_bias_mean_var(*bias);
MLUCnnl::FusedBatchNorm(ctx,
false /*is_training*/,
desc_trans.get(),
GetBasePtr(&trans_x),
desc_weight_bias_mean_var.get(),
GetBasePtr(scale),
GetBasePtr(bias),
GetBasePtr(mean),
GetBasePtr(variance),
epsilon,
momentum,
desc_trans.get(),
GetBasePtr(&trans_y),
nullptr,
nullptr,
nullptr,
nullptr);
} else { // training
if (ctx.HasInput("MomentumTensor")) {
const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
Tensor mom_cpu;
paddle::framework::TensorCopySync(
*mom_tensor, platform::CPUPlace(), &mom_cpu);
momentum = mom_cpu.data<float>()[0];
}
Tensor local_mean, local_var;
local_mean.mutable_data<MPDType>(mean->dims(), ctx.GetPlace());
local_var.mutable_data<MPDType>(variance->dims(), ctx.GetPlace());
MLUCnnlTensorDesc desc_mean_var(*mean_out);
// cacl local_mean and local_var
MLUCnnl::SyncBatchNormStats(ctx,
desc_trans.get(),
GetBasePtr(&trans_x),
epsilon,
desc_mean_var.get(),
GetBasePtr(&local_mean),
desc_mean_var.get(),
GetBasePtr(&local_var));
Tensor input_count;
input_count.mutable_data<T>(phi::make_ddim({1}), ctx.GetPlace());
FillMLUTensorWithHostValue<T>(
ctx, static_cast<T>(x->numel() / C), &input_count);
Tensor count_all;
Tensor mean_all(mean->dtype());
Tensor invstd_all(variance->dtype());
auto &dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto stream = dev_ctx.stream();
auto *comm = dev_ctx.cncl_comm();
if (comm) {
auto *comm = paddle::platform::CNCLCommContext::Instance()
.Get(0, ctx.GetPlace())
->comm();
int count;
PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCommCount(&count, comm));
count_all.mutable_data<T>(phi::make_ddim({count}), ctx.GetPlace());
cnclDataType_t dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(count_all.dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&input_count),
GetBasePtr(&count_all),
1,
dtype,
comm,
stream));
mean_all.mutable_data<MPDType>(phi::make_ddim({count, mean->numel()}),
ctx.GetPlace());
invstd_all.mutable_data<MPDType>(
phi::make_ddim({count, variance->numel()}), ctx.GetPlace());
auto cncl_dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(mean_all.dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&local_mean),
GetBasePtr(&mean_all),
local_mean.numel(),
cncl_dtype,
comm,
stream));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&local_var),
GetBasePtr(&invstd_all),
local_var.numel(),
cncl_dtype,
comm,
stream));
} else {
count_all = input_count;
mean_all.ShareDataWith(local_mean);
invstd_all.ShareDataWith(local_var);
mean_all.Resize(phi::make_ddim({1, local_mean.numel()}));
invstd_all.Resize(phi::make_ddim({1, local_var.numel()}));
}
MLUCnnlTensorDesc desc_all_mean_invstd(
invstd_all, CNNL_LAYOUT_NC, ToCnnlDataType<MPDType>());
MLUCnnlTensorDesc desc_moving_mean_var(*mean_out);
MLUCnnlTensorDesc desc_saved_mean_var(*saved_mean);
MLUCnnlTensorDesc desc_count_all(count_all);
MLUCnnl::SyncBatchNormGatherStatsWithCounts(ctx,
momentum,
epsilon,
desc_all_mean_invstd.get(),
GetBasePtr(&mean_all),
desc_all_mean_invstd.get(),
GetBasePtr(&invstd_all),
desc_moving_mean_var.get(),
GetBasePtr(mean_out),
desc_moving_mean_var.get(),
GetBasePtr(variance_out),
desc_count_all.get(),
GetBasePtr(&count_all),
desc_saved_mean_var.get(),
GetBasePtr(saved_mean),
desc_saved_mean_var.get(),
GetBasePtr(saved_variance));
MLUCnnlTensorDesc desc_other_param(*saved_mean);
MLUCnnl::SyncBatchNormElemt(ctx,
desc_trans.get(),
GetBasePtr(&trans_x),
desc_other_param.get(),
GetBasePtr(saved_mean),
desc_other_param.get(),
GetBasePtr(saved_variance),
desc_other_param.get(),
GetBasePtr(scale),
desc_other_param.get(),
GetBasePtr(bias),
desc_trans.get(),
GetBasePtr(&trans_y));
}
if (need_transpose) {
MLUCnnlTensorDesc desc_y(*y);
MLUCnnlTensorDesc desc_trans_y(trans_y);
MLUCnnl::Transpose(ctx,
backward_perm,
trans_y.dims().size(),
desc_trans_y.get(),
GetBasePtr(&trans_y),
desc_y.get(),
GetBasePtr(y));
}
}
};
template <typename T>
class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const std::string layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout layout = framework::StringToDataLayout(layout_str);
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_inv_var = ctx.Input<Tensor>("SavedVariance");
const Tensor *x;
if (ctx.HasInput("Y")) {
PADDLE_ENFORCE_EQ(true,
false,
platform::errors::InvalidArgument(
"sync_batch_norm_grad doesn't support input Y"));
} else {
x = ctx.Input<Tensor>("X");
}
const auto &x_dims = x->dims();
PADDLE_ENFORCE_GE(x_dims.size(),
2,
platform::errors::InvalidArgument(
"The Input X dim size should be larger than 1."));
PADDLE_ENFORCE_LE(x_dims.size(),
5,
platform::errors::InvalidArgument(
"The Input X dim size should be less than 6."));
int N, C, H, W, D;
ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D);
PADDLE_ENFORCE_EQ(scale->dims()[0],
C,
platform::errors::InvalidArgument(
"Expected first dim for input parameter(scale) of "
"OP(sync_batch_norm) be (%d), but given (%d).",
C,
scale->dims()[0]));
d_x->mutable_data<T>(ctx.GetPlace());
if (d_scale && d_bias) {
d_scale->mutable_data<MPDType>(ctx.GetPlace());
d_bias->mutable_data<MPDType>(ctx.GetPlace());
}
PADDLE_ENFORCE_EQ(scale->dims().size(),
1UL,
platform::errors::InvalidArgument(
"Expected rank for input parameter(scale) of "
"OP(sync_batch_norm) be (1), but given (%d).",
scale->dims().size()));
Tensor trans_x;
Tensor trans_dy;
Tensor trans_dx;
std::vector<int> forward_perm;
std::vector<int> backward_perm;
std::vector<int> trans_shape;
const bool need_transpose =
((layout == DataLayout::kNCHW && x_dims.size() != 2) ||
x_dims.size() == 5);
if (need_transpose) {
SetMLUTransposePerm(
x_dims, layout, &forward_perm, &backward_perm, &trans_shape);
trans_x.mutable_data<T>(phi::make_ddim(trans_shape), ctx.GetPlace());
trans_dy.mutable_data<T>(phi::make_ddim(trans_shape), ctx.GetPlace());
trans_dx.mutable_data<T>(phi::make_ddim(trans_shape), ctx.GetPlace());
MLUCnnlTensorDesc desc_x(*x);
MLUCnnlTensorDesc desc_trans_x(
trans_shape.size(), trans_shape.data(), ToCnnlDataType(x->dtype()));
MLUCnnl::Transpose(ctx,
forward_perm,
x_dims.size(),
desc_x.get(),
GetBasePtr(x),
desc_trans_x.get(),
GetBasePtr(&trans_x));
MLUCnnl::Transpose(ctx,
forward_perm,
x_dims.size(),
desc_x.get(),
GetBasePtr(d_y),
desc_trans_x.get(),
GetBasePtr(&trans_dy));
} else {
trans_x = *x;
trans_dy = *d_y;
trans_dx = *d_x;
}
MLUCnnlTensorDesc desc_trans(
trans_x,
supported_input_layout[x_dims.size() - GET_LAYOUT_OFFSET],
ToCnnlDataType<T>());
Tensor sum_dy, sum_dy_xmu;
sum_dy.mutable_data<MPDType>(bias->dims(), ctx.GetPlace());
sum_dy_xmu.mutable_data<MPDType>(bias->dims(), ctx.GetPlace());
MLUCnnlTensorDesc desc_other_param(*bias);
MLUCnnl::SyncBatchnormBackwardReduce(
ctx,
desc_trans.get(),
GetBasePtr(&trans_dy),
desc_trans.get(),
GetBasePtr(&trans_x),
desc_other_param.get(),
GetBasePtr(saved_mean),
desc_other_param.get(),
GetBasePtr(saved_inv_var),
d_scale ? desc_other_param.get() : nullptr,
d_scale ? GetBasePtr(d_scale) : nullptr,
d_bias ? desc_other_param.get() : nullptr,
d_bias ? GetBasePtr(d_bias) : nullptr,
desc_other_param.get(),
GetBasePtr(&sum_dy),
desc_other_param.get(),
GetBasePtr(&sum_dy_xmu),
true /*compute sum_dy, sum_dy_xmu*/,
d_scale ? true : false /*compute d_scale*/,
d_bias ? true : false /*compute d_bias*/);
Tensor numel_count;
numel_count.mutable_data<int32_t>(phi::make_ddim({1}), ctx.GetPlace());
FillMLUTensorWithHostValue<int32_t>(
ctx, static_cast<int32_t>(x->numel() / C), &numel_count);
auto &dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto stream = dev_ctx.stream();
auto *comm = dev_ctx.cncl_comm();
if (comm) {
auto *comm = paddle::platform::CNCLCommContext::Instance()
.Get(0, ctx.GetPlace())
->comm();
cnclDataType_t dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(numel_count.dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&numel_count),
GetBasePtr(&numel_count),
1,
dtype,
cnclSum,
comm,
stream));
auto cncl_dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(sum_dy.dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&sum_dy),
GetBasePtr(&sum_dy),
sum_dy.numel(),
cncl_dtype,
cnclSum,
comm,
stream));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&sum_dy_xmu),
GetBasePtr(&sum_dy_xmu),
sum_dy_xmu.numel(),
cncl_dtype,
cnclSum,
comm,
stream));
}
if (d_x) {
MLUCnnlTensorDesc desc_count(numel_count);
MLUCnnl::SyncBatchNormBackwardElemt(ctx,
desc_trans.get(),
GetBasePtr(&trans_dy),
desc_trans.get(),
GetBasePtr(&trans_x),
desc_other_param.get(),
GetBasePtr(saved_mean),
desc_other_param.get(),
GetBasePtr(saved_inv_var),
desc_other_param.get(),
GetBasePtr(scale),
desc_other_param.get(),
GetBasePtr(&sum_dy),
desc_other_param.get(),
GetBasePtr(&sum_dy_xmu),
desc_count.get(),
GetBasePtr(&numel_count),
desc_trans.get(),
GetBasePtr(&trans_dx));
if (need_transpose) {
MLUCnnlTensorDesc desc_dx(*d_x);
MLUCnnlTensorDesc desc_trans_dx(trans_dx);
MLUCnnl::Transpose(ctx,
backward_perm,
trans_dx.dims().size(),
desc_trans_dx.get(),
GetBasePtr(&trans_dx),
desc_dx.get(),
GetBasePtr(d_x));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(sync_batch_norm,
ops::SyncBatchNormMLUKernel<float>,
ops::SyncBatchNormMLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(sync_batch_norm_grad,
ops::SyncBatchNormMLUGradKernel<float>,
ops::SyncBatchNormMLUGradKernel<plat::float16>);
......@@ -50,5 +50,7 @@ if(WITH_MLU)
set_tests_properties(test_collective_allgather_api_mlu PROPERTIES TIMEOUT
120)
set_tests_properties(test_c_comm_init_op_mlu PROPERTIES TIMEOUT 120)
set_tests_properties(test_sync_batch_norm_op_mlu_baseline PROPERTIES TIMEOUT
120)
endif()
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
sys.path.append("..")
import signal
import time
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_sync_batch_norm_base_mlu import TestSyncBatchNormRunnerBase, runtime_main
from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
from paddle.fluid.tests.unittests.test_sync_batch_norm_op import create_or_get_tensor
_set_use_system_allocator(False)
paddle.enable_static()
class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
def __init__(self):
self.global_ring_id = 0
self.dtype = np.float32
self.N = 8
self.C = 16
self.H = 32
self.W = 32
self.dshape = [self.N, self.C, self.H, self.W]
self.atol = 1e-3
def get_model(self,
main,
startup,
place,
layout,
seed,
sync_bn=False,
only_forward=False):
"""Build program."""
use_cudnn = False
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
data = fluid.layers.data(name='input',
shape=self.dshape,
dtype=self.dtype,
append_batch_size=False)
conv = fluid.layers.conv2d(
input=data,
num_filters=32,
filter_size=1,
param_attr=fluid.ParamAttr(name='conv2d_weight'),
bias_attr=False,
use_cudnn=use_cudnn)
bn = fluid.layers.batch_norm(
conv,
param_attr=fluid.ParamAttr(name='bn_scale'),
bias_attr=fluid.ParamAttr(name='bn_bias'),
moving_mean_name='bn_moving_mean',
moving_variance_name='bn_moving_variance',
data_layout=layout,
is_test=only_forward)
# if self.dtype == np.float16:
# bn = fluid.layers.cast(bn, 'float32')
sigmoid = fluid.layers.sigmoid(bn)
out = fluid.layers.reduce_sum(sigmoid)
# if not sync_bn:
# out = out / core.get_mlu_device_count()
if not only_forward:
sgd_opt = fluid.optimizer.SGD(learning_rate=0.0)
sgd_opt.backward(out)
return [out, conv, bn]
if __name__ == "__main__":
# print('sync_batch_norm_op_mlu.py __main__')
runtime_main(TestSyncBatchNormOpTraining, "identity", 0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册