未验证 提交 d7a1a178 编写于 作者: J jjyaoao 提交者: GitHub

delete paddle/fluid/operators/amp/*_npu.* (#52673)

* delete paddle/fluid/operators/*_npu.*

* try pass code-style
上级 03afb41c
......@@ -77,6 +77,7 @@ tools/nvcc_lazy
paddle/fluid/pybind/eager_op_function.cc
tools/nvcc_lazy
# these files (directories) are generated before build system generation
paddle/fluid/operators/generated_op*.cc
paddle/fluid/operators/generated_sparse_op.cc
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class AllocFloatStatusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* float_status = ctx.Output<phi::DenseTensor>("FloatStatus");
float_status->mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("NPUAllocFloatStatus", {}, {*float_status});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
alloc_float_status,
ops::AllocFloatStatusKernel<paddle::platform::NPUDeviceContext, float>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace operators {
// NOTE(zhiqiu): The CheckFiniteAndUnscaleNPUKernel is different from CUDA.
// On NPU, we do not really check the data of input tensors,
// but use NPUGetFloatStatus to check whether the nan/inf occurs on device,
// and clear it after this op.
// Which may leads to wrong result if the input tensors is not calculated
// on NPU device, but got from other way, for example, feeding.
template <typename T>
class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const auto xs = ctx.MultiInput<phi::DenseTensor>("X");
const auto* scale = ctx.Input<phi::DenseTensor>("Scale");
const auto* float_status = ctx.Input<phi::DenseTensor>("FloatStatus");
auto outs = ctx.MultiOutput<phi::DenseTensor>("Out");
auto* found_inf = ctx.Output<phi::DenseTensor>("FoundInfinite");
found_inf->mutable_data<bool>(ctx.GetPlace());
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// step1: inverse scale
phi::DenseTensor const_tensor;
const_tensor.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&const_tensor, static_cast<T>(1.0));
// Inverse(1.0/scale)
phi::DenseTensor* tmp_inverse_out = const_cast<phi::DenseTensor*>(scale);
phi::DenseTensor inverse_out(scale->type());
inverse_out.Resize(scale->dims());
inverse_out.mutable_data<T>(ctx.GetPlace());
const auto& runner_inverse =
NpuOpRunner("Div", {const_tensor, *scale}, {inverse_out}, {});
runner_inverse.Run(stream);
tmp_inverse_out = &inverse_out;
// NOTE(zhiqiu):
phi::DenseTensor tmp;
tmp.mutable_data<float>({8}, ctx.GetPlace());
// NOTE(zhiqiu): NPUGetFloatStatus updates data on input in-place.
// tmp is only placeholder.
const auto& runner_float_status =
NpuOpRunner("NPUGetFloatStatus",
{*float_status},
{tmp},
{{"message", std::string("check_nan_and_inf")}});
runner_float_status.Run(stream);
phi::DenseTensor sum;
sum.mutable_data<float>({1}, ctx.GetPlace());
const auto& runner_reduce_sum =
NpuOpRunner("ReduceSumD",
{*float_status},
{sum},
{{"axes", std::vector<int>{0}}, {"keep_dims", true}});
runner_reduce_sum.Run(stream);
const auto& runner_greater =
NpuOpRunner("GreaterEqual", {sum, const_tensor}, {*found_inf}, {});
runner_greater.Run(stream);
// NOTE(zhiqiu): The normal logic is :
// out = in, if found_inf = true
// out = in/scale, if found_inf = false
// However, on NPU, in order to avoid stream sync, we do not copy the
// found_inf data to cpu to check whether to unscale or not.
// Instead, we do the Mul no matter found_inf or not.
// And, a fact is, only few steps contains nan/inf during training.
for (size_t i = 0; i < xs.size(); ++i) {
const auto* x = xs[i];
auto* out = outs[i];
out->mutable_data<T>(ctx.GetPlace());
const auto& runner_mul =
NpuOpRunner("Mul", {*x, *tmp_inverse_out}, {*out}, {});
runner_mul.Run(stream);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(check_finite_and_unscale,
ops::CheckFiniteAndUnscaleNPUKernel<float>,
ops::CheckFiniteAndUnscaleNPUKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef _WIN32
#include <unistd.h>
#endif
#include <algorithm>
#include <cstdlib>
#include <memory>
#include <random>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace f = paddle::framework;
namespace p = paddle::platform;
USE_OP_ITSELF(check_finite_and_unscale);
USE_OP_DEVICE_KERNEL(check_finite_and_unscale, NPU);
struct InputVars {
std::string name;
phi::DenseTensor *tensor;
};
template <typename T>
void Compare(f::Scope *scope, const p::DeviceContext &ctx) {
const f::DDim dims = phi::make_ddim({2, 2});
auto place = ctx.GetPlace();
// init input
std::vector<InputVars> input_names = {
{"x", scope->Var("x")->GetMutable<phi::DenseTensor>()},
{"x1", scope->Var("x1")->GetMutable<phi::DenseTensor>()}};
auto *scale = scope->Var("scale")->GetMutable<phi::DenseTensor>();
// init output
auto *out = scope->Var("out")->GetMutable<phi::DenseTensor>();
auto *out1 = scope->Var("out1")->GetMutable<phi::DenseTensor>();
auto *found_inf = scope->Var("found_inf")->GetMutable<phi::DenseTensor>();
// Initialize input data
const int num_inputs = input_names.size();
size_t numel = static_cast<size_t>(phi::product(dims));
for (int i = 0; i < num_inputs; ++i) {
std::vector<T> init_xs;
for (size_t j = 0; j < numel; ++j) {
if (j == 0) {
init_xs.push_back(static_cast<T>(NAN));
} else {
init_xs.push_back(static_cast<T>(j + 1));
}
}
f::TensorFromVector(init_xs, ctx, input_names[i].tensor);
input_names[i].tensor->Resize(dims);
}
f::TensorFromVector(std::vector<T>{static_cast<T>(0.5)}, ctx, scale);
ctx.Wait();
// run
f::AttributeMap attrs;
auto op = f::OpRegistry::CreateOp(
"check_finite_and_unscale",
{{"X", {"x", "x1"}}, {"Scale", {"scale"}}},
{{"Out", {"out", "out1"}}, {"FoundInfinite", {"found_inf"}}},
attrs);
op->Run(*scope, place);
ctx.Wait();
// out0
std::vector<T> out_vec;
f::TensorToVector(*out, ctx, &out_vec);
EXPECT_EQ(out_vec.size(), static_cast<size_t>(4));
for (size_t j = 0; j < out_vec.size(); ++j) {
VLOG(3) << "out_vec[" << j << "]:" << out_vec[j];
}
ctx.Wait();
// out0
std::vector<T> out1_vec;
f::TensorToVector(*out1, ctx, &out1_vec);
EXPECT_EQ(out1_vec.size(), static_cast<size_t>(4));
for (size_t j = 0; j < out1_vec.size(); ++j) {
VLOG(3) << "out1_vec[" << j << "]:" << out1_vec[j];
}
ctx.Wait();
// out found_inf
phi::DenseTensor found_inf_tensor;
found_inf_tensor.Resize({1});
bool *found_inf_data =
found_inf_tensor.mutable_data<bool>(paddle::platform::CPUPlace());
f::TensorCopy(*found_inf, place, &found_inf_tensor);
EXPECT_TRUE(*found_inf_data);
ctx.Wait();
}
TEST(check_finite_and_unscale, NPU_fp32) {
f::Scope scope;
auto *ctx = p::DeviceContextPool::Instance().Get(p::NPUPlace(0));
Compare<float>(&scope, *ctx);
}
TEST(check_finite_and_unscale, NPU_fp16) {
f::Scope scope;
auto *ctx = p::DeviceContextPool::Instance().Get(p::NPUPlace(0));
Compare<p::float16>(&scope, *ctx);
}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class ClearFloatStatusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* float_status = ctx.Input<phi::DenseTensor>("FloatStatus");
auto* float_status_out = ctx.Output<phi::DenseTensor>("FloatStatusOut");
// NOTE(zhiqiu): NPUClearFloatStatus modifies the input.
PADDLE_ENFORCE_EQ(float_status_out,
float_status,
platform::errors::PreconditionNotMet(
"The input(FloatStatus) and Output(FloatStatusOut) "
"should be the same."));
phi::DenseTensor tmp;
tmp.mutable_data<float>({8}, ctx.GetPlace());
const auto& runner =
NpuOpRunner("NPUClearFloatStatus", {tmp}, {*float_status_out});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
clear_float_status,
ops::ClearFloatStatusKernel<paddle::platform::NPUDeviceContext, float>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class GetFloatStatusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* float_status = ctx.Input<phi::DenseTensor>("FloatStatus");
auto* float_status_out = ctx.Output<phi::DenseTensor>("FloatStatusOut");
// GetClearFloatStatus modifies the input.
PADDLE_ENFORCE_EQ(float_status_out,
float_status,
platform::errors::PreconditionNotMet(
"The input(FloatStatus) and Output(FloatStatusOut) "
"should be the same."));
phi::DenseTensor tmp;
tmp.mutable_data<float>({8}, ctx.GetPlace());
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// NPUGetFloatStatus updates data on input in-place.
// tmp is only placeholder.
NpuOpRunner("NPUGetFloatStatus", {*float_status}, {tmp}).Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
get_float_status,
ops::GetFloatStatusKernel<paddle::platform::NPUDeviceContext, float>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
DECLARE_int32(min_loss_scaling);
namespace paddle {
namespace operators {
template <typename T>
void Update(const platform::NPUDeviceContext& ctx,
const std::vector<bool> found_inf_vec,
const phi::DenseTensor* pre_loss_scaling_tensor,
const phi::DenseTensor* good_in_tensor,
const phi::DenseTensor* bad_in_tensor,
const int incr_every_n_steps,
const int decr_every_n_nan_or_inf,
const float incr_ratio,
const float decr_ratio,
phi::DenseTensor* updated_loss_scaling_tensor,
phi::DenseTensor* good_out_tensor,
phi::DenseTensor* bad_out_tensor) {
auto place = ctx.GetPlace();
auto stream = ctx.stream();
if (found_inf_vec[0]) {
// good_out_data = 0
auto g = good_out_tensor->mutable_data<int>(place);
platform::NPUMemsetAsync(static_cast<void*>(g),
0,
good_out_tensor->numel() * sizeof(int),
stream);
// bad_out_data = bad_in_data + 1
phi::DenseTensor factor_tensor(bad_out_tensor->dtype());
factor_tensor.mutable_data<int>({1}, place);
FillNpuTensorWithConstant<int>(&factor_tensor, static_cast<int>(1));
const auto& runner_p2 = NpuOpRunner(
"Add", {*bad_in_tensor, factor_tensor}, {*bad_out_tensor}, {});
runner_p2.Run(stream);
std::vector<int> bad_out_data;
paddle::framework::TensorToVector(*bad_out_tensor, ctx, &bad_out_data);
if (bad_out_data[0] >= decr_every_n_nan_or_inf) {
const auto& runner_p3 = NpuOpRunner("Power",
{*pre_loss_scaling_tensor},
{*updated_loss_scaling_tensor},
{{"power", static_cast<float>(1)},
{"scale", decr_ratio},
{"shift", static_cast<float>(0)}});
runner_p3.Run(stream);
std::vector<T> new_loss_scaling;
paddle::framework::TensorToVector(
*updated_loss_scaling_tensor, ctx, &new_loss_scaling);
float min_value = 1.0;
if (FLAGS_min_loss_scaling > 1) {
min_value = static_cast<float>(FLAGS_min_loss_scaling);
}
if (new_loss_scaling[0] < min_value) {
// updated_loss_scaling_data = 1
const auto& runner_p4 =
NpuOpRunner("Power",
{*pre_loss_scaling_tensor},
{*updated_loss_scaling_tensor},
{{"power", static_cast<float>(1)},
{"scale", static_cast<float>(0)},
{"shift", static_cast<float>(min_value)}});
runner_p4.Run(stream);
}
// bad_out_data = 0
auto b = bad_out_tensor->mutable_data<int>(place);
platform::NPUMemsetAsync(static_cast<void*>(b),
0,
bad_out_tensor->numel() * sizeof(int),
stream);
}
} else {
// bad_out_data = 0
auto b = bad_out_tensor->mutable_data<int>(place);
platform::NPUMemsetAsync(static_cast<void*>(b),
0,
bad_out_tensor->numel() * sizeof(int),
stream);
// good_out_data = good_in_data + 1
phi::DenseTensor factor_tensor(good_out_tensor->dtype());
factor_tensor.mutable_data<int>({1}, place);
FillNpuTensorWithConstant<int>(&factor_tensor, static_cast<int>(1));
const auto& runner_p2 = NpuOpRunner(
"Add", {*good_in_tensor, factor_tensor}, {*good_out_tensor}, {});
runner_p2.Run(stream);
std::vector<int> good_out_data;
paddle::framework::TensorToVector(*good_out_tensor, ctx, &good_out_data);
if (good_out_data[0] >= incr_every_n_steps) {
const auto& runner_p3 = NpuOpRunner("Power",
{*pre_loss_scaling_tensor},
{*updated_loss_scaling_tensor},
{{"power", static_cast<float>(1)},
{"scale", incr_ratio},
{"shift", static_cast<float>(0)}});
runner_p3.Run(stream);
std::vector<T> new_loss_scaling;
paddle::framework::TensorToVector(
*updated_loss_scaling_tensor, ctx, &new_loss_scaling);
if (!std::isfinite(new_loss_scaling[0])) {
// updated_loss_scaling_data = pre_loss_scaling_data
const auto& runner_p4 = NpuOpRunner("Power",
{*pre_loss_scaling_tensor},
{*updated_loss_scaling_tensor},
{{"power", static_cast<float>(1)},
{"scale", static_cast<float>(1)},
{"shift", static_cast<float>(0)}});
runner_p4.Run(stream);
}
// good_out_data = 0
auto g = good_out_tensor->mutable_data<int>(place);
platform::NPUMemsetAsync(static_cast<void*>(g),
0,
good_out_tensor->numel() * sizeof(int),
stream);
}
}
}
template <typename T>
class UpdateLossScalingFunctor {
public:
void operator()(const platform::NPUDeviceContext& dev_ctx,
const std::vector<bool> found_inf_vec,
const phi::DenseTensor* pre_loss_scaling_tensor,
const phi::DenseTensor* good_in_tensor,
const phi::DenseTensor* bad_in_tensor,
const int incr_every_n_steps,
const int decr_every_n_nan_or_inf,
const float incr_ratio,
const float decr_ratio,
phi::DenseTensor* updated_loss_scaling_tensor,
phi::DenseTensor* good_out_tensor,
phi::DenseTensor* bad_out_tensor) const {
Update<T>(dev_ctx,
found_inf_vec,
pre_loss_scaling_tensor,
good_in_tensor,
bad_in_tensor,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
updated_loss_scaling_tensor,
good_out_tensor,
bad_out_tensor);
}
};
template <typename T>
class LazyZerosNPU {
public:
void operator()(const platform::NPUDeviceContext& dev_ctx,
const std::vector<bool> found_inf_vec,
const std::vector<const phi::DenseTensor*>& xs,
const std::vector<phi::DenseTensor*>& outs) const {
if (!xs.size()) {
return;
}
auto place = dev_ctx.GetPlace();
auto stream = dev_ctx.stream();
phi::DenseTensor* zero_tensor = nullptr;
void* zero_ptr = nullptr;
if (found_inf_vec[0]) {
int max_num = -1;
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
int num = out->numel();
if (max_num < num) {
max_num = num;
zero_tensor = out;
}
}
zero_tensor->mutable_data<T>(place);
const auto& runner_zeros =
NpuOpRunner("ZerosLike", {*zero_tensor}, {*zero_tensor});
runner_zeros.Run(stream);
zero_tensor->check_memory_size();
zero_ptr = zero_tensor->data();
}
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
auto* x = xs[i];
auto dst_ptr = out->mutable_data<T>(place);
if (!found_inf_vec[0]) {
framework::TensorCopy(*x, place, dev_ctx, out);
} else if (zero_ptr != dst_ptr) {
auto size = out->numel() * phi::SizeOf(out->dtype());
memory::Copy(place, dst_ptr, place, zero_ptr, size, stream);
}
}
}
};
template <typename DeviceContext, typename T>
class UpdateLossScalingNPUKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const auto xs = ctx.MultiInput<phi::DenseTensor>("X");
auto outs = ctx.MultiOutput<phi::DenseTensor>("Out");
const auto* found_inf = ctx.Input<phi::DenseTensor>("FoundInfinite");
PADDLE_ENFORCE_EQ(found_inf->numel(),
1,
platform::errors::InvalidArgument(
"FoundInfinite must has only one element."));
std::vector<bool> found_inf_vec;
paddle::framework::TensorToVector(
*found_inf, ctx.device_context(), &found_inf_vec);
LazyZerosNPU<T>{}(dev_ctx, found_inf_vec, xs, outs);
const bool stop_update = ctx.Attr<bool>("stop_update");
if (stop_update) {
return;
}
const auto* pre_loss_scaling =
ctx.Input<phi::DenseTensor>("PrevLossScaling");
const auto* good_in = ctx.Input<phi::DenseTensor>("InGoodSteps");
const auto* bad_in = ctx.Input<phi::DenseTensor>("InBadSteps");
auto* updated_loss_scaling = ctx.Output<phi::DenseTensor>("LossScaling");
auto* good_out = ctx.Output<phi::DenseTensor>("OutGoodSteps");
auto* bad_out = ctx.Output<phi::DenseTensor>("OutBadSteps");
updated_loss_scaling->mutable_data<MPDType>(dev_ctx.GetPlace());
good_out->mutable_data<int>(dev_ctx.GetPlace());
bad_out->mutable_data<int>(dev_ctx.GetPlace());
const int incr_every_n_steps = ctx.Attr<int>("incr_every_n_steps");
const int decr_every_n_nan_or_inf =
ctx.Attr<int>("decr_every_n_nan_or_inf");
const float incr_ratio = ctx.Attr<float>("incr_ratio");
const float decr_ratio = ctx.Attr<float>("decr_ratio");
UpdateLossScalingFunctor<MPDType>{}(dev_ctx,
found_inf_vec,
pre_loss_scaling,
good_in,
bad_in,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
updated_loss_scaling,
good_out,
bad_out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
update_loss_scaling,
ops::UpdateLossScalingNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::UpdateLossScalingNPUKernel<paddle::platform::NPUDeviceContext,
double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册