未验证 提交 b4f67757 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move amp ops into phi (#45079)

* move check finite and unscale kernel into phi

* move infershape into phi

* move update_loss_scaling kernel into phi

* remove original kernels

* move update loss scaling infershape into phi

* add header for xpu and npu

* solve coverage failed

* fix npu test failed

* remove mutable data in cu file

* fix new executor failed

* add valid check for meta tensor output
上级 88724a53
...@@ -127,7 +127,15 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -127,7 +127,15 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
const InferShapeContext& ctx_; const InferShapeContext& ctx_;
}; };
static inline void ValidCheck(const phi::MetaTensor& meta_tensor) {
PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
true,
phi::errors::InvalidArgument(
"The current CompatMetaTensor is not initialized."));
}
int64_t CompatMetaTensor::numel() const { int64_t CompatMetaTensor::numel() const {
ValidCheck(*this);
if (is_runtime_) { if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_); auto* var = PADDLE_GET_CONST(Variable*, var_);
return var->Get<Tensor>().numel(); return var->Get<Tensor>().numel();
...@@ -138,6 +146,7 @@ int64_t CompatMetaTensor::numel() const { ...@@ -138,6 +146,7 @@ int64_t CompatMetaTensor::numel() const {
} }
DDim CompatMetaTensor::dims() const { DDim CompatMetaTensor::dims() const {
ValidCheck(*this);
if (is_runtime_) { if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_); auto* var = PADDLE_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
...@@ -162,6 +171,7 @@ DDim CompatMetaTensor::dims() const { ...@@ -162,6 +171,7 @@ DDim CompatMetaTensor::dims() const {
} }
phi::DataType CompatMetaTensor::dtype() const { phi::DataType CompatMetaTensor::dtype() const {
ValidCheck(*this);
if (is_runtime_) { if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_); auto* var = PADDLE_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
...@@ -183,6 +193,7 @@ phi::DataType CompatMetaTensor::dtype() const { ...@@ -183,6 +193,7 @@ phi::DataType CompatMetaTensor::dtype() const {
} }
DataLayout CompatMetaTensor::layout() const { DataLayout CompatMetaTensor::layout() const {
ValidCheck(*this);
if (is_runtime_) { if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_); auto* var = PADDLE_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
...@@ -206,6 +217,7 @@ DataLayout CompatMetaTensor::layout() const { ...@@ -206,6 +217,7 @@ DataLayout CompatMetaTensor::layout() const {
} }
void CompatMetaTensor::set_dims(const DDim& dims) { void CompatMetaTensor::set_dims(const DDim& dims) {
ValidCheck(*this);
if (is_runtime_) { if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_); auto* var = PADDLE_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
...@@ -236,6 +248,7 @@ void CompatMetaTensor::set_dims(const DDim& dims) { ...@@ -236,6 +248,7 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
} }
void CompatMetaTensor::set_dtype(phi::DataType dtype) { void CompatMetaTensor::set_dtype(phi::DataType dtype) {
ValidCheck(*this);
if (is_runtime_) { if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_); auto* var = PADDLE_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
...@@ -258,6 +271,7 @@ void CompatMetaTensor::set_dtype(phi::DataType dtype) { ...@@ -258,6 +271,7 @@ void CompatMetaTensor::set_dtype(phi::DataType dtype) {
} }
void CompatMetaTensor::set_layout(DataLayout layout) { void CompatMetaTensor::set_layout(DataLayout layout) {
ValidCheck(*this);
if (is_runtime_) { if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_); auto* var = PADDLE_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
...@@ -281,6 +295,8 @@ void CompatMetaTensor::set_layout(DataLayout layout) { ...@@ -281,6 +295,8 @@ void CompatMetaTensor::set_layout(DataLayout layout) {
} }
void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) { void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
ValidCheck(*this);
ValidCheck(meta_tensor);
if (is_runtime_) { if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_); auto* var = PADDLE_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
...@@ -299,6 +315,8 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) { ...@@ -299,6 +315,8 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
} }
void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) { void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) {
ValidCheck(*this);
ValidCheck(meta_tensor);
set_dims(meta_tensor.dims()); set_dims(meta_tensor.dims());
if (is_runtime_) { if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_); auto* var = PADDLE_GET(Variable*, var_);
...@@ -472,6 +490,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -472,6 +490,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(std::string, attr))); phi::Scalar(PADDLE_GET_CONST(std::string, attr)));
break; break;
case framework::proto::AttrType::BOOLEAN:
infer_meta_context.EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(bool, attr)));
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct " "Unsupported cast op attribute `%s` to Scalar when construct "
......
...@@ -135,22 +135,15 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -135,22 +135,15 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
bool run_phi_kernel = false; bool run_phi_kernel = false;
// check if phi kernel exists // check if phi kernel exists
auto phi_kernel_map = if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
phi::KernelFactory::Instance().SelectKernelMap(op_with_kernel->Type()); op_with_kernel->Type())) {
if (phi_kernel_map.size() > 0) {
auto phi_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx); auto phi_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
VLOG(6) << "phi_kernel_key " << phi_kernel_key << "\n"; VLOG(6) << "phi_kernel_key " << phi_kernel_key << "\n";
// this function is used to construct data transfer op if (op_with_kernel->PhiKernel()->IsValid()) {
// we expect that it always has a valid phi kernel
// so no need to fallback to cpu kernel
PADDLE_ENFORCE_EQ(
op_with_kernel->PhiKernel()->IsValid(),
true,
platform::errors::PreconditionNotMet(
"the %s op has no valid phi kernel.", op_with_kernel->Type()));
run_phi_kernel = true; run_phi_kernel = true;
} }
}
// 3. Execute transfer op and construct OpFuncNode // 3. Execute transfer op and construct OpFuncNode
OpFuncNode new_op_func_node; OpFuncNode new_op_func_node;
......
...@@ -2752,6 +2752,10 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2752,6 +2752,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context->EmplaceBackAttr(std::move(phi::Scalar( phi_kernel_context->EmplaceBackAttr(std::move(phi::Scalar(
PADDLE_GET_CONST(std::string, attr_iter->second)))); PADDLE_GET_CONST(std::string, attr_iter->second))));
break; break;
case proto::AttrType::BOOLEAN:
phi_kernel_context->EmplaceBackAttr(std::move(
phi::Scalar(PADDLE_GET_CONST(bool, attr_iter->second))));
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct " "Unsupported cast op attribute `%s` to Scalar when construct "
......
...@@ -420,6 +420,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -420,6 +420,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(PADDLE_GET_CONST(std::string, attr)))); std::move(phi::Scalar(PADDLE_GET_CONST(std::string, attr))));
break; break;
case framework::proto::AttrType::BOOLEAN:
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(PADDLE_GET_CONST(bool, attr))));
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct " "Unsupported cast op attribute `%s` to Scalar when construct "
......
...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,23 +28,6 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel { ...@@ -25,23 +28,6 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel {
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(),
ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(check_finite_and_unscale), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(),
ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
}
ctx->SetOutputDim("FoundInfinite", {1});
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -91,59 +77,18 @@ Otherwise, FoundInfinite will be 0 (False). ...@@ -91,59 +77,18 @@ Otherwise, FoundInfinite will be 0 (False).
} }
}; };
template <typename T>
class CheckFiniteAndUnscaleCpuKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
const auto xs = ctx.MultiInput<framework::Tensor>("X");
const auto* scale = ctx.Input<framework::Tensor>("Scale");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
const T* scale_data = scale->data<T>();
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
*found_inf_data = false;
framework::Tensor is_finite =
ctx.AllocateTmpTensor<bool, phi::CPUContext>({1}, dev_ctx);
bool* is_finite_data = is_finite.template data<bool>();
auto& dev = *ctx.template device_context<phi::CPUContext>().eigen_device();
T inverse_scale = Inverse<T>(*scale_data);
for (size_t i = 0; i < xs.size(); ++i) {
const auto* x = xs[i];
auto* out = outs[i];
out->mutable_data<T>(dev_ctx.GetPlace());
if (!(*found_inf_data)) {
framework::TensorIsfinite(*x, &is_finite);
*found_inf_data = !(*is_finite_data);
}
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*x);
if (!(*found_inf_data)) {
eigen_out.device(dev) = eigen_in * inverse_scale;
} else {
eigen_out.device(dev) = eigen_in * static_cast<T>(0);
}
}
return;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(check_finite_and_unscale,
CheckFiniteAndUnscaleInferShapeFunctor,
PD_INFER_META(phi::CheckFiniteAndUnscaleInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
check_finite_and_unscale, check_finite_and_unscale,
ops::CheckFiniteAndUnscaleOp, ops::CheckFiniteAndUnscaleOp,
ops::CheckFiniteAndUnscaleOpMaker, ops::CheckFiniteAndUnscaleOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
CheckFiniteAndUnscaleInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(check_finite_and_unscale,
ops::CheckFiniteAndUnscaleCpuKernel<float>,
ops::CheckFiniteAndUnscaleCpuKernel<double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) {
*o = Inverse<T>(*s);
*found_inf = false;
}
template <typename T, typename MT>
__global__ void CheckFiniteAndUnscale(const T** xs,
const MT* scale,
int64_t size,
int64_t* starts,
bool* found_inf,
T** outs) {
const int64_t tid = threadIdx.x + blockIdx.x * blockDim.x;
// copy starts array from global memory to shared memory
extern __shared__ int64_t s_starts[];
for (int i = threadIdx.x; i <= size; i += blockDim.x) {
s_starts[i] = starts[i];
}
__syncthreads();
const int64_t num = s_starts[size];
int xs_index = 0;
bool local_found_inf = false;
const MT local_scale = *scale;
for (int64_t idx = tid; idx < num; idx += gridDim.x * blockDim.x) {
// get the "out" index of "id"
// For example:
// idx = 15, starts = [0, 10, 10, 20, 30]
// because 10 <= idx < 20 ==>
// the idx element locate in the 3rd tensor (notice the 2nd tensor size is
// 0)
int next_xs_index = xs_index;
while (idx >= s_starts[next_xs_index]) next_xs_index++;
xs_index = next_xs_index - 1;
// get in data and out data
const T* in = xs[xs_index];
T* out = outs[xs_index];
int64_t in_idx = idx - s_starts[xs_index];
// Unscale
MT val = static_cast<MT>(in[in_idx]) * local_scale;
T narrow_val = static_cast<T>(val);
out[in_idx] = narrow_val;
// CheckFinite
if (!isfinite(narrow_val)) {
local_found_inf = true;
}
}
if (local_found_inf) {
*found_inf = true;
}
}
template <typename T>
class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
const auto xs = ctx.MultiInput<framework::Tensor>("X");
const auto* scale = ctx.Input<framework::Tensor>("Scale");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
const MPDType* scale_data = scale->data<MPDType>();
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
framework::Tensor inverse_scale =
ctx.AllocateTmpTensor<MPDType, phi::GPUContext>({1}, dev_ctx);
MPDType* inverse_scale_v = inverse_scale.template data<MPDType>();
InverseAndMemset<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
scale_data, inverse_scale_v, found_inf_data);
size_t xs_size = xs.size();
if (xs_size == 0) return;
const auto& cpu_place = platform::CPUPlace();
// calculate each tensor's start index and copy to device
auto h_starts_tensor =
memory::Alloc(cpu_place, (xs_size + 1) * sizeof(int64_t));
int64_t* h_starts = reinterpret_cast<int64_t*>(h_starts_tensor->ptr());
auto d_starts_tensor =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr());
// the start index value of each tensor is
// the sum of previous tensor's size. For example:
// xs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
h_starts[0] = 0;
for (int i = 1; i <= xs_size; i++) {
h_starts[i] = h_starts[i - 1] + xs[i - 1]->numel();
}
int64_t total_num = h_starts[xs_size];
memory::Copy(dev_ctx.GetPlace(),
d_starts,
cpu_place,
h_starts,
(xs_size + 1) * sizeof(int64_t),
dev_ctx.stream());
// copy each tensor's data address to device
auto h_mem = memory::Alloc(cpu_place, 2 * xs_size * sizeof(T*));
const T** h_xs = reinterpret_cast<const T**>(h_mem->ptr());
T** h_outs = reinterpret_cast<T**>(h_mem->ptr()) + xs_size;
auto d_mem = memory::Alloc(dev_ctx, 2 * xs_size * sizeof(T*));
const T** d_xs = reinterpret_cast<const T**>(d_mem->ptr());
T** d_outs = reinterpret_cast<T**>(d_mem->ptr()) + xs_size;
for (size_t i = 0; i < xs_size; ++i) {
h_xs[i] = xs[i]->data<T>();
h_outs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
}
memory::Copy(dev_ctx.GetPlace(),
d_xs,
cpu_place,
h_xs,
2 * xs_size * sizeof(T*),
dev_ctx.stream());
// Launch Kernel
int threads_per_block = std::min(static_cast<int64_t>(1024), total_num);
int elements_per_block =
threads_per_block * 20; // each thread deal with 20 number
int blocks_per_grid =
(total_num + elements_per_block - 1) / elements_per_block;
VLOG(3) << "launch kernel";
CheckFiniteAndUnscale<T, MPDType><<<blocks_per_grid,
threads_per_block,
(xs_size + 1) * sizeof(int64_t),
dev_ctx.stream()>>>(
d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs);
VLOG(3) << "finish kernel";
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(check_finite_and_unscale,
ops::CheckFiniteAndUnscaleGpuKernel<float>,
ops::CheckFiniteAndUnscaleGpuKernel<double>,
ops::CheckFiniteAndUnscaleGpuKernel<plat::float16>);
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
......
...@@ -33,7 +33,7 @@ namespace p = paddle::platform; ...@@ -33,7 +33,7 @@ namespace p = paddle::platform;
using Tensor = paddle::framework::Tensor; using Tensor = paddle::framework::Tensor;
USE_OP(check_finite_and_unscale); USE_OP_ITSELF(check_finite_and_unscale);
USE_OP_DEVICE_KERNEL(check_finite_and_unscale, NPU); USE_OP_DEVICE_KERNEL(check_finite_and_unscale, NPU);
struct InputVars { struct InputVars {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
......
...@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
#include <cstring> #include <cstring>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,55 +28,6 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel { ...@@ -27,55 +28,6 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("FoundInfinite"),
"Input",
"FoundInfinite",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("PrevLossScaling"),
"Input",
"PrevLossScaling",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("InGoodSteps"),
"Input",
"InGoodSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("InBadSteps"),
"Input",
"InBadSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("LossScaling"),
"Output",
"LossScaling",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutGoodSteps"),
"Output",
"OutGoodSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutBadSteps"),
"Output",
"OutBadSteps",
"update_loss_scaling");
if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(),
ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(update_loss_scaling), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(),
ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
}
ctx->SetOutputDim("LossScaling", {1});
ctx->SetOutputDim("OutGoodSteps", {1});
ctx->SetOutputDim("OutBadSteps", {1});
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -168,72 +120,19 @@ decr_every_n_nan_or_inf steps and each step some gradients are infinite. ...@@ -168,72 +120,19 @@ decr_every_n_nan_or_inf steps and each step some gradients are infinite.
} }
}; };
template <typename T, bool IsFoundInfOnCPU>
class UpdateLossScalingFunctor<phi::CPUContext, T, IsFoundInfOnCPU> {
public:
void operator()(const phi::CPUContext& ctx,
const bool* found_inf_data,
const T* pre_loss_scaling_data,
const int* good_in_data,
const int* bad_in_data,
const int incr_every_n_steps,
const int decr_every_n_nan_or_inf,
const float incr_ratio,
const float decr_ratio,
T* updated_loss_scaling_data,
int* good_out_data,
int* bad_out_data) const {
PADDLE_ENFORCE_EQ(
IsFoundInfOnCPU,
true,
platform::errors::InvalidArgument(
"The Input(FoundInfinite) should be on the CPUPlace."));
Update<T>(found_inf_data,
pre_loss_scaling_data,
good_in_data,
bad_in_data,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
updated_loss_scaling_data,
good_out_data,
bad_out_data);
}
};
template <typename T>
class LazyZeros<phi::CPUContext, T> {
public:
void operator()(const phi::CPUContext& dev_ctx,
const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int num = out->numel();
if (*found_inf_data) {
VLOG(1) << "-- UpdateLossScaling: Find infinite grads. --";
std::memset(out_data, 0, num * sizeof(T));
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = phi::CPUContext; using CPU = phi::CPUContext;
DECLARE_INFER_SHAPE_FUNCTOR(update_loss_scaling,
UpdateLossScalingInferShapeFunctor,
PD_INFER_META(phi::UpdateLossScalingInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
update_loss_scaling, update_loss_scaling,
ops::UpdateLossScalingOp, ops::UpdateLossScalingOp,
ops::UpdateLossScalingOpMaker, ops::UpdateLossScalingOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
UpdateLossScalingInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(update_loss_scaling,
ops::UpdateLossScalingKernel<CPU, float>,
ops::UpdateLossScalingKernel<CPU, double>);
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
DECLARE_int32(min_loss_scaling); DECLARE_int32(min_loss_scaling);
...@@ -150,9 +150,7 @@ void Update(const platform::NPUDeviceContext& ctx, ...@@ -150,9 +150,7 @@ void Update(const platform::NPUDeviceContext& ctx,
} }
template <typename T> template <typename T>
class UpdateLossScalingFunctor<platform::NPUDeviceContext, class UpdateLossScalingFunctor {
T,
/*IsFoundInfOnCPU=*/true> {
public: public:
void operator()(const platform::NPUDeviceContext& dev_ctx, void operator()(const platform::NPUDeviceContext& dev_ctx,
const std::vector<bool> found_inf_vec, const std::vector<bool> found_inf_vec,
...@@ -270,8 +268,7 @@ class UpdateLossScalingNPUKernel : public framework::OpKernel<T> { ...@@ -270,8 +268,7 @@ class UpdateLossScalingNPUKernel : public framework::OpKernel<T> {
ctx.Attr<int>("decr_every_n_nan_or_inf"); ctx.Attr<int>("decr_every_n_nan_or_inf");
const float incr_ratio = ctx.Attr<float>("incr_ratio"); const float incr_ratio = ctx.Attr<float>("incr_ratio");
const float decr_ratio = ctx.Attr<float>("decr_ratio"); const float decr_ratio = ctx.Attr<float>("decr_ratio");
UpdateLossScalingFunctor<DeviceContext, MPDType, true>{}( UpdateLossScalingFunctor<MPDType>{}(dev_ctx,
dev_ctx,
found_inf_vec, found_inf_vec,
pre_loss_scaling, pre_loss_scaling,
good_in, good_in,
......
...@@ -19,12 +19,13 @@ limitations under the License. */ ...@@ -19,12 +19,13 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T> template <typename T>
class UpdateLossScalingXPUKernel : public framework::OpKernel<T> { class UpdateLossScalingXPUKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type; using MPDType = typename details::MPTypeTrait<T>::Type;
......
...@@ -25,15 +25,35 @@ limitations under the License. */ ...@@ -25,15 +25,35 @@ limitations under the License. */
namespace phi { namespace phi {
int64_t MetaTensor::numel() const { return tensor_->numel(); } static inline void ValidCheck(const MetaTensor& meta_tensor) {
PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
true,
phi::errors::InvalidArgument(
"The current MetaTensor is not initialized."));
}
DDim MetaTensor::dims() const { return tensor_->dims(); } int64_t MetaTensor::numel() const {
ValidCheck(*this);
return tensor_->numel();
}
DataType MetaTensor::dtype() const { return tensor_->dtype(); } DDim MetaTensor::dims() const {
ValidCheck(*this);
return tensor_->dims();
}
DataLayout MetaTensor::layout() const { return tensor_->layout(); } DataType MetaTensor::dtype() const {
ValidCheck(*this);
return tensor_->dtype();
}
DataLayout MetaTensor::layout() const {
ValidCheck(*this);
return tensor_->layout();
}
void MetaTensor::set_dims(const DDim& dims) { void MetaTensor::set_dims(const DDim& dims) {
ValidCheck(*this);
if (phi::DenseTensor::classof(tensor_)) { if (phi::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->dims = DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->dims =
dims; dims;
...@@ -51,6 +71,7 @@ void MetaTensor::set_dims(const DDim& dims) { ...@@ -51,6 +71,7 @@ void MetaTensor::set_dims(const DDim& dims) {
} }
void MetaTensor::set_dtype(DataType dtype) { void MetaTensor::set_dtype(DataType dtype) {
ValidCheck(*this);
if (phi::DenseTensor::classof(tensor_)) { if (phi::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_)) DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
->dtype = dtype; ->dtype = dtype;
...@@ -67,6 +88,7 @@ void MetaTensor::set_dtype(DataType dtype) { ...@@ -67,6 +88,7 @@ void MetaTensor::set_dtype(DataType dtype) {
} }
void MetaTensor::set_layout(DataLayout layout) { void MetaTensor::set_layout(DataLayout layout) {
ValidCheck(*this);
if (phi::DenseTensor::classof(tensor_)) { if (phi::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_)) DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
->layout = layout; ->layout = layout;
...@@ -83,6 +105,8 @@ void MetaTensor::set_layout(DataLayout layout) { ...@@ -83,6 +105,8 @@ void MetaTensor::set_layout(DataLayout layout) {
} }
void MetaTensor::share_lod(const MetaTensor& meta_tensor) { void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
ValidCheck(*this);
ValidCheck(meta_tensor);
if (meta_tensor.lod().size() == 0) { if (meta_tensor.lod().size() == 0) {
// no need share // no need share
return; return;
...@@ -101,18 +125,8 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) { ...@@ -101,18 +125,8 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
} }
} }
const LoD& MetaTensor::lod() const {
if (phi::DenseTensor::classof(tensor_)) {
return static_cast<DenseTensor*>(tensor_)->lod();
} else if (phi::SelectedRows::classof(tensor_)) {
return static_cast<SelectedRows*>(tensor_)->value().lod();
} else {
PADDLE_THROW(phi::errors::Unimplemented("Unsupported getting lod of `%s`.",
tensor_->type_info().name()));
}
}
void MetaTensor::share_meta(const MetaTensor& meta_tensor) { void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
ValidCheck(*this);
if (phi::DenseTensor::classof(tensor_) || if (phi::DenseTensor::classof(tensor_) ||
phi::SelectedRows::classof(tensor_)) { phi::SelectedRows::classof(tensor_)) {
share_dims(meta_tensor); share_dims(meta_tensor);
...@@ -125,9 +139,8 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) { ...@@ -125,9 +139,8 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
} }
} }
TensorBase* MetaTensor::tensor() const { return tensor_; }
void MetaTensor::share_dims(const MetaTensor& meta_tensor) { void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
ValidCheck(*this);
bool is_dense_tensor = phi::DenseTensor::classof(tensor_); bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
bool is_selected_rows = phi::SelectedRows::classof(tensor_); bool is_selected_rows = phi::SelectedRows::classof(tensor_);
if (is_dense_tensor || is_selected_rows) { if (is_dense_tensor || is_selected_rows) {
...@@ -152,4 +165,19 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) { ...@@ -152,4 +165,19 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
bool MetaTensor::initialized() const { return tensor_ != nullptr; } bool MetaTensor::initialized() const { return tensor_ != nullptr; }
// Private Member Methods
const LoD& MetaTensor::lod() const {
if (phi::DenseTensor::classof(tensor_)) {
return static_cast<DenseTensor*>(tensor_)->lod();
} else if (phi::SelectedRows::classof(tensor_)) {
return static_cast<SelectedRows*>(tensor_)->value().lod();
} else {
PADDLE_THROW(phi::errors::Unimplemented("Unsupported getting lod of `%s`.",
tensor_->type_info().name()));
}
}
TensorBase* MetaTensor::tensor() const { return tensor_; }
} // namespace phi } // namespace phi
...@@ -764,6 +764,27 @@ void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x, ...@@ -764,6 +764,27 @@ void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
} }
} }
void CheckFiniteAndUnscaleInferMeta(const std::vector<const MetaTensor*>& xs,
const MetaTensor& scale,
std::vector<MetaTensor*> outs,
MetaTensor* found_infinite) {
PADDLE_ENFORCE_EQ(
xs.size(),
outs.size(),
phi::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(check_finite_and_unscale), size of input(X) is %d "
"and size of output(Out) is %d.",
xs.size(),
outs.size()));
for (size_t i = 0; i < xs.size(); ++i) {
outs[i]->set_dims(xs[i]->dims());
outs[i]->set_dtype(xs[i]->dtype());
}
found_infinite->set_dims({1});
found_infinite->set_dtype(DataType::BOOL);
}
void ConcatInferMeta(const std::vector<const MetaTensor*>& x, void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar, const Scalar& axis_scalar,
MetaTensor* out, MetaTensor* out,
...@@ -1109,6 +1130,102 @@ void GenerateProposalsV2InferMeta(const MetaTensor& scores, ...@@ -1109,6 +1130,102 @@ void GenerateProposalsV2InferMeta(const MetaTensor& scores,
rpn_roi_probs->set_dims(phi::make_ddim({-1, 1})); rpn_roi_probs->set_dims(phi::make_ddim({-1, 1}));
} }
void GraphReindexInferMeta(const MetaTensor& x,
const MetaTensor& neighbors,
const MetaTensor& count,
const MetaTensor& hashtable_value,
const MetaTensor& hashtable_index,
bool flag_buffer_hashtable,
MetaTensor* reindex_src,
MetaTensor* reindex_dst,
MetaTensor* out_nodes) {
auto GraphReindexShapeCheck = [](const phi::DDim& dims,
std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(
dims[1],
1,
phi::errors::InvalidArgument("The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name,
dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(),
1,
phi::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
dims.size()));
}
};
GraphReindexShapeCheck(x.dims(), "X");
GraphReindexShapeCheck(neighbors.dims(), "Neighbors");
GraphReindexShapeCheck(count.dims(), "Count");
if (flag_buffer_hashtable) {
GraphReindexShapeCheck(hashtable_value.dims(), "HashTable_Value");
GraphReindexShapeCheck(hashtable_index.dims(), "HashTable_Index");
}
reindex_src->set_dims({-1});
reindex_src->set_dtype(neighbors.dtype());
reindex_dst->set_dims({-1});
reindex_dst->set_dtype(neighbors.dtype());
out_nodes->set_dims({-1});
out_nodes->set_dtype(x.dtype());
}
void GraphSampleNeighborsInferMeta(const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& x,
const MetaTensor& eids,
const MetaTensor& perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids) {
// GSN: GraphSampleNeighbors
auto GSNShapeCheck = [](const phi::DDim& dims, std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(
dims[1],
1,
phi::errors::InvalidArgument("The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name,
dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(),
1,
phi::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
dims.size()));
}
};
GSNShapeCheck(row.dims(), "Row");
GSNShapeCheck(col_ptr.dims(), "Col_Ptr");
GSNShapeCheck(x.dims(), "X");
if (return_eids) {
GSNShapeCheck(eids.dims(), "Eids");
out_eids->set_dims({-1});
out_eids->set_dtype(row.dtype());
}
if (flag_perm_buffer) {
GSNShapeCheck(perm_buffer.dims(), "Perm_Buffer");
}
out->set_dims({-1});
out->set_dtype(row.dtype());
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}
void HierarchicalSigmoidInferMeta(const MetaTensor& x, void HierarchicalSigmoidInferMeta(const MetaTensor& x,
const MetaTensor& w, const MetaTensor& w,
const MetaTensor& label, const MetaTensor& label,
...@@ -2294,6 +2411,34 @@ void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x, ...@@ -2294,6 +2411,34 @@ void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
} }
} }
void UpdateLossScalingInferMeta(const std::vector<const MetaTensor*>& xs,
const MetaTensor& found_infinite,
const MetaTensor& prev_loss_scaling,
const MetaTensor& in_good_steps,
const MetaTensor& in_bad_steps,
std::vector<MetaTensor*> outs,
MetaTensor* loss_scaling,
MetaTensor* out_good_steps,
MetaTensor* out_bad_steps) {
PADDLE_ENFORCE_EQ(xs.size(),
outs.size(),
phi::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(update_loss_scaling), size of input(X) is %d "
"and size of output(Out) is %d.",
xs.size(),
outs.size()));
for (size_t i = 0; i < xs.size(); ++i) {
outs[i]->set_dims(xs[i]->dims());
outs[i]->set_dtype(xs[i]->dtype());
}
loss_scaling->set_dims({1});
out_good_steps->set_dims({1});
out_good_steps->set_dtype(DataType::INT32);
out_bad_steps->set_dims({1});
out_bad_steps->set_dtype(DataType::INT32);
}
void WarpctcInferMeta(const MetaTensor& logits, void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label, const MetaTensor& label,
const MetaTensor& logits_length, const MetaTensor& logits_length,
...@@ -2356,102 +2501,6 @@ void WhereInferMeta(const MetaTensor& condition, ...@@ -2356,102 +2501,6 @@ void WhereInferMeta(const MetaTensor& condition,
out->share_meta(x); out->share_meta(x);
} }
void GraphReindexInferMeta(const MetaTensor& x,
const MetaTensor& neighbors,
const MetaTensor& count,
const MetaTensor& hashtable_value,
const MetaTensor& hashtable_index,
bool flag_buffer_hashtable,
MetaTensor* reindex_src,
MetaTensor* reindex_dst,
MetaTensor* out_nodes) {
auto GraphReindexShapeCheck = [](const phi::DDim& dims,
std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(
dims[1],
1,
phi::errors::InvalidArgument("The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name,
dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(),
1,
phi::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
dims.size()));
}
};
GraphReindexShapeCheck(x.dims(), "X");
GraphReindexShapeCheck(neighbors.dims(), "Neighbors");
GraphReindexShapeCheck(count.dims(), "Count");
if (flag_buffer_hashtable) {
GraphReindexShapeCheck(hashtable_value.dims(), "HashTable_Value");
GraphReindexShapeCheck(hashtable_index.dims(), "HashTable_Index");
}
reindex_src->set_dims({-1});
reindex_src->set_dtype(neighbors.dtype());
reindex_dst->set_dims({-1});
reindex_dst->set_dtype(neighbors.dtype());
out_nodes->set_dims({-1});
out_nodes->set_dtype(x.dtype());
}
void GraphSampleNeighborsInferMeta(const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& x,
const MetaTensor& eids,
const MetaTensor& perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids) {
// GSN: GraphSampleNeighbors
auto GSNShapeCheck = [](const phi::DDim& dims, std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(
dims[1],
1,
phi::errors::InvalidArgument("The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name,
dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(),
1,
phi::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
dims.size()));
}
};
GSNShapeCheck(row.dims(), "Row");
GSNShapeCheck(col_ptr.dims(), "Col_Ptr");
GSNShapeCheck(x.dims(), "X");
if (return_eids) {
GSNShapeCheck(eids.dims(), "Eids");
out_eids->set_dims({-1});
out_eids->set_dtype(row.dtype());
}
if (flag_perm_buffer) {
GSNShapeCheck(perm_buffer.dims(), "Perm_Buffer");
}
out->set_dims({-1});
out->set_dtype(row.dtype());
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}
void Yolov3LossInferMeta(const MetaTensor& x, void Yolov3LossInferMeta(const MetaTensor& x,
const MetaTensor& gt_box, const MetaTensor& gt_box,
const MetaTensor& gt_label, const MetaTensor& gt_label,
......
...@@ -196,6 +196,11 @@ void BilinearTensorProductInferMeta(const MetaTensor& x, ...@@ -196,6 +196,11 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x, void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out); std::vector<MetaTensor*> out);
void CheckFiniteAndUnscaleInferMeta(const std::vector<const MetaTensor*>& xs,
const MetaTensor& scale,
std::vector<MetaTensor*> outs,
MetaTensor* found_infinite);
void ConcatInferMeta(const std::vector<const MetaTensor*>& x, void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar, const Scalar& axis_scalar,
MetaTensor* out, MetaTensor* out,
...@@ -237,6 +242,28 @@ void GenerateProposalsV2InferMeta(const MetaTensor& scores, ...@@ -237,6 +242,28 @@ void GenerateProposalsV2InferMeta(const MetaTensor& scores,
MetaTensor* rpn_roi_probs, MetaTensor* rpn_roi_probs,
MetaTensor* rpn_rois_num); MetaTensor* rpn_rois_num);
void GraphReindexInferMeta(const MetaTensor& x,
const MetaTensor& neighbors,
const MetaTensor& count,
const MetaTensor& hashtable_value,
const MetaTensor& hashtable_index,
bool flag_buffer_hashtable,
MetaTensor* reindex_src,
MetaTensor* reindex_dst,
MetaTensor* out_nodes);
void GraphSampleNeighborsInferMeta(const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& x,
const MetaTensor& eids,
const MetaTensor& perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids);
void HierarchicalSigmoidInferMeta(const MetaTensor& x, void HierarchicalSigmoidInferMeta(const MetaTensor& x,
const MetaTensor& w, const MetaTensor& w,
const MetaTensor& label, const MetaTensor& label,
...@@ -415,6 +442,16 @@ void StackInferMeta(const std::vector<const MetaTensor*>& x, ...@@ -415,6 +442,16 @@ void StackInferMeta(const std::vector<const MetaTensor*>& x,
void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x, void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out); std::vector<MetaTensor*> out);
void UpdateLossScalingInferMeta(const std::vector<const MetaTensor*>& xs,
const MetaTensor& found_infinite,
const MetaTensor& prev_loss_scaling,
const MetaTensor& in_good_steps,
const MetaTensor& in_bad_steps,
std::vector<MetaTensor*> outs,
MetaTensor* loss_scaling,
MetaTensor* out_good_steps,
MetaTensor* out_bad_steps);
void WarpctcInferMeta(const MetaTensor& logits, void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label, const MetaTensor& label,
const MetaTensor& logits_length, const MetaTensor& logits_length,
...@@ -429,28 +466,6 @@ void WhereInferMeta(const MetaTensor& condition, ...@@ -429,28 +466,6 @@ void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& y, const MetaTensor& y,
MetaTensor* out); MetaTensor* out);
void GraphReindexInferMeta(const MetaTensor& x,
const MetaTensor& neighbors,
const MetaTensor& count,
const MetaTensor& hashtable_value,
const MetaTensor& hashtable_index,
bool flag_buffer_hashtable,
MetaTensor* reindex_src,
MetaTensor* reindex_dst,
MetaTensor* out_nodes);
void GraphSampleNeighborsInferMeta(const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& x,
const MetaTensor& eids,
const MetaTensor& perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids);
void Yolov3LossInferMeta(const MetaTensor& x, void Yolov3LossInferMeta(const MetaTensor& x,
const MetaTensor& gt_box, const MetaTensor& gt_box,
const MetaTensor& gt_label, const MetaTensor& gt_label,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CheckFiniteAndUnscaleKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& xs,
const DenseTensor& scale,
std::vector<DenseTensor*> outs,
DenseTensor* found_infinite);
template <typename T, typename Context>
void UpdateLossScalingKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& xs,
const DenseTensor& found_infinite,
const DenseTensor& prev_loss_scaling,
const DenseTensor& in_good_steps,
const DenseTensor& in_bad_steps,
int incr_every_n_steps,
int decr_every_n_nan_or_inf,
float incr_ratio,
float decr_ratio,
const Scalar& stop_update,
std::vector<DenseTensor*> outs,
DenseTensor* loss_scaling,
DenseTensor* out_good_steps,
DenseTensor* out_bad_steps);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/amp_kernel.h"
#include <cmath>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/impl/amp_kernel_impl.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace phi {
// Utils
template <typename T, bool IsFoundInfOnCPU>
class UpdateLossScalingFunctor<phi::CPUContext, T, IsFoundInfOnCPU> {
public:
void operator()(const phi::CPUContext& ctx,
const bool* found_inf_data,
const T* pre_loss_scaling_data,
const int* good_in_data,
const int* bad_in_data,
const int incr_every_n_steps,
const int decr_every_n_nan_or_inf,
const float incr_ratio,
const float decr_ratio,
T* updated_loss_scaling_data,
int* good_out_data,
int* bad_out_data) const {
PADDLE_ENFORCE_EQ(
IsFoundInfOnCPU,
true,
phi::errors::InvalidArgument(
"The Input(FoundInfinite) should be on the CPUPlace."));
Update<T>(found_inf_data,
pre_loss_scaling_data,
good_in_data,
bad_in_data,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
updated_loss_scaling_data,
good_out_data,
bad_out_data);
}
};
// Kernels
template <typename T, typename Context>
void CheckFiniteAndUnscaleKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& xs,
const DenseTensor& scale,
std::vector<DenseTensor*> outs,
DenseTensor* found_infinite) {
const T* scale_data = scale.data<T>();
bool* found_inf_data = dev_ctx.template Alloc<bool>(found_infinite);
*found_inf_data = false;
DenseTensor is_finite = Empty<bool>(dev_ctx, {1});
bool* is_finite_data = is_finite.template data<bool>();
auto& dev = *dev_ctx.eigen_device();
T inverse_scale = 1.0 / *scale_data;
for (size_t i = 0; i < xs.size(); ++i) {
const auto* x = xs[i];
auto* out = outs[i];
dev_ctx.template Alloc<T>(out);
if (!(*found_inf_data)) {
paddle::framework::TensorIsfinite(*x, &is_finite);
*found_inf_data = !(*is_finite_data);
}
auto eigen_out = EigenVector<T>::Flatten(*out);
auto eigen_in = EigenVector<T>::Flatten(*x);
if (!(*found_inf_data)) {
eigen_out.device(dev) = eigen_in * inverse_scale;
} else {
eigen_out.device(dev) = eigen_in * static_cast<T>(0);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(check_finite_and_unscale,
CPU,
ALL_LAYOUT,
phi::CheckFiniteAndUnscaleKernel,
float,
double) {}
PD_REGISTER_KERNEL(update_loss_scaling,
CPU,
ALL_LAYOUT,
phi::UpdateLossScalingKernel,
float,
double) {}
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License"); #include "paddle/phi/kernels/amp_kernel.h"
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/impl/amp_kernel_impl.h"
Unless required by applicable law or agreed to in writing, software #include "paddle/fluid/framework/tensor_util.h"
distributed under the License is distributed on an "AS IS" BASIS, #include "paddle/fluid/memory/memory.h"
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector> namespace phi {
#include "paddle/fluid/framework/op_registry.h" // Utils
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { template <typename T>
namespace operators { __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) {
*o = 1.0 / *s;
*found_inf = false;
}
template <typename T, typename MT>
__global__ void CheckFiniteAndUnscale(const T** xs,
const MT* scale,
int64_t size,
int64_t* starts,
bool* found_inf,
T** outs) {
const int64_t tid = threadIdx.x + blockIdx.x * blockDim.x;
// copy starts array from global memory to shared memory
extern __shared__ int64_t s_starts[];
for (int i = threadIdx.x; i <= size; i += blockDim.x) {
s_starts[i] = starts[i];
}
__syncthreads();
const int64_t num = s_starts[size];
int xs_index = 0;
bool local_found_inf = false;
const MT local_scale = *scale;
for (int64_t idx = tid; idx < num; idx += gridDim.x * blockDim.x) {
// get the "out" index of "id"
// For example:
// idx = 15, starts = [0, 10, 10, 20, 30]
// because 10 <= idx < 20 ==>
// the idx element locate in the 3rd tensor (notice the 2nd tensor size is
// 0)
int next_xs_index = xs_index;
while (idx >= s_starts[next_xs_index]) next_xs_index++;
xs_index = next_xs_index - 1;
// get in data and out data
const T* in = xs[xs_index];
T* out = outs[xs_index];
int64_t in_idx = idx - s_starts[xs_index];
// Unscale
MT val = static_cast<MT>(in[in_idx]) * local_scale;
T narrow_val = static_cast<T>(val);
out[in_idx] = narrow_val;
// CheckFinite
if (!isfinite(narrow_val)) {
local_found_inf = true;
}
}
if (local_found_inf) {
*found_inf = true;
}
}
template <typename T, typename FoundNanInfFlagT> template <typename T, typename FoundNanInfFlagT>
__global__ void GpuUpdateLossScaling(const FoundNanInfFlagT found_inf_data, __global__ void GpuUpdateLossScaling(const FoundNanInfFlagT found_inf_data,
...@@ -86,6 +147,73 @@ __global__ void FusedFillIf(T** outs, ...@@ -86,6 +147,73 @@ __global__ void FusedFillIf(T** outs,
} }
} }
template <typename T>
class LazyZeros<phi::GPUContext, T> {
public:
void operator()(const phi::GPUContext& dev_ctx,
const bool* found_inf_data,
const std::vector<const DenseTensor*>& xs,
const std::vector<DenseTensor*>& outs) {
size_t xs_size = xs.size();
if (xs_size == 0) return;
const auto& cpu_place = phi::CPUPlace();
// alloc each tensor's start index and copy to device
auto h_in_starts_mem =
paddle::memory::Alloc(cpu_place, (xs_size + 1) * sizeof(int64_t));
int64_t* h_starts = reinterpret_cast<int64_t*>(h_in_starts_mem->ptr());
auto d_in_starts_mem =
paddle::memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_in_starts_mem->ptr());
// the start index value of each tensor is
// the sum of previous tensor's size. For example:
// outs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
h_starts[0] = 0;
for (int i = 0; i < xs_size; i++) {
h_starts[i + 1] = h_starts[i] + outs[i]->numel();
}
paddle::memory::Copy(dev_ctx.GetPlace(),
d_starts,
cpu_place,
h_starts,
(xs_size + 1) * sizeof(int64_t),
dev_ctx.stream());
// copy each tensor of "outs" data address array to device
auto h_out_addrs_mem =
paddle::memory::Alloc(cpu_place, xs_size * sizeof(T*));
T** h_out_addrs = reinterpret_cast<T**>(h_out_addrs_mem->ptr());
auto d_out_addrs_mem = paddle::memory::Alloc(dev_ctx, xs_size * sizeof(T*));
T** d_out_addrs = reinterpret_cast<T**>(d_out_addrs_mem->ptr());
for (size_t i = 0; i < xs_size; ++i) {
h_out_addrs[i] = dev_ctx.Alloc<T>(outs[i]);
}
paddle::memory::Copy(dev_ctx.GetPlace(),
d_out_addrs,
cpu_place,
h_out_addrs,
xs_size * sizeof(T*),
dev_ctx.stream());
// launch cuda kernel
int64_t total_num = h_starts[xs_size];
int64_t threads_per_block = std::min(static_cast<int64_t>(1024), total_num);
int64_t elements_per_block =
threads_per_block * 50; // each thread deal with 50 data
int64_t blocks_per_grid =
(total_num + elements_per_block - 1) / elements_per_block;
FusedFillIf<T><<<blocks_per_grid,
threads_per_block,
(xs_size + 1) * sizeof(int64_t),
dev_ctx.stream()>>>(
d_out_addrs, xs_size, d_starts, static_cast<T>(0), found_inf_data);
}
};
template <typename T, bool IsFoundInfOnCPU> template <typename T, bool IsFoundInfOnCPU>
class UpdateLossScalingFunctor<phi::GPUContext, T, IsFoundInfOnCPU> { class UpdateLossScalingFunctor<phi::GPUContext, T, IsFoundInfOnCPU> {
public: public:
...@@ -131,80 +259,100 @@ class UpdateLossScalingFunctor<phi::GPUContext, T, IsFoundInfOnCPU> { ...@@ -131,80 +259,100 @@ class UpdateLossScalingFunctor<phi::GPUContext, T, IsFoundInfOnCPU> {
} }
}; };
template <typename T> // Kernels
class LazyZeros<phi::GPUContext, T> {
public: template <typename T, typename Context>
void operator()(const phi::GPUContext& dev_ctx, void CheckFiniteAndUnscaleKernel(const Context& dev_ctx,
const bool* found_inf_data, const std::vector<const DenseTensor*>& xs,
const std::vector<const framework::Tensor*>& xs, const DenseTensor& scale,
const std::vector<framework::Tensor*>& outs) const { std::vector<DenseTensor*> outs,
DenseTensor* found_infinite) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPDType* scale_data = scale.data<MPDType>();
bool* found_inf_data = dev_ctx.template Alloc<bool>(found_infinite);
DenseTensor inverse_scale = Empty<MPDType>(dev_ctx, {1});
MPDType* inverse_scale_v = inverse_scale.template data<MPDType>();
InverseAndMemset<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
scale_data, inverse_scale_v, found_inf_data);
size_t xs_size = xs.size(); size_t xs_size = xs.size();
if (xs_size == 0) return; if (xs_size == 0) return;
const auto& cpu_place = platform::CPUPlace(); const auto& cpu_place = phi::CPUPlace();
// alloc each tensor's start index and copy to device // calculate each tensor's start index and copy to device
auto h_in_starts_mem = auto h_starts_tensor =
memory::Alloc(cpu_place, (xs_size + 1) * sizeof(int64_t)); paddle::memory::Alloc(cpu_place, (xs_size + 1) * sizeof(int64_t));
int64_t* h_starts = reinterpret_cast<int64_t*>(h_in_starts_mem->ptr()); int64_t* h_starts = reinterpret_cast<int64_t*>(h_starts_tensor->ptr());
auto d_in_starts_mem = auto d_starts_tensor =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t)); paddle::memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_in_starts_mem->ptr()); int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr());
// the start index value of each tensor is // the start index value of each tensor is
// the sum of previous tensor's size. For example: // the sum of previous tensor's size. For example:
// outs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30] // x = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
h_starts[0] = 0; h_starts[0] = 0;
for (int i = 0; i < xs_size; i++) { for (int i = 1; i <= xs_size; i++) {
h_starts[i + 1] = h_starts[i] + outs[i]->numel(); h_starts[i] = h_starts[i - 1] + xs[i - 1]->numel();
} }
memory::Copy(dev_ctx.GetPlace(), int64_t total_num = h_starts[xs_size];
paddle::memory::Copy(dev_ctx.GetPlace(),
d_starts, d_starts,
cpu_place, cpu_place,
h_starts, h_starts,
(xs_size + 1) * sizeof(int64_t), (xs_size + 1) * sizeof(int64_t),
dev_ctx.stream()); dev_ctx.stream());
// copy each tensor of "outs" data address array to device // copy each tensor's data address to device
auto h_out_addrs_mem = memory::Alloc(cpu_place, xs_size * sizeof(T*)); auto h_mem = paddle::memory::Alloc(cpu_place, 2 * xs_size * sizeof(T*));
T** h_out_addrs = reinterpret_cast<T**>(h_out_addrs_mem->ptr()); const T** h_xs = reinterpret_cast<const T**>(h_mem->ptr());
T** h_outs = reinterpret_cast<T**>(h_mem->ptr()) + xs_size;
auto d_out_addrs_mem = memory::Alloc(dev_ctx, xs_size * sizeof(T*)); auto d_mem = paddle::memory::Alloc(dev_ctx, 2 * xs_size * sizeof(T*));
T** d_out_addrs = reinterpret_cast<T**>(d_out_addrs_mem->ptr()); const T** d_xs = reinterpret_cast<const T**>(d_mem->ptr());
T** d_outs = reinterpret_cast<T**>(d_mem->ptr()) + xs_size;
for (size_t i = 0; i < xs_size; ++i) { for (size_t i = 0; i < xs_size; ++i) {
h_out_addrs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace()); h_xs[i] = xs[i]->data<T>();
h_outs[i] = dev_ctx.template Alloc<T>(outs[i]);
} }
memory::Copy(dev_ctx.GetPlace(), paddle::memory::Copy(dev_ctx.GetPlace(),
d_out_addrs, d_xs,
cpu_place, cpu_place,
h_out_addrs, h_xs,
xs_size * sizeof(T*), 2 * xs_size * sizeof(T*),
dev_ctx.stream()); dev_ctx.stream());
// launch cuda kernel // Launch Kernel
int64_t total_num = h_starts[xs_size]; int threads_per_block = std::min(static_cast<int64_t>(1024), total_num);
int64_t threads_per_block = std::min(static_cast<int64_t>(1024), total_num); int elements_per_block =
int64_t elements_per_block = threads_per_block * 20; // each thread deal with 20 number
threads_per_block * 50; // each thread deal with 50 data int blocks_per_grid =
int64_t blocks_per_grid =
(total_num + elements_per_block - 1) / elements_per_block; (total_num + elements_per_block - 1) / elements_per_block;
FusedFillIf<T><<<blocks_per_grid, CheckFiniteAndUnscale<T, MPDType><<<blocks_per_grid,
threads_per_block, threads_per_block,
(xs_size + 1) * sizeof(int64_t), (xs_size + 1) * sizeof(int64_t),
dev_ctx.stream()>>>( dev_ctx.stream()>>>(
d_out_addrs, xs_size, d_starts, static_cast<T>(0), found_inf_data); d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs);
} }
};
} // namespace operators } // namespace phi
} // namespace paddle
namespace ops = paddle::operators; PD_REGISTER_KERNEL(check_finite_and_unscale,
namespace plat = paddle::platform; GPU,
using GPU = phi::GPUContext; ALL_LAYOUT,
phi::CheckFiniteAndUnscaleKernel,
float,
double,
phi::dtype::float16) {}
REGISTER_OP_CUDA_KERNEL(update_loss_scaling, PD_REGISTER_KERNEL(update_loss_scaling,
ops::UpdateLossScalingKernel<GPU, float>, GPU,
ops::UpdateLossScalingKernel<GPU, double>, ALL_LAYOUT,
ops::UpdateLossScalingKernel<GPU, plat::float16>); phi::UpdateLossScalingKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,27 +14,15 @@ ...@@ -14,27 +14,15 @@
#pragma once #pragma once
#if defined(PADDLE_WITH_CUDA) && defined(__NVCC__) #include "paddle/phi/common/amp_type_traits.h"
#include <cuda.h>
#endif // PADDLE_WITH_CUDA && __NVCC__
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/amp_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor; namespace phi {
template <typename T> template <typename T>
inline HOSTDEVICE bool check_finite(T value) { inline HOSTDEVICE bool CheckFinite(T value) {
#if defined(PADDLE_WITH_CUDA) && defined(__NVCC__) #if defined(PADDLE_WITH_CUDA) && defined(__NVCC__)
return isfinite(value); return isfinite(value);
#else #else
...@@ -77,7 +65,7 @@ inline HOSTDEVICE void Update(const FoundInfFlagT found_inf_data, ...@@ -77,7 +65,7 @@ inline HOSTDEVICE void Update(const FoundInfFlagT found_inf_data,
*good_out_data = *good_in_data + 1; *good_out_data = *good_in_data + 1;
if (*good_out_data == incr_every_n_steps) { if (*good_out_data == incr_every_n_steps) {
T new_loss_scaling = *pre_loss_scaling_data * incr_ratio; T new_loss_scaling = *pre_loss_scaling_data * incr_ratio;
*updated_loss_scaling_data = check_finite(new_loss_scaling) *updated_loss_scaling_data = CheckFinite(new_loss_scaling)
? new_loss_scaling ? new_loss_scaling
: *pre_loss_scaling_data; : *pre_loss_scaling_data;
*good_out_data = 0; *good_out_data = 0;
...@@ -85,7 +73,16 @@ inline HOSTDEVICE void Update(const FoundInfFlagT found_inf_data, ...@@ -85,7 +73,16 @@ inline HOSTDEVICE void Update(const FoundInfFlagT found_inf_data,
} }
} }
template <typename DeviceContext, typename T, bool IsFoundInfOnCPU> template <typename Context, typename T>
class LazyZeros {
public:
void operator()(const DeviceContext& dev_ctx,
const bool* found_inf_data,
const std::vector<const DenseTensor*>& xs,
const std::vector<DenseTensor*>& outs) const {}
};
template <typename Context, typename T, bool IsFoundInfOnCPU>
class UpdateLossScalingFunctor { class UpdateLossScalingFunctor {
public: public:
void operator()(const DeviceContext& dev_ctx, void operator()(const DeviceContext& dev_ctx,
...@@ -102,84 +99,58 @@ class UpdateLossScalingFunctor { ...@@ -102,84 +99,58 @@ class UpdateLossScalingFunctor {
int* bad_out_data) const; int* bad_out_data) const;
}; };
template <typename DeviceContext, typename T> template <typename T, typename Context>
class LazyZeros { void UpdateLossScalingKernel(const Context& dev_ctx,
public: const std::vector<const DenseTensor*>& xs,
void operator()(const DeviceContext& dev_ctx, const DenseTensor& found_infinite,
const bool* found_inf_data, const DenseTensor& prev_loss_scaling,
const std::vector<const framework::Tensor*>& xs, const DenseTensor& in_good_steps,
const std::vector<framework::Tensor*>& outs) const; const DenseTensor& in_bad_steps,
}; int incr_every_n_steps,
int decr_every_n_nan_or_inf,
template <typename DeviceContext, typename T> float incr_ratio,
class UpdateLossScalingKernel : public framework::OpKernel<T> { float decr_ratio,
using MPDType = typename details::MPTypeTrait<T>::Type; const Scalar& stop_update,
std::vector<DenseTensor*> outs,
public: DenseTensor* loss_scaling,
void Compute(const framework::ExecutionContext& ctx) const override { DenseTensor* out_good_steps,
auto& dev_ctx = ctx.template device_context<DeviceContext>(); DenseTensor* out_bad_steps) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const auto xs = ctx.MultiInput<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out"); PADDLE_ENFORCE_EQ(
const auto* found_inf = ctx.Input<Tensor>("FoundInfinite"); found_infinite.numel(),
PADDLE_ENFORCE_EQ(found_inf->numel(),
1, 1,
platform::errors::InvalidArgument( phi::errors::InvalidArgument("FoundInfinite must has only one element."));
"FoundInfinite must has only one element.")); const bool* found_inf_data = found_infinite.data<bool>();
const bool* found_inf_data = found_inf->data<bool>(); bool is_found_inf_on_cpu =
bool is_found_inf_on_cpu = platform::is_cpu_place(found_inf->place()); found_infinite.place().GetType() == AllocationType::CPU;
if (is_found_inf_on_cpu) { if (is_found_inf_on_cpu) {
if (*found_inf_data) { if (*found_inf_data) {
phi::funcs::SetConstant<DeviceContext, T> set_constant;
for (auto* out : outs) { for (auto* out : outs) {
out->mutable_data<T>(dev_ctx.GetPlace()); Full<T>(dev_ctx, vectorize(out->dims()), static_cast<T>(0), out);
set_constant(dev_ctx, out, static_cast<T>(0));
} }
} }
} else { } else {
LazyZeros<DeviceContext, T>{}(dev_ctx, found_inf_data, xs, outs); LazyZeros<Context, T>{}(dev_ctx, found_inf_data, xs, outs);
} }
const auto* stop_update_tensor = ctx.Input<Tensor>("StopUpdate"); auto stop_update_val = stop_update.to<bool>();
bool stop_update = false; if (stop_update_val) {
if (stop_update_tensor && stop_update_tensor->IsInitialized()) {
if (platform::is_cpu_place(stop_update_tensor->place())) {
stop_update = stop_update_tensor->data<bool>()[0];
} else {
framework::Tensor tmp_tensor;
framework::TensorCopySync(
*stop_update_tensor, platform::CPUPlace(), &tmp_tensor);
stop_update = tmp_tensor.data<bool>()[0];
}
}
stop_update |= ctx.Attr<bool>("stop_update");
if (stop_update) {
return; return;
} }
const auto* pre_loss_scaling = ctx.Input<Tensor>("PrevLossScaling"); const MPDType* pre_loss_scaling_data = prev_loss_scaling.data<MPDType>();
const auto* good_in = ctx.Input<Tensor>("InGoodSteps"); const int* good_in_data = in_good_steps.data<int>();
const auto* bad_in = ctx.Input<Tensor>("InBadSteps"); const int* bad_in_data = in_bad_steps.data<int>();
auto* updated_loss_scaling = ctx.Output<Tensor>("LossScaling");
auto* good_out = ctx.Output<Tensor>("OutGoodSteps");
auto* bad_out = ctx.Output<Tensor>("OutBadSteps");
const MPDType* pre_loss_scaling_data = pre_loss_scaling->data<MPDType>();
const int* good_in_data = good_in->data<int>();
const int* bad_in_data = bad_in->data<int>();
MPDType* updated_loss_scaling_data = MPDType* updated_loss_scaling_data =
updated_loss_scaling->mutable_data<MPDType>(dev_ctx.GetPlace()); dev_ctx.template Alloc<MPDType>(loss_scaling);
int* good_out_data = good_out->mutable_data<int>(dev_ctx.GetPlace()); int* good_out_data = dev_ctx.template Alloc<int>(out_good_steps);
int* bad_out_data = bad_out->mutable_data<int>(dev_ctx.GetPlace()); int* bad_out_data = dev_ctx.template Alloc<int>(out_bad_steps);
const int incr_every_n_steps = ctx.Attr<int>("incr_every_n_steps");
const int decr_every_n_nan_or_inf =
ctx.Attr<int>("decr_every_n_nan_or_inf");
const float incr_ratio = ctx.Attr<float>("incr_ratio");
const float decr_ratio = ctx.Attr<float>("decr_ratio");
if (is_found_inf_on_cpu) { if (is_found_inf_on_cpu) {
UpdateLossScalingFunctor<DeviceContext, MPDType, true>{}( UpdateLossScalingFunctor<Context, MPDType, true>{}(
dev_ctx, dev_ctx,
found_inf_data, found_inf_data,
pre_loss_scaling_data, pre_loss_scaling_data,
...@@ -193,7 +164,7 @@ class UpdateLossScalingKernel : public framework::OpKernel<T> { ...@@ -193,7 +164,7 @@ class UpdateLossScalingKernel : public framework::OpKernel<T> {
good_out_data, good_out_data,
bad_out_data); bad_out_data);
} else { } else {
UpdateLossScalingFunctor<DeviceContext, MPDType, false>{}( UpdateLossScalingFunctor<Context, MPDType, false>{}(
dev_ctx, dev_ctx,
found_inf_data, found_inf_data,
pre_loss_scaling_data, pre_loss_scaling_data,
...@@ -207,8 +178,6 @@ class UpdateLossScalingKernel : public framework::OpKernel<T> { ...@@ -207,8 +178,6 @@ class UpdateLossScalingKernel : public framework::OpKernel<T> {
good_out_data, good_out_data,
bad_out_data); bad_out_data);
} }
} }
};
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
...@@ -12,21 +12,36 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,21 +12,36 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #include "paddle/phi/core/compat/op_utils.h"
#include <string> namespace phi {
#include <vector>
KernelSignature UpdateLossScalingOpArgumentMapping(
#include "paddle/fluid/operators/isfinite_op.h" const ArgumentMappingContext& ctx) {
#include "paddle/phi/core/hostdevice.h" if (ctx.HasInput("StopUpdate")) {
return KernelSignature(
namespace paddle { "update_loss_scaling",
namespace operators { {"X", "FoundInfinite", "PrevLossScaling", "InGoodSteps", "InBadSteps"},
{"incr_every_n_steps",
template <typename T> "decr_every_n_nan_or_inf",
inline HOSTDEVICE T Inverse(T s) { "incr_ratio",
return 1.0 / s; "decr_ratio",
"StopUpdate"},
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"});
} else {
return KernelSignature(
"update_loss_scaling",
{"X", "FoundInfinite", "PrevLossScaling", "InGoodSteps", "InBadSteps"},
{"incr_every_n_steps",
"decr_every_n_nan_or_inf",
"incr_ratio",
"decr_ratio",
"stop_update"},
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"});
}
} }
} // namespace operators } // namespace phi
} // namespace paddle
PD_REGISTER_ARG_MAPPING_FN(update_loss_scaling,
phi::UpdateLossScalingOpArgumentMapping);
...@@ -50,6 +50,7 @@ class TestUpdateLossScalingOp(OpTest): ...@@ -50,6 +50,7 @@ class TestUpdateLossScalingOp(OpTest):
self.num_good_steps = np.array([999], dtype=np.int32) self.num_good_steps = np.array([999], dtype=np.int32)
self.num_bad_steps = np.array([1], dtype=np.int32) self.num_bad_steps = np.array([1], dtype=np.int32)
self.zero_steps = np.array([0], dtype=np.int32) self.zero_steps = np.array([0], dtype=np.int32)
self.stop_update = np.array([False], dtype=np.bool)
self.attrs = { self.attrs = {
'incr_every_n_steps': 1000, 'incr_every_n_steps': 1000,
'decr_every_n_nan_or_inf': 2, 'decr_every_n_nan_or_inf': 2,
...@@ -77,7 +78,8 @@ class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): ...@@ -77,7 +78,8 @@ class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
'FoundInfinite': found_inf, 'FoundInfinite': found_inf,
'PrevLossScaling': self.prev_loss_scaling, 'PrevLossScaling': self.prev_loss_scaling,
'InGoodSteps': self.num_good_steps, 'InGoodSteps': self.num_good_steps,
'InBadSteps': self.num_bad_steps 'InBadSteps': self.num_bad_steps,
'StopUpdate': self.stop_update
} }
self.outputs = { self.outputs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册