未验证 提交 f3982a9d 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] fix FLAGS_npu_storage_format flag in python, test=develop (#48976)

上级 29d9dbe3
......@@ -1040,7 +1040,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"Predictor",
"Choose default funciton type in JitLayer.");
#ifdef PADDLE_WITH_CUSTOM_DEVICE
/**
* Custom Device NPU related FLAG
* Name: FLAGS_npu_storage_format
......@@ -1050,7 +1049,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
* Note: Enable NPU Storage Format for Ascend910 performance improvement.
*/
PADDLE_DEFINE_EXPORTED_bool(npu_storage_format, false, "");
#endif
#ifdef PADDLE_WITH_CUDNN_FRONTEND
/**
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import inspect
import numpy as np
import warnings
......@@ -42,6 +41,7 @@ import paddle.profiler as profiler
from paddle.profiler.utils import in_profiler_mode
from paddle import _C_ops, _legacy_C_ops
from paddle.device import get_all_custom_device_type
from paddle.fluid.framework import _global_flags
_grad_scalar = None
......@@ -381,8 +381,7 @@ def monkey_patch_varbase():
new_ivar = self._grad_ivar()
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1)
......
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode
from paddle.device import (
get_all_custom_device_type,
......@@ -152,8 +150,7 @@ def _conv_nd(
bias = bias.reshape(new_shape)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
......@@ -753,8 +750,7 @@ def conv2d(
)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
......
......@@ -28,7 +28,6 @@
# TODO: define normalization api
import numbers
import os
import warnings
import numpy as np
......@@ -688,8 +687,7 @@ class _BatchNormBase(Layer):
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册