From f3982a9dd791eacec0f40ebf84e61b48b5f31df6 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Tue, 13 Dec 2022 14:54:16 +0800 Subject: [PATCH] [NPU] fix FLAGS_npu_storage_format flag in python, test=develop (#48976) --- paddle/phi/core/flags.cc | 2 -- python/paddle/fluid/dygraph/varbase_patch_methods.py | 5 ++--- python/paddle/nn/functional/conv.py | 8 ++------ python/paddle/nn/layer/norm.py | 4 +--- 4 files changed, 5 insertions(+), 14 deletions(-) diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index f11e09cf890..ee3caeea367 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1040,7 +1040,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type, "Predictor", "Choose default funciton type in JitLayer."); -#ifdef PADDLE_WITH_CUSTOM_DEVICE /** * Custom Device NPU related FLAG * Name: FLAGS_npu_storage_format @@ -1050,7 +1049,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type, * Note: Enable NPU Storage Format for Ascend910 performance improvement. */ PADDLE_DEFINE_EXPORTED_bool(npu_storage_format, false, ""); -#endif #ifdef PADDLE_WITH_CUDNN_FRONTEND /** diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index f52ba97066c..9ddebbab767 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import inspect import numpy as np import warnings @@ -42,6 +41,7 @@ import paddle.profiler as profiler from paddle.profiler.utils import in_profiler_mode from paddle import _C_ops, _legacy_C_ops from paddle.device import get_all_custom_device_type +from paddle.fluid.framework import _global_flags _grad_scalar = None @@ -381,8 +381,7 @@ def monkey_patch_varbase(): new_ivar = self._grad_ivar() # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op if ( - os.environ.get('FLAGS_npu_storage_format', None) - in [1, '1', True, 'True', 'true'] + _global_flags()['FLAGS_npu_storage_format'] and 'npu' in get_all_custom_device_type() ): new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1) diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index c479cabb4fb..d6f4ee12eea 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode from paddle.device import ( get_all_custom_device_type, @@ -152,8 +150,7 @@ def _conv_nd( bias = bias.reshape(new_shape) # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op if ( - os.environ.get('FLAGS_npu_storage_format', None) - in [1, '1', True, 'True', 'true'] + _global_flags()['FLAGS_npu_storage_format'] and 'npu' in get_all_custom_device_type() ): with no_grad(): @@ -753,8 +750,7 @@ def conv2d( ) # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op if ( - os.environ.get('FLAGS_npu_storage_format', None) - in [1, '1', True, 'True', 'true'] + _global_flags()['FLAGS_npu_storage_format'] and 'npu' in get_all_custom_device_type() ): with no_grad(): diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index f446970ee0e..e2842e1944d 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -28,7 +28,6 @@ # TODO: define normalization api import numbers -import os import warnings import numpy as np @@ -688,8 +687,7 @@ class _BatchNormBase(Layer): # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op if ( - os.environ.get('FLAGS_npu_storage_format', None) - in [1, '1', True, 'True', 'true'] + _global_flags()['FLAGS_npu_storage_format'] and 'npu' in get_all_custom_device_type() ): with no_grad(): -- GitLab