未验证 提交 f1be9cf1 编写于 作者: Q qipengh 提交者: GitHub

[MLU]add sync_batch_norm op (#44176)

上级 75aaa08a
......@@ -149,6 +149,10 @@ if (WITH_ASCEND_CL)
op_library(sync_batch_norm_op)
endif()
if (WITH_MLU)
op_library(sync_batch_norm_op)
endif()
op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
op_library(eye_op DEPS ${OP_HEADER_DEPS})
op_library(recurrent_op DEPS ${OP_HEADER_DEPS})
......
......@@ -259,15 +259,16 @@ MLUCnnlTensorDesc::~MLUCnnlTensorDesc() {
MLUCnnlActivationDesc::MLUCnnlActivationDesc(
const cnnlActivationMode_t act_mode, const float ceof) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4(
active_desc_,
act_mode,
CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN,
ceof,
1.0f /*sliced_dim*/,
1.67326319217681884765625 /*selu_alpha*/,
1.05070102214813232421875 /*selu_lambda*/));
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlSetActivationDescriptor_v5(active_desc_,
act_mode,
CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN,
ceof,
1.0f /*sliced_dim*/,
1.67326319217681884765625 /*selu_alpha*/,
1.05070102214813232421875 /*selu_lambda*/,
false /*is_elu_mode*/));
}
MLUCnnlActivationDesc::MLUCnnlActivationDesc(
......@@ -278,14 +279,15 @@ MLUCnnlActivationDesc::MLUCnnlActivationDesc(
const float selu_lambda) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_));
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlSetActivationDescriptor_v4(active_desc_,
cnnlSetActivationDescriptor_v5(active_desc_,
act_mode,
CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN,
ceof,
sliced_dim,
selu_alpha,
selu_lambda));
selu_lambda,
false /*is_elu_mode*/));
}
const cnnlActivationDescriptor_t MLUCnnlActivationDesc::get() const {
......@@ -2350,6 +2352,36 @@ MLURNNDesc::~MLURNNDesc() {
workspace_size));
}
/* static */ void MLUCnnl::Pow(const ExecutionContext& ctx,
cnnlComputationPreference_t prefer,
const cnnlTensorDescriptor_t input1_desc,
const void* input1,
const cnnlTensorDescriptor_t input2_desc,
const void* input2,
const cnnlTensorDescriptor_t output_desc,
void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetPowWorkspaceSize(
handle, input1_desc, input2_desc, output_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlPow(handle,
prefer,
input1_desc,
input1,
input2_desc,
input2,
workspace_ptr,
workspace_size,
output_desc,
output));
}
/* static */ void MLUCnnl::PowR(const ExecutionContext& ctx,
cnnlComputationPreference_t prefer,
const cnnlTensorDescriptor_t input1_desc,
......@@ -4895,5 +4927,180 @@ MLURNNDesc::~MLURNNDesc() {
grads_image));
}
/* static */ void MLUCnnl::SyncBatchNormStats(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const float eps,
const cnnlTensorDescriptor_t mean_desc,
void* mean,
const cnnlTensorDescriptor_t invstd_desc,
void* invstd) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSyncBatchNormStats(
handle, x_desc, x, eps, mean_desc, mean, invstd_desc, invstd));
}
/* static */ void MLUCnnl::SyncBatchNormGatherStatsWithCounts(
const ExecutionContext& ctx,
float momentum,
float eps,
const cnnlTensorDescriptor_t mean_all_desc,
const void* mean_all,
const cnnlTensorDescriptor_t invstd_all_desc,
const void* invstd_all,
const cnnlTensorDescriptor_t moving_mean_desc,
void* moving_mean,
const cnnlTensorDescriptor_t moving_var_desc,
void* moving_var,
const cnnlTensorDescriptor_t count_all_desc,
const void* count_all,
const cnnlTensorDescriptor_t mean_desc,
void* mean,
const cnnlTensorDescriptor_t invstd_desc,
void* invstd) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlSyncBatchNormGatherStatsWithCounts(handle,
mean_all_desc,
mean_all,
invstd_all_desc,
invstd_all,
moving_mean_desc,
moving_mean,
moving_var_desc,
moving_var,
momentum,
eps,
count_all_desc,
count_all,
mean_desc,
mean,
invstd_desc,
invstd));
}
/* static */ void MLUCnnl::SyncBatchNormElemt(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t mean_desc,
const void* mean,
const cnnlTensorDescriptor_t invstd_desc,
const void* invstd,
const cnnlTensorDescriptor_t weight_desc,
const void* weight,
const cnnlTensorDescriptor_t bias_desc,
const void* bias,
const cnnlTensorDescriptor_t y_desc,
void* y) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSyncBatchNormElemt(handle,
x_desc,
x,
mean_desc,
mean,
invstd_desc,
invstd,
weight_desc,
weight,
bias_desc,
bias,
y_desc,
y));
}
/* static */ void MLUCnnl::SyncBatchnormBackwardReduce(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t desc_dz,
const void* dz,
const cnnlTensorDescriptor_t desc_x,
const void* x,
const cnnlTensorDescriptor_t desc_mean,
const void* mean,
const cnnlTensorDescriptor_t desc_invstd,
const void* invstd,
const cnnlTensorDescriptor_t desc_dweight,
void* dweight,
const cnnlTensorDescriptor_t desc_dbias,
void* dbias,
const cnnlTensorDescriptor_t desc_sum_dy,
void* sum_dy,
const cnnlTensorDescriptor_t desc_sum_dy_xmu,
void* sum_dy_xmu,
const bool needs_input_grad0,
const bool needs_input_grad1,
const bool needs_input_grad2) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlSyncBatchnormBackwardReduce(handle,
desc_dz,
dz,
desc_x,
x,
desc_mean,
mean,
desc_invstd,
invstd,
desc_dweight,
dweight,
desc_dbias,
dbias,
desc_sum_dy,
sum_dy,
desc_sum_dy_xmu,
sum_dy_xmu,
needs_input_grad0,
needs_input_grad1,
needs_input_grad2));
}
/* static */ void MLUCnnl::SyncBatchNormBackwardElemt(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t diff_y_desc,
const void* diff_y,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t mean_desc,
const void* mean,
const cnnlTensorDescriptor_t invstd_desc,
const void* invstd,
const cnnlTensorDescriptor_t weight_desc,
const void* weight,
const cnnlTensorDescriptor_t sum_dy_desc,
const void* sum_dy,
const cnnlTensorDescriptor_t sum_dy_xmu_desc,
const void* sum_dy_xmu,
const cnnlTensorDescriptor_t count_desc,
const void* count,
const cnnlTensorDescriptor_t diff_x_desc,
void* diff_x) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSyncBatchNormBackwardElemtV2(handle,
diff_y_desc,
diff_y,
x_desc,
x,
mean_desc,
mean,
invstd_desc,
invstd,
weight_desc,
weight,
sum_dy_desc,
sum_dy,
sum_dy_xmu_desc,
sum_dy_xmu,
count_desc,
count,
diff_x_desc,
diff_x));
}
} // namespace operators
} // namespace paddle
......@@ -1276,6 +1276,15 @@ class MLUCnnl {
const cnnlTensorDescriptor_t output_desc,
void* output);
static void Pow(const ExecutionContext& ctx,
cnnlComputationPreference_t prefer,
const cnnlTensorDescriptor_t input1_desc,
const void* input1,
const cnnlTensorDescriptor_t input2_desc,
const void* input2,
const cnnlTensorDescriptor_t output_desc,
void* output);
static void PowR(const ExecutionContext& ctx,
cnnlComputationPreference_t prefer,
const cnnlTensorDescriptor_t input1_desc,
......@@ -2030,8 +2039,152 @@ class MLUCnnl {
const void* boxes,
const cnnlTensorDescriptor_t grads_image_desc,
void* grads_image);
static void SyncBatchNormStats(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const float eps,
const cnnlTensorDescriptor_t mean_desc,
void* mean,
const cnnlTensorDescriptor_t invstd_desc,
void* invstd);
static void SyncBatchNormGatherStatsWithCounts(
const ExecutionContext& ctx,
float momentum,
float eps,
const cnnlTensorDescriptor_t mean_all_desc,
const void* mean_all,
const cnnlTensorDescriptor_t invstd_all_desc,
const void* invstd_all,
const cnnlTensorDescriptor_t moving_mean_desc,
void* moving_mean,
const cnnlTensorDescriptor_t moving_var_desc,
void* moving_var,
const cnnlTensorDescriptor_t count_all_desc,
const void* count_all,
const cnnlTensorDescriptor_t mean_desc,
void* mean,
const cnnlTensorDescriptor_t invstd_desc,
void* invstd);
static void SyncBatchNormElemt(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t mean_desc,
const void* mean,
const cnnlTensorDescriptor_t invstd_desc,
const void* invstd,
const cnnlTensorDescriptor_t weight_desc,
const void* weight,
const cnnlTensorDescriptor_t bias_desc,
const void* bias,
const cnnlTensorDescriptor_t y_desc,
void* y);
static void SyncBatchnormBackwardReduce(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t desc_dz,
const void* dz,
const cnnlTensorDescriptor_t desc_x,
const void* x,
const cnnlTensorDescriptor_t desc_mean,
const void* mean,
const cnnlTensorDescriptor_t desc_invstd,
const void* invstd,
const cnnlTensorDescriptor_t desc_dweight,
void* dweight,
const cnnlTensorDescriptor_t desc_dbias,
void* dbias,
const cnnlTensorDescriptor_t desc_sum_dy,
void* sum_dy,
const cnnlTensorDescriptor_t desc_sum_dy_xmu,
void* sum_dy_xmu,
const bool needs_input_grad0,
const bool needs_input_grad1,
const bool needs_input_grad2);
static void SyncBatchNormBackwardElemt(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t diff_y_desc,
const void* diff_y,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t mean_desc,
const void* mean,
const cnnlTensorDescriptor_t invstd_desc,
const void* invstd,
const cnnlTensorDescriptor_t weight_desc,
const void* weight,
const cnnlTensorDescriptor_t sum_dy_desc,
const void* sum_dy,
const cnnlTensorDescriptor_t sum_dy_xmu_desc,
const void* sum_dy_xmu,
const cnnlTensorDescriptor_t count_desc,
const void* count,
const cnnlTensorDescriptor_t diff_x_desc,
void* diff_x);
};
const std::map<const std::string, std::pair<std::vector<int>, std::vector<int>>>
TransPermMap = {
// trans_mode, (forward_perm, backward_perm)
{"3D_NCHW2NHWC", {{0, 2, 1}, {0, 2, 1}}},
{"4D_NCHW2NHWC", {{0, 2, 3, 1}, {0, 3, 1, 2}}},
{"5D_NCHWD2NDHWC", {{0, 4, 2, 3, 1}, {0, 4, 2, 3, 1}}},
{"5D_NHWDC2NDHWC", {{0, 3, 1, 2, 4}, {0, 2, 3, 4, 1}}}};
inline void SetMLUTransposePerm(const framework::DDim& dims,
const DataLayout& data_layout,
std::vector<int>* forward_perm,
std::vector<int>* backward_perm,
std::vector<int>* out_shape) {
const int dim_size = dims.size();
PADDLE_ENFORCE_EQ((dim_size >= 3) && (dim_size <= 5),
true,
platform::errors::InvalidArgument(
"MLUTransposePerm func only support (dim_size >= 3) && "
"(dim_size <= 5), but now dim_size is %d.",
dim_size));
PADDLE_ENFORCE_EQ(
(data_layout == DataLayout::kNCHW) || (data_layout == DataLayout::kNHWC),
true,
platform::errors::InvalidArgument(
"MLUTransposePerm func only support DataLayout: kNCHW or kNHWC, but "
"now data_layout is %s.",
data_layout));
// case 1: NCHW of Paddle != NHWC of MLU when dims==3,4
// case 2: NHWDC and NCHWD of Paddle != NDHWC of MLU when dims==5
std::string map_key = "";
if (data_layout == DataLayout::kNCHW) {
switch (dim_size) {
case 3:
map_key = "3D_NCHW2NHWC";
break;
case 4:
map_key = "4D_NCHW2NHWC";
break;
case 5:
map_key = "5D_NCHWD2NDHWC";
break;
}
} else if (data_layout == DataLayout::kNHWC && dim_size == 5) {
map_key = "5D_NHWDC2NDHWC";
}
assert(map_key != "");
forward_perm->assign(TransPermMap.at(map_key).first.begin(),
TransPermMap.at(map_key).first.end());
backward_perm->assign(TransPermMap.at(map_key).second.begin(),
TransPermMap.at(map_key).second.end());
auto in_dims = phi::vectorize(dims);
for (size_t i = 0; i < in_dims.size(); i++) {
out_shape->push_back(in_dims[forward_perm->at(i)]);
}
}
template <typename T>
inline void TransposeFromMLUTensor(const ExecutionContext& ctx,
const std::vector<int> perm,
......
此差异已折叠。
......@@ -50,5 +50,7 @@ if(WITH_MLU)
set_tests_properties(test_collective_allgather_api_mlu PROPERTIES TIMEOUT
120)
set_tests_properties(test_c_comm_init_op_mlu PROPERTIES TIMEOUT 120)
set_tests_properties(test_sync_batch_norm_op_mlu_baseline PROPERTIES TIMEOUT
120)
endif()
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
sys.path.append("..")
import signal
import time
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_sync_batch_norm_base_mlu import TestSyncBatchNormRunnerBase, runtime_main
from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
from paddle.fluid.tests.unittests.test_sync_batch_norm_op import create_or_get_tensor
_set_use_system_allocator(False)
paddle.enable_static()
class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
def __init__(self):
self.global_ring_id = 0
self.dtype = np.float32
self.N = 8
self.C = 16
self.H = 32
self.W = 32
self.dshape = [self.N, self.C, self.H, self.W]
self.atol = 1e-3
def get_model(self,
main,
startup,
place,
layout,
seed,
sync_bn=False,
only_forward=False):
"""Build program."""
use_cudnn = False
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
data = fluid.layers.data(name='input',
shape=self.dshape,
dtype=self.dtype,
append_batch_size=False)
conv = fluid.layers.conv2d(
input=data,
num_filters=32,
filter_size=1,
param_attr=fluid.ParamAttr(name='conv2d_weight'),
bias_attr=False,
use_cudnn=use_cudnn)
bn = fluid.layers.batch_norm(
conv,
param_attr=fluid.ParamAttr(name='bn_scale'),
bias_attr=fluid.ParamAttr(name='bn_bias'),
moving_mean_name='bn_moving_mean',
moving_variance_name='bn_moving_variance',
data_layout=layout,
is_test=only_forward)
# if self.dtype == np.float16:
# bn = fluid.layers.cast(bn, 'float32')
sigmoid = fluid.layers.sigmoid(bn)
out = fluid.layers.reduce_sum(sigmoid)
# if not sync_bn:
# out = out / core.get_mlu_device_count()
if not only_forward:
sgd_opt = fluid.optimizer.SGD(learning_rate=0.0)
sgd_opt.backward(out)
return [out, conv, bn]
if __name__ == "__main__":
# print('sync_batch_norm_op_mlu.py __main__')
runtime_main(TestSyncBatchNormOpTraining, "identity", 0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册