未验证 提交 fecbc958 编写于 作者: Q QingshuChen 提交者: GitHub

add some fp16 op for kunlun resnet50 model (#44672)

* add some fp16 op for kunlun resnet50 model
*test=kunlun

* tmp
*test=kunlun
上级 a9919903
...@@ -23,6 +23,8 @@ using Tensor = framework::Tensor; ...@@ -23,6 +23,8 @@ using Tensor = framework::Tensor;
template <typename T> template <typename T>
class ResNetUnitXPUKernel : public framework::OpKernel<T> { class ResNetUnitXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -63,9 +65,12 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> { ...@@ -63,9 +65,12 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> {
std::string act_type = ctx.Attr<std::string>("act_type"); std::string act_type = ctx.Attr<std::string>("act_type");
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
std::vector<const T *> x_list = {input_x->data<T>()}; std::vector<const XPUType *> x_list = {
std::vector<const T *> w_list = {filter_x->data<T>()}; reinterpret_cast<const XPUType *>(input_x->data<T>())};
std::vector<T *> conv_y_list = {conv_out_x->mutable_data<T>(place)}; std::vector<const XPUType *> w_list = {
reinterpret_cast<const XPUType *>(filter_x->data<T>())};
std::vector<XPUType *> conv_y_list = {
reinterpret_cast<XPUType *>(conv_out_x->mutable_data<T>(place))};
std::vector<std::vector<int>> x_shape_list = { std::vector<std::vector<int>> x_shape_list = {
phi::vectorize<int>(input_x->dims())}; phi::vectorize<int>(input_x->dims())};
...@@ -107,9 +112,10 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> { ...@@ -107,9 +112,10 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> {
Tensor *running_mean_z = ctx.Output<Tensor>("RunningMeanZ"); Tensor *running_mean_z = ctx.Output<Tensor>("RunningMeanZ");
Tensor *running_var_z = ctx.Output<Tensor>("RunningVarZ"); Tensor *running_var_z = ctx.Output<Tensor>("RunningVarZ");
x_list.push_back(input_z->data<T>()); x_list.push_back(reinterpret_cast<const XPUType *>(input_z->data<T>()));
w_list.push_back(filter_z->data<T>()); w_list.push_back(reinterpret_cast<const XPUType *>(filter_z->data<T>()));
conv_y_list.push_back(conv_out_z->mutable_data<T>(place)); conv_y_list.push_back(
reinterpret_cast<XPUType *>(conv_out_z->mutable_data<T>(place)));
x_shape_list.push_back(phi::vectorize<int>(input_z->dims())); x_shape_list.push_back(phi::vectorize<int>(input_z->dims()));
...@@ -133,17 +139,17 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> { ...@@ -133,17 +139,17 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> {
if (fuse_add) { if (fuse_add) {
const Tensor *input_z = ctx.Input<Tensor>("Z"); const Tensor *input_z = ctx.Input<Tensor>("Z");
auto input_z_shape = phi::vectorize<int>(input_z->dims()); auto input_z_shape = phi::vectorize<int>(input_z->dims());
x_list.push_back(input_z->data<T>()); x_list.push_back(reinterpret_cast<const XPUType *>(input_z->data<T>()));
x_shape_list.push_back(input_z_shape); x_shape_list.push_back(input_z_shape);
x_maxlist.push_back(nullptr); x_maxlist.push_back(nullptr);
} }
} }
int r = xpu::resnet_unit_fusion<T, T, T, int16_t>( int r = xpu::resnet_unit_fusion<XPUType, XPUType, XPUType, int16_t>(
dev_ctx.x_context(), dev_ctx.x_context(),
x_list, x_list,
w_list, w_list,
conv_y_list, conv_y_list,
output->mutable_data<T>(place), reinterpret_cast<XPUType *>(output->mutable_data<T>(place)),
x_shape_list, x_shape_list,
filter_x_shape[0], filter_x_shape[0],
ksize_list, ksize_list,
...@@ -172,6 +178,8 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> { ...@@ -172,6 +178,8 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> {
template <typename T> template <typename T>
class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -208,11 +216,16 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { ...@@ -208,11 +216,16 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
std::vector<const T *> x_list = {x->data<T>()}; std::vector<const XPUType *> x_list = {
std::vector<const T *> w_list = {filter_x->data<T>()}; reinterpret_cast<const XPUType *>(x->data<T>())};
std::vector<const T *> conv_y_list = {conv_out_x->data<T>()}; std::vector<const XPUType *> w_list = {
std::vector<T *> dx_list = {x_grad->mutable_data<T>(place)}; reinterpret_cast<const XPUType *>(filter_x->data<T>())};
std::vector<T *> dw_list = {filter_x_grad->mutable_data<T>(place)}; std::vector<const XPUType *> conv_y_list = {
reinterpret_cast<const XPUType *>(conv_out_x->data<T>())};
std::vector<XPUType *> dx_list = {
reinterpret_cast<XPUType *>(x_grad->mutable_data<T>(place))};
std::vector<XPUType *> dw_list = {
reinterpret_cast<XPUType *>(filter_x_grad->mutable_data<T>(place))};
std::vector<std::vector<int>> x_shape_list = { std::vector<std::vector<int>> x_shape_list = {
phi::vectorize<int>(x->dims())}; phi::vectorize<int>(x->dims())};
...@@ -262,11 +275,14 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { ...@@ -262,11 +275,14 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
Tensor *scale_z_grad = Tensor *scale_z_grad =
ctx.Output<Tensor>(framework::GradVarName("ScaleZ")); ctx.Output<Tensor>(framework::GradVarName("ScaleZ"));
Tensor *bias_z_grad = ctx.Output<Tensor>(framework::GradVarName("BiasZ")); Tensor *bias_z_grad = ctx.Output<Tensor>(framework::GradVarName("BiasZ"));
x_list.push_back(z->data<T>()); x_list.push_back(reinterpret_cast<const XPUType *>(z->data<T>()));
w_list.push_back(filter_z->data<T>()); w_list.push_back(reinterpret_cast<const XPUType *>(filter_z->data<T>()));
conv_y_list.push_back(conv_out_z->data<T>()); conv_y_list.push_back(
dx_list.push_back(z_grad->mutable_data<T>(place)); reinterpret_cast<const XPUType *>(conv_out_z->data<T>()));
dw_list.push_back(filter_z_grad->mutable_data<T>(place)); dx_list.push_back(
reinterpret_cast<XPUType *>(z_grad->mutable_data<T>(place)));
dw_list.push_back(
reinterpret_cast<XPUType *>(filter_z_grad->mutable_data<T>(place)));
x_shape_list.push_back(phi::vectorize<int>(z->dims())); x_shape_list.push_back(phi::vectorize<int>(z->dims()));
auto filter_z_shape = phi::vectorize<int>(filter_z->dims()); auto filter_z_shape = phi::vectorize<int>(filter_z->dims());
...@@ -288,38 +304,39 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { ...@@ -288,38 +304,39 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
} else { } else {
if (fuse_add) { if (fuse_add) {
auto z_grad = ctx.Output<Tensor>(framework::GradVarName("Z")); auto z_grad = ctx.Output<Tensor>(framework::GradVarName("Z"));
dx_list.push_back(z_grad->mutable_data<T>(place)); dx_list.push_back(
reinterpret_cast<XPUType *>(z_grad->mutable_data<T>(place)));
} }
} }
int r = int r = xpu::resnet_unit_grad_fusion<XPUType, XPUType, XPUType, int16_t>(
xpu::resnet_unit_grad_fusion<T, T, T, int16_t>(dev_ctx.x_context(), dev_ctx.x_context(),
x_list, x_list,
w_list, w_list,
y_grad->data<T>(), reinterpret_cast<const XPUType *>(y_grad->data<T>()),
output->data<T>(), reinterpret_cast<const XPUType *>(output->data<T>()),
conv_y_list, conv_y_list,
dx_list, dx_list,
dw_list, dw_list,
x_shape_list, x_shape_list,
filter_x_shape[0], filter_x_shape[0],
ksize_list, ksize_list,
stride_list, stride_list,
paddings, paddings,
dilations, dilations,
group, group,
x_maxlist, x_maxlist,
w_maxlist, w_maxlist,
scale_list, scale_list,
batch_mean_list, batch_mean_list,
batch_invstd_list, batch_invstd_list,
dscale_list, dscale_list,
dbias_list, dbias_list,
xpu::Activation_t::RELU, xpu::Activation_t::RELU,
eps, eps,
is_nchw, is_nchw,
has_shortcut, has_shortcut,
fuse_add); fuse_add);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_grad_fusion"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_grad_fusion");
} }
}; };
...@@ -329,5 +346,9 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { ...@@ -329,5 +346,9 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(resnet_unit, ops::ResNetUnitXPUKernel<float>); REGISTER_OP_XPU_KERNEL(resnet_unit,
REGISTER_OP_XPU_KERNEL(resnet_unit_grad, ops::ResNetUnitGradXPUKernel<float>); ops::ResNetUnitXPUKernel<plat::float16>,
ops::ResNetUnitXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(resnet_unit_grad,
ops::ResNetUnitGradXPUKernel<plat::float16>,
ops::ResNetUnitGradXPUKernel<float>);
...@@ -22,6 +22,8 @@ namespace operators { ...@@ -22,6 +22,8 @@ namespace operators {
template <typename T> template <typename T>
class LarsMomentumOpXPUKernel : public framework::OpKernel<T> { class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
bool multi_precision = ctx.Attr<bool>("multi_precision"); bool multi_precision = ctx.Attr<bool>("multi_precision");
...@@ -35,14 +37,14 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> { ...@@ -35,14 +37,14 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam"); auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam");
auto master_param_out = auto master_param_out =
ctx.MultiOutput<framework::LoDTensor>("MasterParamOut"); ctx.MultiOutput<framework::LoDTensor>("MasterParamOut");
T mu = static_cast<T>(ctx.Attr<float>("mu")); float mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff"); float lars_coeff = ctx.Attr<float>("lars_coeff");
T epsilon = ctx.Attr<float>("epsilon"); float epsilon = ctx.Attr<float>("epsilon");
T rescale_grad = ctx.Attr<float>("rescale_grad"); float rescale_grad = ctx.Attr<float>("rescale_grad");
std::vector<T*> param_list; std::vector<XPUType*> param_list;
std::vector<T*> grad_list; std::vector<XPUType*> grad_list;
std::vector<T*> param_out_list; std::vector<XPUType*> param_out_list;
std::vector<float*> velocity_list; std::vector<float*> velocity_list;
std::vector<float*> velocity_out_list; std::vector<float*> velocity_out_list;
std::vector<float*> lrs; std::vector<float*> lrs;
...@@ -52,9 +54,12 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> { ...@@ -52,9 +54,12 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
std::vector<float*> master_param_out_list; std::vector<float*> master_param_out_list;
int op_num = param.size(); int op_num = param.size();
for (int i = 0; i < op_num; ++i) { for (int i = 0; i < op_num; ++i) {
param_list.push_back(const_cast<T*>(param[i]->data<T>())); param_list.push_back(
grad_list.push_back(const_cast<T*>(grad[i]->data<T>())); reinterpret_cast<XPUType*>(const_cast<T*>((param[i]->data<T>()))));
param_out_list.push_back(param_out[i]->mutable_data<T>(ctx.GetPlace())); grad_list.push_back(
reinterpret_cast<XPUType*>(const_cast<T*>(grad[i]->data<T>())));
param_out_list.push_back(reinterpret_cast<XPUType*>(
param_out[i]->mutable_data<T>(ctx.GetPlace())));
velocity_list.push_back(const_cast<float*>(velocity[i]->data<float>())); velocity_list.push_back(const_cast<float*>(velocity[i]->data<float>()));
velocity_out_list.push_back( velocity_out_list.push_back(
velocity_out[i]->mutable_data<float>(ctx.GetPlace())); velocity_out[i]->mutable_data<float>(ctx.GetPlace()));
...@@ -111,5 +116,7 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> { ...@@ -111,5 +116,7 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(lars_momentum, ops::LarsMomentumOpXPUKernel<float>); REGISTER_OP_XPU_KERNEL(lars_momentum,
ops::LarsMomentumOpXPUKernel<paddle::platform::float16>,
ops::LarsMomentumOpXPUKernel<float>);
#endif #endif
...@@ -231,7 +231,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -231,7 +231,9 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"generate_proposals_v2", {"generate_proposals_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"grad_add", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"grad_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"greater_equal", {"greater_equal",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
...@@ -254,9 +256,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -254,9 +256,8 @@ XPUOpMap& get_kl2_ops() {
{"label_smooth", {"label_smooth",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lars_momentum", {"lars_momentum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"layer_norm_grad", pOpKernelType(vartype::FP16, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm_grad", {"layer_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
...@@ -380,9 +381,12 @@ XPUOpMap& get_kl2_ops() { ...@@ -380,9 +381,12 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"resnet_unit", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"resnet_unit",
XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"resnet_unit_grad", {"resnet_unit_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
...@@ -502,6 +506,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -502,6 +506,9 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"top_k_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"top_k_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"update_loss_scaling",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"unsqueeze2_grad", {"unsqueeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
......
...@@ -24,13 +24,15 @@ void GradAddXPUKernel(const Context& dev_ctx, ...@@ -24,13 +24,15 @@ void GradAddXPUKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
auto x_shape = phi::vectorize<int>(x.dims()); auto x_shape = phi::vectorize<int>(x.dims());
auto y_shape = phi::vectorize<int>(y.dims()); auto y_shape = phi::vectorize<int>(y.dims());
int r = xpu::broadcast_add(dev_ctx.x_context(), int r = xpu::broadcast_add(dev_ctx.x_context(),
x.data<T>(), reinterpret_cast<const XPUType*>(x.data<T>()),
y.data<T>(), reinterpret_cast<const XPUType*>(y.data<T>()),
out->data<T>(), reinterpret_cast<XPUType*>(out->data<T>()),
x_shape, x_shape,
y_shape); y_shape);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
...@@ -38,4 +40,9 @@ void GradAddXPUKernel(const Context& dev_ctx, ...@@ -38,4 +40,9 @@ void GradAddXPUKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(grad_add, XPU, ALL_LAYOUT, phi::GradAddXPUKernel, float) {} PD_REGISTER_KERNEL(grad_add,
XPU,
ALL_LAYOUT,
phi::GradAddXPUKernel,
phi::dtype::float16,
float) {}
...@@ -26,6 +26,7 @@ void LogSoftmaxGradKernel(const Context& dev_ctx, ...@@ -26,6 +26,7 @@ void LogSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int axis, int axis,
DenseTensor* x_grad) { DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
const int rank = out.dims().size(); const int rank = out.dims().size();
axis = funcs::CanonicalAxis(axis, rank); axis = funcs::CanonicalAxis(axis, rank);
...@@ -40,24 +41,29 @@ void LogSoftmaxGradKernel(const Context& dev_ctx, ...@@ -40,24 +41,29 @@ void LogSoftmaxGradKernel(const Context& dev_ctx,
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
tmp2_ptr, nullptr, phi::errors::External("no enough memory in xpu")); tmp2_ptr, nullptr, phi::errors::External("no enough memory in xpu"));
int r = int r = xpu::exp<XPUType>(dev_ctx.x_context(),
xpu::exp(dev_ctx.x_context(), out.data<T>(), tmp_ptr, out_grad.numel()); reinterpret_cast<const XPUType*>(out.data<T>()),
reinterpret_cast<XPUType*>(tmp_ptr),
out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp");
r = xpu::reciprocal( r = xpu::reciprocal<XPUType>(dev_ctx.x_context(),
dev_ctx.x_context(), tmp_ptr, tmp2_ptr, out_grad.numel()); reinterpret_cast<const XPUType*>(tmp_ptr),
reinterpret_cast<XPUType*>(tmp2_ptr),
out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reciprocal"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "reciprocal");
r = xpu::mul(dev_ctx.x_context(), r = xpu::mul<XPUType>(dev_ctx.x_context(),
tmp2_ptr, reinterpret_cast<const XPUType*>(tmp2_ptr),
out_grad.data<T>(), reinterpret_cast<const XPUType*>(out_grad.data<T>()),
tmp2_ptr, reinterpret_cast<XPUType*>(tmp2_ptr),
out_grad.numel()); out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
r = xpu::softmax_grad(dev_ctx.x_context(), r = xpu::softmax_grad<XPUType>(
tmp_ptr, dev_ctx.x_context(),
tmp2_ptr, reinterpret_cast<const XPUType*>(tmp_ptr),
x_grad->data<T>(), reinterpret_cast<const XPUType*>(tmp2_ptr),
out_shape, reinterpret_cast<XPUType*>(x_grad->data<T>()),
axis); out_shape,
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad");
} }
} }
......
...@@ -25,6 +25,7 @@ void LogSoftmaxKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void LogSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int axis, int axis,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
const int rank = x.dims().size(); const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank); axis = funcs::CanonicalAxis(axis, rank);
...@@ -32,11 +33,16 @@ void LogSoftmaxKernel(const Context& dev_ctx, ...@@ -32,11 +33,16 @@ void LogSoftmaxKernel(const Context& dev_ctx,
auto x_shape = phi::vectorize<int>(x.dims()); auto x_shape = phi::vectorize<int>(x.dims());
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (axis < 0) axis += rank; if (axis < 0) axis += rank;
int r = xpu::softmax<T>( int r = xpu::softmax<XPUType>(dev_ctx.x_context(),
dev_ctx.x_context(), x.data<T>(), out->data<T>(), x_shape, axis); reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x_shape,
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
r = xpu::log<T>( r = xpu::log<XPUType>(dev_ctx.x_context(),
dev_ctx.x_context(), out->data<T>(), out->data<T>(), out->numel()); reinterpret_cast<const XPUType*>(out->data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "log"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "log");
} }
} }
......
...@@ -23,231 +23,242 @@ import paddle ...@@ -23,231 +23,242 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static() paddle.enable_static()
class TestUpdateLossScalingOp(XPUOpTest): class XPUTestUpdateLossScalingOp(XPUOpTestWrapper):
def setUp(self): def __init__(self):
self.op_type = "update_loss_scaling" self.op_name = "update_loss_scaling"
self.init() self.use_dynamic_create_class = False
found_inf = np.array([False], dtype=np.bool_)
x = np.random.random((1024, 1024)).astype(self.dtype) class TestUpdateLossScalingOp(XPUOpTest):
self.inputs = { def setUp(self):
'X': [('x0', x)], self.op_type = "update_loss_scaling"
'FoundInfinite': found_inf, self.init()
'PrevLossScaling': self.prev_loss_scaling, found_inf = np.array([False], dtype=np.bool_)
'InGoodSteps': self.num_good_steps, x = np.random.random((1024, 1024)).astype(self.dtype)
'InBadSteps': self.num_bad_steps
} self.inputs = {
'X': [('x0', x)],
self.outputs = { 'FoundInfinite': found_inf,
'Out': [('out0', x)], 'PrevLossScaling': self.prev_loss_scaling,
'LossScaling': self.prev_loss_scaling * self.incr_ratio, 'InGoodSteps': self.num_good_steps,
'OutGoodSteps': self.zero_steps, 'InBadSteps': self.num_bad_steps
'OutBadSteps': self.zero_steps }
}
self.outputs = {
def init(self): 'Out': [('out0', x)],
self.incr_ratio = 2.0 'LossScaling': self.prev_loss_scaling * self.incr_ratio,
self.decr_ratio = 0.8 'OutGoodSteps': self.zero_steps,
self.dtype = np.float32 'OutBadSteps': self.zero_steps
self.prev_loss_scaling = np.array([2048]).astype(self.dtype) }
self.num_good_steps = np.array([999], dtype=np.int32)
self.num_bad_steps = np.array([1], dtype=np.int32) def init(self):
self.zero_steps = np.array([0], dtype=np.int32) self.incr_ratio = 2.0
self.attrs = { self.decr_ratio = 0.8
'incr_every_n_steps': 1000, self.dtype = np.float32
'decr_every_n_nan_or_inf': 2, self.prev_loss_scaling = np.array([2048]).astype(self.dtype)
'incr_ratio': self.incr_ratio, self.num_good_steps = np.array([999], dtype=np.int32)
'decr_ratio': self.decr_ratio, self.num_bad_steps = np.array([1], dtype=np.int32)
} self.zero_steps = np.array([0], dtype=np.int32)
self.attrs = {
def test_check_output(self): 'incr_every_n_steps': 1000,
if paddle.is_compiled_with_xpu(): 'decr_every_n_nan_or_inf': 2,
place = paddle.XPUPlace(0) 'incr_ratio': self.incr_ratio,
self.check_output_with_place(place, no_check_set=['Out']) 'decr_ratio': self.decr_ratio,
}
class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): def test_check_output(self):
if paddle.is_compiled_with_xpu():
def setUp(self): place = paddle.XPUPlace(0)
self.op_type = "update_loss_scaling" self.check_output_with_place(place, no_check_set=['Out'])
self.init()
found_inf = np.array([True], dtype=np.bool_) class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
x = np.random.random((1024, 1024)).astype(self.dtype)
i = np.random.randint(0, 1024, 1) def setUp(self):
j = np.random.randint(0, 1024, 1) self.op_type = "update_loss_scaling"
x[i[0]][j[0]] = np.inf self.init()
found_inf = np.array([True], dtype=np.bool_)
self.inputs = { x = np.random.random((1024, 1024)).astype(self.dtype)
'X': [('x0', x)], i = np.random.randint(0, 1024, 1)
'FoundInfinite': found_inf, j = np.random.randint(0, 1024, 1)
'PrevLossScaling': self.prev_loss_scaling, x[i[0]][j[0]] = np.inf
'InGoodSteps': self.num_good_steps,
'InBadSteps': self.num_bad_steps self.inputs = {
} 'X': [('x0', x)],
'FoundInfinite': found_inf,
self.outputs = { 'PrevLossScaling': self.prev_loss_scaling,
'Out': [('out0', np.zeros_like(x))], 'InGoodSteps': self.num_good_steps,
'LossScaling': self.prev_loss_scaling * self.decr_ratio, 'InBadSteps': self.num_bad_steps
'OutGoodSteps': self.zero_steps, }
'OutBadSteps': self.zero_steps
} self.outputs = {
'Out': [('out0', np.zeros_like(x))],
def test_check_output(self): 'LossScaling': self.prev_loss_scaling * self.decr_ratio,
if paddle.is_compiled_with_xpu(): 'OutGoodSteps': self.zero_steps,
place = paddle.XPUPlace(0) 'OutBadSteps': self.zero_steps
self.check_output_with_place(place) }
#self.check_output()
def test_check_output(self):
if paddle.is_compiled_with_xpu():
class TestUpdateLossScalingLayer(unittest.TestCase): place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def loss_scaling_check(self, scope=fluid.Scope()): #self.check_output()
a = fluid.data(name="a", shape=[1024, 1024], dtype='float32')
b = fluid.data(name="b", shape=[512, 128], dtype='float32') class TestUpdateLossScalingLayer(unittest.TestCase):
x = [a, b]
found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool') def loss_scaling_check(self, scope=fluid.Scope()):
prev_loss_scaling = fluid.data(name="prev_loss_scaling", a = fluid.data(name="a", shape=[1024, 1024], dtype='float32')
b = fluid.data(name="b", shape=[512, 128], dtype='float32')
x = [a, b]
found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool')
prev_loss_scaling = fluid.data(name="prev_loss_scaling",
shape=[1],
dtype='float32')
num_good_steps = fluid.data(name="num_good_steps",
shape=[1],
dtype='int32')
num_bad_steps = fluid.data(name="num_bad_steps",
shape=[1], shape=[1],
dtype='float32') dtype='int32')
num_good_steps = fluid.data(name="num_good_steps",
shape=[1], a_v = np.random.random([1024, 1024]).astype('float32')
dtype='int32') b_v = np.random.random([512, 128]).astype('float32')
num_bad_steps = fluid.data(name="num_bad_steps", found_inf_v = np.array([False]).astype('bool')
shape=[1], prev_loss_scaling_v = np.array([2048]).astype('float32')
dtype='int32') num_good_steps_v = np.array([999], dtype=np.int32)
num_bad_steps_v = np.array([1], dtype=np.int32)
a_v = np.random.random([1024, 1024]).astype('float32')
b_v = np.random.random([512, 128]).astype('float32') incr_every_n_steps = 1000
found_inf_v = np.array([False]).astype('bool') decr_every_n_nan_or_inf = 2
prev_loss_scaling_v = np.array([2048]).astype('float32') incr_ratio = 2
num_good_steps_v = np.array([999], dtype=np.int32) decr_ratio = 0.8
num_bad_steps_v = np.array([1], dtype=np.int32)
result = amp_nn.update_loss_scaling(x,
incr_every_n_steps = 1000 found_inf,
decr_every_n_nan_or_inf = 2 prev_loss_scaling,
incr_ratio = 2 num_good_steps,
decr_ratio = 0.8 num_bad_steps,
incr_every_n_steps,
result = amp_nn.update_loss_scaling(x, decr_every_n_nan_or_inf,
found_inf, incr_ratio,
prev_loss_scaling, decr_ratio,
num_good_steps, name="update_loss_scaling")
num_bad_steps,
incr_every_n_steps, place = fluid.XPUPlace(0)
decr_every_n_nan_or_inf, exe = fluid.Executor(place)
incr_ratio, with fluid.scope_guard(scope):
decr_ratio, exe.run(fluid.default_startup_program())
name="update_loss_scaling") result_v = exe.run(feed={
'a': a_v,
place = fluid.XPUPlace(0) 'b': b_v,
exe = fluid.Executor(place) 'found_inf': found_inf_v,
with fluid.scope_guard(scope): 'prev_loss_scaling': prev_loss_scaling_v,
exe.run(fluid.default_startup_program()) 'num_good_steps': num_good_steps_v,
result_v = exe.run(feed={ 'num_bad_steps': num_bad_steps_v
'a': a_v, },
'b': b_v, fetch_list=[
'found_inf': found_inf_v, result, x, found_inf, prev_loss_scaling,
'prev_loss_scaling': prev_loss_scaling_v, num_good_steps, num_bad_steps
'num_good_steps': num_good_steps_v, ])
'num_bad_steps': num_bad_steps_v assert np.array_equal(result_v[0], a_v)
}, assert np.array_equal(result_v[1], b_v)
fetch_list=[ assert np.array_equal(result_v[0], result_v[2])
result, x, found_inf, prev_loss_scaling, assert np.array_equal(result_v[1], result_v[3])
num_good_steps, num_bad_steps assert np.array_equal(result_v[4], found_inf_v)
]) assert np.array_equal(result_v[5], prev_loss_scaling_v * incr_ratio)
assert np.array_equal(result_v[0], a_v) assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v))
assert np.array_equal(result_v[1], b_v) assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v))
assert np.array_equal(result_v[0], result_v[2])
assert np.array_equal(result_v[1], result_v[3]) def loss_scaling_check_inf(self, use_cuda=True, scope=fluid.Scope()):
assert np.array_equal(result_v[4], found_inf_v) a = fluid.data(name="a", shape=[1024, 1024], dtype='float32')
assert np.array_equal(result_v[5], prev_loss_scaling_v * incr_ratio) b = fluid.data(name="b", shape=[512, 128], dtype='float32')
assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v)) x = [a, b]
assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v)) found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool')
prev_loss_scaling = fluid.data(name="prev_loss_scaling",
def loss_scaling_check_inf(self, use_cuda=True, scope=fluid.Scope()): shape=[1],
a = fluid.data(name="a", shape=[1024, 1024], dtype='float32') dtype='float32')
b = fluid.data(name="b", shape=[512, 128], dtype='float32') num_good_steps = fluid.data(name="num_good_steps",
x = [a, b] shape=[1],
found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool') dtype='int32')
prev_loss_scaling = fluid.data(name="prev_loss_scaling", num_bad_steps = fluid.data(name="num_bad_steps",
shape=[1], shape=[1],
dtype='float32') dtype='int32')
num_good_steps = fluid.data(name="num_good_steps",
shape=[1], a_v = np.random.random([1024, 1024]).astype('float32')
dtype='int32') b_v = np.random.random([512, 128]).astype('float32')
num_bad_steps = fluid.data(name="num_bad_steps", i = np.random.randint(0, 1024, 1)
shape=[1], j = np.random.randint(0, 1024, 1)
dtype='int32') a_v[i[0]][j[0]] = np.inf
found_inf_v = np.array([True]).astype('bool')
a_v = np.random.random([1024, 1024]).astype('float32') prev_loss_scaling_v = np.array([2048]).astype('float32')
b_v = np.random.random([512, 128]).astype('float32') num_good_steps_v = np.array([999], dtype=np.int32)
i = np.random.randint(0, 1024, 1) num_bad_steps_v = np.array([1], dtype=np.int32)
j = np.random.randint(0, 1024, 1)
a_v[i[0]][j[0]] = np.inf incr_every_n_steps = 1000
found_inf_v = np.array([True]).astype('bool') decr_every_n_nan_or_inf = 2
prev_loss_scaling_v = np.array([2048]).astype('float32') incr_ratio = 2
num_good_steps_v = np.array([999], dtype=np.int32) decr_ratio = 0.8
num_bad_steps_v = np.array([1], dtype=np.int32)
result = amp_nn.update_loss_scaling(x,
incr_every_n_steps = 1000 found_inf,
decr_every_n_nan_or_inf = 2 prev_loss_scaling,
incr_ratio = 2 num_good_steps,
decr_ratio = 0.8 num_bad_steps,
incr_every_n_steps,
result = amp_nn.update_loss_scaling(x, decr_every_n_nan_or_inf,
found_inf, incr_ratio,
prev_loss_scaling, decr_ratio,
num_good_steps, name="update_loss_scaling")
num_bad_steps,
incr_every_n_steps, place = fluid.XPUPlace(0)
decr_every_n_nan_or_inf, exe = fluid.Executor(place)
incr_ratio, with fluid.scope_guard(scope):
decr_ratio, exe.run(fluid.default_startup_program())
name="update_loss_scaling") result_v = exe.run(feed={
'a': a_v,
place = fluid.XPUPlace(0) 'b': b_v,
exe = fluid.Executor(place) 'found_inf': found_inf_v,
with fluid.scope_guard(scope): 'prev_loss_scaling': prev_loss_scaling_v,
exe.run(fluid.default_startup_program()) 'num_good_steps': num_good_steps_v,
result_v = exe.run(feed={ 'num_bad_steps': num_bad_steps_v
'a': a_v, },
'b': b_v, fetch_list=[
'found_inf': found_inf_v, result, x, found_inf, prev_loss_scaling,
'prev_loss_scaling': prev_loss_scaling_v, num_good_steps, num_bad_steps
'num_good_steps': num_good_steps_v, ])
'num_bad_steps': num_bad_steps_v assert np.array_equal(result_v[0], np.zeros_like(a_v))
}, assert np.array_equal(result_v[1], np.zeros_like(b_v))
fetch_list=[ assert np.array_equal(result_v[2], np.zeros_like(a_v))
result, x, found_inf, prev_loss_scaling, assert np.array_equal(result_v[3], np.zeros_like(b_v))
num_good_steps, num_bad_steps assert np.array_equal(result_v[4], found_inf_v)
]) assert np.array_equal(result_v[5], prev_loss_scaling_v * decr_ratio)
assert np.array_equal(result_v[0], np.zeros_like(a_v)) assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v))
assert np.array_equal(result_v[1], np.zeros_like(b_v)) assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v))
assert np.array_equal(result_v[2], np.zeros_like(a_v))
assert np.array_equal(result_v[3], np.zeros_like(b_v)) def test_loss_scaling(self):
assert np.array_equal(result_v[4], found_inf_v) main = fluid.Program()
assert np.array_equal(result_v[5], prev_loss_scaling_v * decr_ratio) startup = fluid.Program()
assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v)) with fluid.unique_name.guard():
assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v)) with fluid.program_guard(main, startup):
self.loss_scaling_check()
def test_loss_scaling(self):
main = fluid.Program() def test_loss_scaling_inf(self):
startup = fluid.Program() main = fluid.Program()
with fluid.unique_name.guard(): startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.unique_name.guard():
self.loss_scaling_check() with fluid.program_guard(main, startup):
self.loss_scaling_check_inf()
def test_loss_scaling_inf(self):
main = fluid.Program()
startup = fluid.Program() support_types = get_xpu_op_support_types('update_loss_scaling')
with fluid.unique_name.guard(): for stype in support_types:
with fluid.program_guard(main, startup): create_test_class(globals(), XPUTestUpdateLossScalingOp, stype)
self.loss_scaling_check_inf()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册