未验证 提交 ad81f22c 编写于 作者: Q qipengh 提交者: GitHub

[MLU] support amp O1 of mlu (#40461)

上级 f748b433
......@@ -34,6 +34,14 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
return;
}
// NOTE(hqp): Special case for CPU->MLU, avoid stream sync.
if (platform::is_cpu_place(in.place()) && platform::is_mlu_place(dst_place)) {
paddle::framework::TensorCopy(
in, dst_place, *platform::DeviceContextPool::Instance().Get(dst_place),
out);
return;
}
// NOTE(yy): TransDataDevice should wait for computation of input.
if (!platform::is_cuda_pinned_place(in.place())) {
platform::DeviceContextPool::Instance().Get(in.place())->Wait();
......
......@@ -124,7 +124,7 @@ AmpOperators::AmpOperators()
OpSupportedInfos("GPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_gpu_bf16.begin(),
unsupported_ops_gpu_bf16.end());
// NOTE: GPU/NPU/XPU is compiled seperatly.
// NOTE: GPU/NPU/XPU/MLU is compiled seperatly.
#elif defined(PADDLE_WITH_ASCEND_CL)
auto unsupported_ops_npu_fp16 = std::get<2>(
OpSupportedInfos("NPU", paddle::framework::proto::VarType::FP16));
......@@ -143,6 +143,15 @@ AmpOperators::AmpOperators()
OpSupportedInfos("XPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_xpu_bf16.begin(),
unsupported_ops_xpu_bf16.end());
#elif defined(PADDLE_WITH_MLU)
auto unsupported_ops_mlu_fp16 = std::get<2>(
OpSupportedInfos("MLU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_mlu_fp16.begin(),
unsupported_ops_mlu_fp16.end());
auto unsupported_ops_mlu_bf16 = std::get<2>(
OpSupportedInfos("MLU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_mlu_bf16.begin(),
unsupported_ops_mlu_bf16.end());
#endif
VLOG(4) << allow_ops_->size() << " " << block_ops_->size() << " "
<< unsupported_fp16_ops_->size() << " "
......@@ -210,6 +219,7 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) {
if (paddle::platform::is_gpu_place(place) ||
paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
......@@ -20,6 +21,8 @@ namespace operators {
template <typename T>
class MLUBatchNormOpKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto &place = ctx.GetPlace();
......@@ -68,10 +71,10 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> {
// alloc memory
y->mutable_data<T>(place);
mean_out->mutable_data<T>(place);
variance_out->mutable_data<T>(place);
saved_mean->mutable_data<T>(place);
saved_variance->mutable_data<T>(place);
mean_out->mutable_data<MPDType>(place);
variance_out->mutable_data<MPDType>(place);
saved_mean->mutable_data<MPDType>(place);
saved_variance->mutable_data<MPDType>(place);
Tensor transformed_x;
Tensor transformed_y;
......@@ -132,6 +135,8 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> {
template <typename T>
class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
......@@ -154,10 +159,10 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.template device_context<MLUDeviceContext>();
auto d_x_tmp =
ctx.AllocateTmpTensor<T, MLUDeviceContext>(x->dims(), dev_ctx);
auto scale_grad_tmp =
ctx.AllocateTmpTensor<T, MLUDeviceContext>(scale->dims(), dev_ctx);
auto scale_grad_tmp = ctx.AllocateTmpTensor<MPDType, MLUDeviceContext>(
scale->dims(), dev_ctx);
auto bias_grad_tmp =
ctx.AllocateTmpTensor<T, MLUDeviceContext>(bias->dims(), dev_ctx);
ctx.AllocateTmpTensor<MPDType, MLUDeviceContext>(bias->dims(), dev_ctx);
if (d_x == nullptr) {
d_x = &d_x_tmp;
......@@ -171,8 +176,8 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
const auto &place = ctx.GetPlace();
d_x->mutable_data<T>(place);
d_scale->mutable_data<T>(place);
d_bias->mutable_data<T>(place);
d_scale->mutable_data<MPDType>(place);
d_bias->mutable_data<MPDType>(place);
use_global_stats = is_test || use_global_stats;
......
......@@ -173,6 +173,9 @@ if core.is_compiled_with_xpu():
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'NPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'MLU', core.VarDesc.VarType.FP16)
else:
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'GPU', core.VarDesc.VarType.FP16)
......
......@@ -271,13 +271,14 @@ def amp_guard(enable=True,
"current_tracer is None, maybe it is not in imperative mode.")
# check device_type:
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, npu for float16.
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, mlu for float16, npu for float16.
# Maybe we will support cpu for bfloat16.
if enable and not (tracer._expected_place.is_gpu_place() or
tracer._expected_place.is_xpu_place() or
tracer._expected_place.is_mlu_place() or
tracer._expected_place.is_npu_place()):
warnings.warn(
'amp_guard can only be enabled on CUDAPlace, XPUPlace, and NPUPlace, current place is %s, so it makes no effect.'
'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, and NPUPlace, current place is %s, so it makes no effect.'
% tracer._expected_place)
enable = False
# For npu:
......@@ -288,6 +289,10 @@ def amp_guard(enable=True,
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
warnings.warn('XPUPlace only support float16 amp.')
enable = False
# For mlu:
if tracer._expected_place.is_mlu_place() and (dtype == 'bfloat16'):
warnings.warn('MLUPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
......
......@@ -106,9 +106,10 @@ class AmpScaler(object):
if enable and not (tracer._expected_place.is_gpu_place() or
tracer._expected_place.is_xpu_place() or
tracer._expected_place.is_mlu_place() or
tracer._expected_place.is_npu_place()):
warnings.warn(
'AmpScaler can only be enabled on CUDAPlace, XPUPlace and NPUPlace, current place is %s, so it makes no effect.'
'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace and NPUPlace, current place is %s, so it makes no effect.'
% tracer._expected_place)
enable = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册