未验证 提交 f3982a9d 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] fix FLAGS_npu_storage_format flag in python, test=develop (#48976)

上级 29d9dbe3
...@@ -1040,7 +1040,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type, ...@@ -1040,7 +1040,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"Predictor", "Predictor",
"Choose default funciton type in JitLayer."); "Choose default funciton type in JitLayer.");
#ifdef PADDLE_WITH_CUSTOM_DEVICE
/** /**
* Custom Device NPU related FLAG * Custom Device NPU related FLAG
* Name: FLAGS_npu_storage_format * Name: FLAGS_npu_storage_format
...@@ -1050,7 +1049,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type, ...@@ -1050,7 +1049,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
* Note: Enable NPU Storage Format for Ascend910 performance improvement. * Note: Enable NPU Storage Format for Ascend910 performance improvement.
*/ */
PADDLE_DEFINE_EXPORTED_bool(npu_storage_format, false, ""); PADDLE_DEFINE_EXPORTED_bool(npu_storage_format, false, "");
#endif
#ifdef PADDLE_WITH_CUDNN_FRONTEND #ifdef PADDLE_WITH_CUDNN_FRONTEND
/** /**
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import inspect import inspect
import numpy as np import numpy as np
import warnings import warnings
...@@ -42,6 +41,7 @@ import paddle.profiler as profiler ...@@ -42,6 +41,7 @@ import paddle.profiler as profiler
from paddle.profiler.utils import in_profiler_mode from paddle.profiler.utils import in_profiler_mode
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.device import get_all_custom_device_type from paddle.device import get_all_custom_device_type
from paddle.fluid.framework import _global_flags
_grad_scalar = None _grad_scalar = None
...@@ -381,8 +381,7 @@ def monkey_patch_varbase(): ...@@ -381,8 +381,7 @@ def monkey_patch_varbase():
new_ivar = self._grad_ivar() new_ivar = self._grad_ivar()
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if ( if (
os.environ.get('FLAGS_npu_storage_format', None) _global_flags()['FLAGS_npu_storage_format']
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type() and 'npu' in get_all_custom_device_type()
): ):
new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1) new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1)
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode
from paddle.device import ( from paddle.device import (
get_all_custom_device_type, get_all_custom_device_type,
...@@ -152,8 +150,7 @@ def _conv_nd( ...@@ -152,8 +150,7 @@ def _conv_nd(
bias = bias.reshape(new_shape) bias = bias.reshape(new_shape)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if ( if (
os.environ.get('FLAGS_npu_storage_format', None) _global_flags()['FLAGS_npu_storage_format']
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type() and 'npu' in get_all_custom_device_type()
): ):
with no_grad(): with no_grad():
...@@ -753,8 +750,7 @@ def conv2d( ...@@ -753,8 +750,7 @@ def conv2d(
) )
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if ( if (
os.environ.get('FLAGS_npu_storage_format', None) _global_flags()['FLAGS_npu_storage_format']
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type() and 'npu' in get_all_custom_device_type()
): ):
with no_grad(): with no_grad():
......
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
# TODO: define normalization api # TODO: define normalization api
import numbers import numbers
import os
import warnings import warnings
import numpy as np import numpy as np
...@@ -688,8 +687,7 @@ class _BatchNormBase(Layer): ...@@ -688,8 +687,7 @@ class _BatchNormBase(Layer):
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if ( if (
os.environ.get('FLAGS_npu_storage_format', None) _global_flags()['FLAGS_npu_storage_format']
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type() and 'npu' in get_all_custom_device_type()
): ):
with no_grad(): with no_grad():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册