未验证 提交 e5bc2eec 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] add FLAGS_npu_storage_format env to enable npu storage format, test=develop (#48774)

上级 c6a2b0fd
...@@ -1041,6 +1041,18 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type, ...@@ -1041,6 +1041,18 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"Predictor", "Predictor",
"Choose default funciton type in JitLayer."); "Choose default funciton type in JitLayer.");
#ifdef PADDLE_WITH_CUSTOM_DEVICE
/**
* Custom Device NPU related FLAG
* Name: FLAGS_npu_storage_format
* Since Version: 2.5.0
* Value Range: bool, default=false
* Example:
* Note: Enable NPU Storage Format for Ascend910 performance improvement.
*/
PADDLE_DEFINE_EXPORTED_bool(npu_storage_format, false, "");
#endif
#ifdef PADDLE_WITH_CUDNN_FRONTEND #ifdef PADDLE_WITH_CUDNN_FRONTEND
/** /**
* CUDNNv8 related FLAG * CUDNNv8 related FLAG
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import inspect import inspect
import numpy as np import numpy as np
import warnings import warnings
...@@ -379,7 +380,11 @@ def monkey_patch_varbase(): ...@@ -379,7 +380,11 @@ def monkey_patch_varbase():
new_ivar = self._grad_ivar() new_ivar = self._grad_ivar()
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if 'npu' in get_all_custom_device_type(): if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type()
):
new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1) new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1)
new_ivar = new_ivar._copy_to(core.CPUPlace(), True) new_ivar = new_ivar._copy_to(core.CPUPlace(), True)
if self._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS: if self._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS:
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode
from paddle.device import ( from paddle.device import (
get_all_custom_device_type, get_all_custom_device_type,
...@@ -149,7 +151,11 @@ def _conv_nd( ...@@ -149,7 +151,11 @@ def _conv_nd(
new_shape[channel_dim] = -1 new_shape[channel_dim] = -1
bias = bias.reshape(new_shape) bias = bias.reshape(new_shape)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if 'npu' in get_all_custom_device_type(): if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type()
):
with no_grad(): with no_grad():
bias_storage = _C_ops.npu_identity( bias_storage = _C_ops.npu_identity(
bias, 3 bias, 3
...@@ -747,7 +753,11 @@ def conv2d( ...@@ -747,7 +753,11 @@ def conv2d(
+ [1 for i in range(len(x.shape) - channel_dim - 1)], + [1 for i in range(len(x.shape) - channel_dim - 1)],
) )
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if 'npu' in get_all_custom_device_type(): if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type()
):
with no_grad(): with no_grad():
bias_storage = _C_ops.npu_identity( bias_storage = _C_ops.npu_identity(
bias, 3 bias, 3
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
# TODO: define normalization api # TODO: define normalization api
import numbers import numbers
import os
import warnings import warnings
import numpy as np import numpy as np
...@@ -681,7 +682,11 @@ class _BatchNormBase(Layer): ...@@ -681,7 +682,11 @@ class _BatchNormBase(Layer):
self._variance.stop_gradient = True self._variance.stop_gradient = True
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if 'npu' in get_all_custom_device_type(): if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type()
):
with no_grad(): with no_grad():
weight_trans = _C_ops.npu_identity( weight_trans = _C_ops.npu_identity(
self.weight, 3 self.weight, 3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册