未验证 提交 84e813e3 编写于 作者: T taixiurong 提交者: GitHub

[xpu] add dropout & amp ops in xpu place (#33891)

上级 d128c286
......@@ -35,7 +35,7 @@ ELSE ()
ENDIF()
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210625")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210701")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T>
class CheckFiniteAndUnscaleXPUKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
using XPUTyp = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
const auto xs = ctx.MultiInput<framework::Tensor>("X");
const auto* scale = ctx.Input<framework::Tensor>("Scale");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
const MPDType* scale_data = scale->data<MPDType>();
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
// cpy to cpu
bool cpu_found_inf_data = false;
MPDType cpu_scale_data;
if (platform::is_xpu_place(scale->place())) {
xpu_memcpy(&cpu_scale_data, scale_data, sizeof(MPDType),
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
} else {
cpu_scale_data = (*scale_data);
}
MPDType inverse_scale = 1.0 / cpu_scale_data;
for (size_t i = 0; i < xs.size(); ++i) {
const auto* x = xs[i];
auto* out = outs[i];
out->mutable_data<T>(dev_ctx.GetPlace());
framework::Tensor is_finite =
ctx.AllocateTmpTensor<bool, platform::XPUDeviceContext>(x->dims(),
dev_ctx);
framework::Tensor is_nan =
ctx.AllocateTmpTensor<bool, platform::XPUDeviceContext>(x->dims(),
dev_ctx);
framework::Tensor is_finite_and_nan =
ctx.AllocateTmpTensor<bool, platform::XPUDeviceContext>(x->dims(),
dev_ctx);
if (cpu_found_inf_data == false) {
int r = xpu::isfinite(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(x->data<T>()),
is_finite.data<bool>(), x->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(isfinite) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::logical_not(dev_ctx.x_context(), reinterpret_cast<const bool*>(
is_finite.data<bool>()),
is_finite.data<bool>(), x->numel());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(logical_not) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::isnan(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(x->data<T>()),
is_nan.data<bool>(), x->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(isnan) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::logical_or(dev_ctx.x_context(), is_finite.data<bool>(),
is_nan.data<bool>(), is_finite.data<bool>(),
x->numel());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(logical_or) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::any(dev_ctx.x_context(), is_finite.data<bool>(),
found_inf_data, x->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(any) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
memory::Copy(platform::CPUPlace(), &cpu_found_inf_data,
BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()),
found_inf_data, sizeof(bool));
}
if (cpu_found_inf_data) {
inverse_scale = 0.0;
}
auto dev_env = XPUEnv::getenv("XPUSIM_DEVICE_MODEL");
if (std::is_same<T, paddle::platform::float16>::value &&
(dev_env == nullptr || std::strcmp(dev_env, "KUNLUN1"))) {
framework::Tensor float_x;
framework::Tensor float_out;
float_x.mutable_data<MPDType>(dev_ctx.GetPlace(),
x->numel() * sizeof(MPDType));
float_out.mutable_data<MPDType>(dev_ctx.GetPlace(),
out->numel() * sizeof(MPDType));
int r = xpu::cast_v2(dev_ctx.x_context(),
reinterpret_cast<const float16*>(x->data<T>()),
float_x.data<MPDType>(), x->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(cast_v2) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::scale(dev_ctx.x_context(), float_x.data<MPDType>(),
float_out.data<MPDType>(), x->numel(), false,
inverse_scale, 0.0);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(scale) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::cast_v2(dev_ctx.x_context(), float_out.data<MPDType>(),
reinterpret_cast<float16*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(cast_v2) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
} else {
int r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(x->data<T>()),
reinterpret_cast<XPUTyp*>(out->data<T>()),
x->numel(), false, inverse_scale, 0.0);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(scale) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
}
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()),
found_inf_data, platform::CPUPlace(), &cpu_found_inf_data,
sizeof(bool));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(check_finite_and_unscale,
ops::CheckFiniteAndUnscaleXPUKernel<float>,
ops::CheckFiniteAndUnscaleXPUKernel<plat::float16>);
#endif
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T>
class UpdateLossScalingXPUKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
using XPUTyp = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
const auto xs = ctx.MultiInput<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
const auto* found_inf = ctx.Input<Tensor>("FoundInfinite");
PADDLE_ENFORCE_EQ(found_inf->numel(), 1,
platform::errors::InvalidArgument(
"FoundInfinite must has only one element."));
const bool* found_inf_data = found_inf->data<bool>();
bool cpu_found_inf_data = false;
if (platform::is_xpu_place(found_inf->place())) {
xpu_memcpy(&cpu_found_inf_data, found_inf_data, sizeof(bool),
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
} else {
cpu_found_inf_data = (*found_inf_data);
}
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int num = out->numel();
if (cpu_found_inf_data) {
VLOG(1) << "-- UpdateLossScaling: Find infinite grads. --";
int r = 0;
r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(out_data), num,
XPUTyp(0.0));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(constant) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
}
const bool stop_update = ctx.Attr<bool>("stop_update");
if (stop_update) {
return;
}
const auto* pre_loss_scaling = ctx.Input<Tensor>("PrevLossScaling");
const auto* good_in = ctx.Input<Tensor>("InGoodSteps");
const auto* bad_in = ctx.Input<Tensor>("InBadSteps");
auto* updated_loss_scaling = ctx.Output<Tensor>("LossScaling");
auto* good_out = ctx.Output<Tensor>("OutGoodSteps");
auto* bad_out = ctx.Output<Tensor>("OutBadSteps");
const MPDType* pre_loss_scaling_data = pre_loss_scaling->data<MPDType>();
const int* good_in_data = good_in->data<int>();
const int* bad_in_data = bad_in->data<int>();
MPDType* updated_loss_scaling_data =
updated_loss_scaling->mutable_data<MPDType>(dev_ctx.GetPlace());
int* good_out_data = good_out->mutable_data<int>(dev_ctx.GetPlace());
int* bad_out_data = bad_out->mutable_data<int>(dev_ctx.GetPlace());
const int incr_every_n_steps = ctx.Attr<int>("incr_every_n_steps");
const int decr_every_n_nan_or_inf =
ctx.Attr<int>("decr_every_n_nan_or_inf");
const float incr_ratio = ctx.Attr<float>("incr_ratio");
const float decr_ratio = ctx.Attr<float>("decr_ratio");
int cpu_bad_in_data;
int cpu_good_in_data;
MPDType cpu_pre_loss_scaling_data;
if (platform::is_xpu_place(bad_in->place())) {
xpu_memcpy(&cpu_bad_in_data, bad_in_data, sizeof(int),
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
} else {
cpu_bad_in_data = (*bad_in_data);
}
if (platform::is_xpu_place(good_in->place())) {
xpu_memcpy(&cpu_good_in_data, good_in_data, sizeof(int),
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
} else {
cpu_good_in_data = (*good_in_data);
}
if (platform::is_xpu_place(pre_loss_scaling->place())) {
xpu_memcpy(&cpu_pre_loss_scaling_data, pre_loss_scaling_data,
sizeof(MPDType), XPUMemcpyKind::XPU_DEVICE_TO_HOST);
} else {
cpu_pre_loss_scaling_data = (*pre_loss_scaling_data);
}
int cpu_good_out_data = 0;
int cpu_bad_out_data = 0;
MPDType cpu_updated_loss_scaling_data;
if (cpu_found_inf_data) {
cpu_good_out_data = 0;
cpu_bad_out_data = cpu_bad_in_data + 1;
if (cpu_bad_out_data == decr_every_n_nan_or_inf) {
MPDType new_loss_scaling = cpu_pre_loss_scaling_data * decr_ratio;
cpu_updated_loss_scaling_data =
(new_loss_scaling < static_cast<MPDType>(1))
? (static_cast<MPDType>(1))
: (new_loss_scaling);
cpu_bad_out_data = 0;
}
} else {
cpu_bad_out_data = 0;
cpu_good_out_data = cpu_good_in_data + 1;
if (cpu_good_out_data == incr_every_n_steps) {
MPDType new_loss_scaling = cpu_pre_loss_scaling_data * incr_ratio;
cpu_updated_loss_scaling_data = (std::isfinite(new_loss_scaling))
? new_loss_scaling
: cpu_pre_loss_scaling_data;
cpu_good_out_data = 0;
}
}
// copy to host
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()),
bad_out_data, platform::CPUPlace(), &cpu_bad_out_data,
sizeof(int));
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()),
good_out_data, platform::CPUPlace(), &cpu_good_out_data,
sizeof(int));
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()),
updated_loss_scaling_data, platform::CPUPlace(),
&cpu_updated_loss_scaling_data, sizeof(MPDType));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(update_loss_scaling,
ops::UpdateLossScalingXPUKernel<float>,
ops::UpdateLossScalingXPUKernel<plat::float16>);
#endif
......@@ -16,11 +16,11 @@ namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_XPU
static std::map<int, float*> mask_data_tables;
static const int max_data_size = 32 * 1024 * 1024;
static std::mutex s_mask_data_table_lock;
template <typename DeviceContext, typename T>
class DropoutXPUKernel : public framework::OpKernel<T> {
using XPUTyp = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
......@@ -30,93 +30,70 @@ class DropoutXPUKernel : public framework::OpKernel<T> {
float dropout_prob = context.Attr<float>("dropout_prob");
auto dropout_implementation =
context.Attr<std::string>("dropout_implementation");
float* mask_data_table = nullptr;
auto& dev_ctx = context.template device_context<DeviceContext>();
PADDLE_ENFORCE_EQ(!context.HasInput("Seed"), true,
platform::errors::InvalidArgument(
("Input(Seed) not supported on XPU")));
int is_upscale = (dropout_implementation == "upscale_in_train");
if (!context.Attr<bool>("is_test")) {
int dev_id =
BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()).GetDeviceId();
int prop = static_cast<int>(dropout_prob * 100);
int is_upscale = (dropout_implementation == "upscale_in_train");
/* mask_data_tables key contains 3 part:
* | 31-16 | 15-8 | 7-0 |
* | dev_id | prob | is_upscale |
*/
int index = (dev_id << 16) + (prop << 8) + is_upscale;
std::lock_guard<std::mutex> lock(s_mask_data_table_lock);
if (mask_data_tables.find(index) == mask_data_tables.end()) {
float* mask_data_host = new float[max_data_size];
std::random_device rnd;
std::minstd_rand engine;
int seed =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
engine.seed(seed);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < max_data_size; ++i) {
if (dist(engine) < dropout_prob) {
mask_data_host[i] = 0.0f;
} else {
if (is_upscale) {
mask_data_host[i] = 1.0f / static_cast<T>(1.0f - dropout_prob);
} else {
mask_data_host[i] = 1.0;
}
}
}
PADDLE_ENFORCE_EQ(
xpu_malloc(reinterpret_cast<void**>(&mask_data_table),
max_data_size * sizeof(float)),
XPU_SUCCESS,
platform::errors::ResourceExhausted(
"\n\nOut of memory error on XPU, Cannot"
"allocate %s memory on XPU. \n\nPlease "
"check whether there is any other process "
"using XPU.\n",
string::HumanReadableSize(max_data_size * sizeof(void*))));
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()),
mask_data_table, platform::CPUPlace(), mask_data_host,
max_data_size * sizeof(float));
mask_data_tables[index] = mask_data_table;
free(mask_data_host);
std::random_device rnd;
// int seed = (context.Attr<bool>("fix_seed")) ?
// int(context.Attr<int>("seed")) : (rnd());
int seed = 0;
if (context.Attr<bool>("fix_seed") == true) {
seed = static_cast<int>(context.Attr<int>("seed"));
} else {
mask_data_table = mask_data_tables[index];
seed = rnd();
}
}
if (!context.Attr<bool>("is_test")) { // Train
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
size_t size = framework::product(mask->dims());
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::dropout(dev_ctx.x_context(), mask_data_table, x_data,
mask_data, y_data, max_data_size, size);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU dropout return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
} else { // Infer
float scale = 0.0f;
if (dropout_implementation == "upscale_in_train") {
scale = 1.0f;
} else {
scale = static_cast<T>(1.0f - dropout_prob);
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
int r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(y_data), y->numel(),
XPUTyp(0));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(constant) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(mask_data), mask->numel(),
XPUTyp(0));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(constant) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
return;
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::scale(dev_ctx.x_context(), x->numel(), scale, 0.0f, 0,
x_data, y_data);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU dropout return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
int r = xpu::dropout(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(x->data<T>()),
reinterpret_cast<XPUTyp*>(y->data<T>()),
reinterpret_cast<XPUTyp*>(mask_data), seed,
mask->numel(), is_upscale, dropout_prob);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(dropout) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
} else {
float scale =
(is_upscale) ? (1.0) : (static_cast<float>(1.0f - dropout_prob));
int r = xpu::scale(
dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(x_data),
reinterpret_cast<XPUTyp*>(y_data), x->numel(), false, scale, 0.0f);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(scale) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
}
};
template <typename DeviceContext, typename T>
class DropoutGradXPUKernel : public framework::OpKernel<T> {
using XPUTyp = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(!context.Attr<bool>("is_test"), true,
......@@ -127,23 +104,47 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::elementwise_mul(dev_ctx.x_context(), grad_y->data<T>(),
mask->data<T>(), grad_x->data<T>(),
grad_y->numel());
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU dropout return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
float dropout_prob = context.Attr<float>("dropout_prob");
const T* mask_data = mask->data<T>();
framework::Tensor mask_new;
if (dropout_implementation == "upscale_in_train") {
mask_new = context.AllocateTmpTensor<T, platform::XPUDeviceContext>(
mask->dims(), dev_ctx);
float scale =
(dropout_prob == 1.0f) ? (1.0f) : (1.0f / (1.0f - dropout_prob));
int r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(mask->data<T>()),
reinterpret_cast<XPUTyp*>(mask_new.data<T>()),
mask->numel(), false, scale, 0.0f);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(scale) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
mask_data = mask_new.data<T>();
}
int r = xpu::mul(
dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(grad_y->data<T>()),
reinterpret_cast<const XPUTyp*>(mask_data),
reinterpret_cast<XPUTyp*>(grad_x->data<T>()), grad_y->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(mul) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
dropout, ops::DropoutXPUKernel<paddle::platform::XPUDeviceContext, float>);
dropout, ops::DropoutXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::DropoutXPUKernel<paddle::platform::XPUDeviceContext, plat::float16>);
REGISTER_OP_XPU_KERNEL(
dropout_grad,
ops::DropoutGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::DropoutGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::DropoutGradXPUKernel<paddle::platform::XPUDeviceContext,
plat::float16>);
#endif
......@@ -122,33 +122,50 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
int x_len = 1;
int y_len = 1;
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
x_len *= x_dims_vec[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
x_len *= x_dims_vec[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
y_len *= y_dims_vec[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
y_len *= y_dims_vec[i];
}
}
const T* dz_data = dz->data<T>();
framework::Tensor dx_local_tensor;
framework::Tensor dy_local_tensor;
bool need_wait = false;
T* dx_data = nullptr;
T* dy_data = nullptr;
if (dx) {
dx_data = dx->mutable_data<T>(ctx.GetPlace());
} else {
dx_data =
dx_local_tensor.mutable_data<T>(ctx.GetPlace(), x_len * sizeof(T));
need_wait = true;
}
if (dy) {
dy_data = dy->mutable_data<T>(ctx.GetPlace());
} else {
dy_data =
dy_local_tensor.mutable_data<T>(ctx.GetPlace(), y_len * sizeof(T));
need_wait = true;
}
auto& dev_ctx =
......@@ -161,6 +178,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
platform::errors::External(
"XPU kernel Elementwise occur error in XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
if (need_wait && dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
}
};
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("..")
import paddle
import unittest
import numpy as np
from op_test_xpu import XPUOpTest
from op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid
paddle.enable_static()
class TestCheckFiniteAndUnscaleOp(XPUOpTest):
def setUp(self):
self.op_type = "check_finite_and_unscale"
self.init_dtype()
x = np.random.random((1024, 1024)).astype(self.dtype)
scale = np.random.random((1)).astype(self.dtype)
# self.attrs = {'stop_gradient': True}
self.inputs = {'X': [('x0', x)], 'Scale': scale}
self.outputs = {
'FoundInfinite': np.array([0]),
'Out': [('out0', x / scale)],
}
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
# class TestCheckFiniteAndUnscaleOpWithNan(XPUOpTest):
# def setUp(self):
# self.op_type = "check_finite_and_unscale"
# self.init_dtype()
# x = np.random.random((1024, 1024)).astype(self.dtype)
# x[128][128] = np.nan
# print("x shape = ", x.shape)
# print(x)
# scale = np.random.random((1)).astype(self.dtype)
# self.inputs = {'X': [('x0', x)], 'Scale': scale}
# self.outputs = {
# 'FoundInfinite': np.array([1]),
# 'Out': [('out0', x)],
# }
# def init_dtype(self):
# self.dtype = np.float32
# def test_check_output(self):
# # When input contains nan, do not check the output,
# # since the output may be nondeterministic and will be discarded.
# if paddle.is_compiled_with_xpu():
# place = paddle.XPUPlace(0)
# self.check_output_with_place(place, no_check_set=['Out'])
# class TestCheckFiniteAndUnscaleOpWithInf(XPUOpTest):
# def setUp(self):
# self.op_type = "check_finite_and_unscale"
# self.init_dtype()
# x = np.random.random((1024, 1024)).astype(self.dtype)
# x[128][128] = np.inf
# scale = np.random.random((1)).astype(self.dtype)
# self.inputs = {'X': [('x0', x)], 'Scale': scale}
# self.outputs = {
# 'FoundInfinite': np.array([1]),
# 'Out': [('out0', x)],
# }
# def init_dtype(self):
# self.dtype = np.float32
# def test_check_output(self):
# # When input contains inf, do not check the output,
# # since the output may be nondeterministic and will be discarded.
# if paddle.is_compiled_with_xpu():
# place = paddle.XPUPlace(0)
# self.check_output_with_place(place, no_check_set=['Out'])
if __name__ == '__main__':
unittest.main()
......@@ -22,9 +22,11 @@ from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from op_test_xpu import XPUOpTest
paddle.enable_static()
class TestDropoutOp(OpTest):
class TestDropoutOp(XPUOpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
......@@ -47,7 +49,7 @@ class TestDropoutOp(OpTest):
self.check_grad_with_place(place, ['X'], 'Out')
class TestDropoutOpInput1d(OpTest):
class TestDropoutOpInput1d(XPUOpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((2000, )).astype("float32")}
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
sys.path.append("..")
import numpy as np
from op_test import OpTest
from op_test_xpu import XPUOpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn
paddle.enable_static()
class TestUpdateLossScalingOp(XPUOpTest):
def setUp(self):
self.op_type = "update_loss_scaling"
self.init()
found_inf = np.array([False], dtype=np.bool)
x = np.random.random((1024, 1024)).astype(self.dtype)
self.inputs = {
'X': [('x0', x)],
'FoundInfinite': found_inf,
'PrevLossScaling': self.prev_loss_scaling,
'InGoodSteps': self.num_good_steps,
'InBadSteps': self.num_bad_steps
}
self.outputs = {
'Out': [('out0', x)],
'LossScaling': self.prev_loss_scaling * self.incr_ratio,
'OutGoodSteps': self.zero_steps,
'OutBadSteps': self.zero_steps
}
def init(self):
self.incr_ratio = 2.0
self.decr_ratio = 0.8
self.dtype = np.float32
self.prev_loss_scaling = np.array([2048]).astype(self.dtype)
self.num_good_steps = np.array([999], dtype=np.int32)
self.num_bad_steps = np.array([1], dtype=np.int32)
self.zero_steps = np.array([0], dtype=np.int32)
self.attrs = {
'incr_every_n_steps': 1000,
'decr_every_n_nan_or_inf': 2,
'incr_ratio': self.incr_ratio,
'decr_ratio': self.decr_ratio,
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place, no_check_set=['Out'])
class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
def setUp(self):
self.op_type = "update_loss_scaling"
self.init()
found_inf = np.array([True], dtype=np.bool)
x = np.random.random((1024, 1024)).astype(self.dtype)
i = np.random.randint(0, 1024, 1)
j = np.random.randint(0, 1024, 1)
x[i[0]][j[0]] = np.inf
self.inputs = {
'X': [('x0', x)],
'FoundInfinite': found_inf,
'PrevLossScaling': self.prev_loss_scaling,
'InGoodSteps': self.num_good_steps,
'InBadSteps': self.num_bad_steps
}
self.outputs = {
'Out': [('out0', np.zeros_like(x))],
'LossScaling': self.prev_loss_scaling * self.decr_ratio,
'OutGoodSteps': self.zero_steps,
'OutBadSteps': self.zero_steps
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
#self.check_output()
class TestUpdateLossScalingLayer(unittest.TestCase):
def loss_scaling_check(self, scope=fluid.Scope()):
a = fluid.data(name="a", shape=[1024, 1024], dtype='float32')
b = fluid.data(name="b", shape=[512, 128], dtype='float32')
x = [a, b]
found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool')
prev_loss_scaling = fluid.data(
name="prev_loss_scaling", shape=[1], dtype='float32')
num_good_steps = fluid.data(
name="num_good_steps", shape=[1], dtype='int32')
num_bad_steps = fluid.data(
name="num_bad_steps", shape=[1], dtype='int32')
a_v = np.random.random([1024, 1024]).astype('float32')
b_v = np.random.random([512, 128]).astype('float32')
found_inf_v = np.array([False]).astype('bool')
prev_loss_scaling_v = np.array([2048]).astype('float32')
num_good_steps_v = np.array([999], dtype=np.int32)
num_bad_steps_v = np.array([1], dtype=np.int32)
incr_every_n_steps = 1000
decr_every_n_nan_or_inf = 2
incr_ratio = 2
decr_ratio = 0.8
result = amp_nn.update_loss_scaling(
x,
found_inf,
prev_loss_scaling,
num_good_steps,
num_bad_steps,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
name="update_loss_scaling")
place = fluid.XPUPlace(0)
exe = fluid.Executor(place)
with fluid.scope_guard(scope):
exe.run(fluid.default_startup_program())
result_v = exe.run(feed={
'a': a_v,
'b': b_v,
'found_inf': found_inf_v,
'prev_loss_scaling': prev_loss_scaling_v,
'num_good_steps': num_good_steps_v,
'num_bad_steps': num_bad_steps_v
},
fetch_list=[
result, x, found_inf, prev_loss_scaling,
num_good_steps, num_bad_steps
])
assert np.array_equal(result_v[0], a_v)
assert np.array_equal(result_v[1], b_v)
assert np.array_equal(result_v[0], result_v[2])
assert np.array_equal(result_v[1], result_v[3])
assert np.array_equal(result_v[4], found_inf_v)
assert np.array_equal(result_v[5], prev_loss_scaling_v * incr_ratio)
assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v))
assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v))
def loss_scaling_check_inf(self, use_cuda=True, scope=fluid.Scope()):
a = fluid.data(name="a", shape=[1024, 1024], dtype='float32')
b = fluid.data(name="b", shape=[512, 128], dtype='float32')
x = [a, b]
found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool')
prev_loss_scaling = fluid.data(
name="prev_loss_scaling", shape=[1], dtype='float32')
num_good_steps = fluid.data(
name="num_good_steps", shape=[1], dtype='int32')
num_bad_steps = fluid.data(
name="num_bad_steps", shape=[1], dtype='int32')
a_v = np.random.random([1024, 1024]).astype('float32')
b_v = np.random.random([512, 128]).astype('float32')
i = np.random.randint(0, 1024, 1)
j = np.random.randint(0, 1024, 1)
a_v[i[0]][j[0]] = np.inf
found_inf_v = np.array([True]).astype('bool')
prev_loss_scaling_v = np.array([2048]).astype('float32')
num_good_steps_v = np.array([999], dtype=np.int32)
num_bad_steps_v = np.array([1], dtype=np.int32)
incr_every_n_steps = 1000
decr_every_n_nan_or_inf = 2
incr_ratio = 2
decr_ratio = 0.8
result = amp_nn.update_loss_scaling(
x,
found_inf,
prev_loss_scaling,
num_good_steps,
num_bad_steps,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
name="update_loss_scaling")
place = fluid.XPUPlace(0)
exe = fluid.Executor(place)
with fluid.scope_guard(scope):
exe.run(fluid.default_startup_program())
result_v = exe.run(feed={
'a': a_v,
'b': b_v,
'found_inf': found_inf_v,
'prev_loss_scaling': prev_loss_scaling_v,
'num_good_steps': num_good_steps_v,
'num_bad_steps': num_bad_steps_v
},
fetch_list=[
result, x, found_inf, prev_loss_scaling,
num_good_steps, num_bad_steps
])
assert np.array_equal(result_v[0], np.zeros_like(a_v))
assert np.array_equal(result_v[1], np.zeros_like(b_v))
assert np.array_equal(result_v[2], np.zeros_like(a_v))
assert np.array_equal(result_v[3], np.zeros_like(b_v))
assert np.array_equal(result_v[4], found_inf_v)
assert np.array_equal(result_v[5], prev_loss_scaling_v * decr_ratio)
assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v))
assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v))
def test_loss_scaling(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.loss_scaling_check()
def test_loss_scaling_inf(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.loss_scaling_check_inf()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册