未验证 提交 f49b0cb9 编写于 作者: Q qipengh 提交者: GitHub

[MLU]fix sync_batch_norm and concat_grad op (#44586)

上级 84d595fa
......@@ -121,6 +121,7 @@ class ConcatGradMLUKernel : public framework::OpKernel<T> {
out_grad->dims().size()));
// get output tensor that the name is not kEmptyVarName
std::vector<void*> outputs_vec;
std::vector<Tensor> tmp_outputs_vec;
std::vector<MLUCnnlTensorDesc> output_descs;
std::vector<cnnlTensorDescriptor_t> descs_vec;
for (size_t j = 0; j < outs.size(); ++j) {
......@@ -128,11 +129,15 @@ class ConcatGradMLUKernel : public framework::OpKernel<T> {
outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace());
output_descs.emplace_back(MLUCnnlTensorDesc(*outs[j]));
descs_vec.push_back(output_descs.back().get());
outputs_vec.push_back(GetBasePtr(outs[j]));
} else {
outputs_vec.push_back(nullptr);
Tensor tmp_tensor;
tmp_tensor.mutable_data<T>(ins[j]->dims(), ctx.GetPlace());
tmp_outputs_vec.push_back(tmp_tensor);
output_descs.emplace_back(MLUCnnlTensorDesc(*ins[j]));
outputs_vec.push_back(GetBasePtr(&(tmp_outputs_vec.back())));
}
descs_vec.push_back(output_descs.back().get());
}
MLUCnnlTensorDesc out_grad_desc(*out_grad);
......
......@@ -23,7 +23,9 @@ limitations under the Licnse. */
namespace paddle {
namespace operators {
#define NO_USE_CNCL 0
#define GET_LAYOUT_OFFSET 2
using Tensor = framework::Tensor;
static std::vector<cnnlTensorLayout_t> supported_input_layout = {
CNNL_LAYOUT_NC, CNNL_LAYOUT_NLC, CNNL_LAYOUT_NHWC, CNNL_LAYOUT_NDHWC};
......@@ -165,6 +167,7 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
Tensor mean_all(mean->dtype());
Tensor invstd_all(variance->dtype());
#ifdef PADDLE_WITH_CNCL
auto &dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto stream = dev_ctx.stream();
......@@ -205,7 +208,9 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
cncl_dtype,
comm,
stream));
#else
if (NO_USE_CNCL) {
#endif
} else {
count_all = input_count;
mean_all.ShareDataWith(local_mean);
......@@ -404,6 +409,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
FillMLUTensorWithHostValue<int32_t>(
ctx, static_cast<int32_t>(x->numel() / C), &numel_count);
#ifdef PADDLE_WITH_CNCL
auto &dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto stream = dev_ctx.stream();
......@@ -440,6 +446,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
comm,
stream));
}
#endif
if (d_x) {
MLUCnnlTensorDesc desc_count(numel_count);
......
......@@ -35,9 +35,9 @@ from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_sync_batch_norm_base_mlu import TestSyncBatchNormRunnerBase, runtime_main
from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
from op_test import OpTest, _set_use_system_allocator
from paddle.fluid.tests.unittests.test_sync_batch_norm_op import create_or_get_tensor
from test_sync_batch_norm_op import create_or_get_tensor
_set_use_system_allocator(False)
paddle.enable_static()
......
......@@ -33,9 +33,9 @@ from paddle.fluid import core
from six import string_types
import paddle
from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
from op_test import OpTest, _set_use_system_allocator
from paddle.fluid.tests.unittests.test_sync_batch_norm_op import create_or_get_tensor
from test_sync_batch_norm_op import create_or_get_tensor
_set_use_system_allocator(False)
paddle.enable_static()
......
#!/bin/bash
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
MLU_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch test_sync_batch_norm_op_mlu_baseline.py
......@@ -20,7 +20,7 @@ import os
import sys
sys.path.append("..")
from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
from op_test import OpTest, _set_use_system_allocator
from test_sync_batch_norm_base_mlu import TestDistBase
......
......@@ -29,8 +29,9 @@ import paddle.fluid as fluid
import paddle.nn as nn
from paddle.fluid import Program, program_guard
from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
from paddle.fluid.tests.unittests.test_dist_base import TestDistBase
sys.path.append("..")
from op_test import OpTest, _set_use_system_allocator
from test_dist_base import TestDistBase
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册