未验证 提交 fecbc958 编写于 作者: Q QingshuChen 提交者: GitHub

add some fp16 op for kunlun resnet50 model (#44672)

* add some fp16 op for kunlun resnet50 model
*test=kunlun

* tmp
*test=kunlun
上级 a9919903
...@@ -23,6 +23,8 @@ using Tensor = framework::Tensor; ...@@ -23,6 +23,8 @@ using Tensor = framework::Tensor;
template <typename T> template <typename T>
class ResNetUnitXPUKernel : public framework::OpKernel<T> { class ResNetUnitXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -63,9 +65,12 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> { ...@@ -63,9 +65,12 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> {
std::string act_type = ctx.Attr<std::string>("act_type"); std::string act_type = ctx.Attr<std::string>("act_type");
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
std::vector<const T *> x_list = {input_x->data<T>()}; std::vector<const XPUType *> x_list = {
std::vector<const T *> w_list = {filter_x->data<T>()}; reinterpret_cast<const XPUType *>(input_x->data<T>())};
std::vector<T *> conv_y_list = {conv_out_x->mutable_data<T>(place)}; std::vector<const XPUType *> w_list = {
reinterpret_cast<const XPUType *>(filter_x->data<T>())};
std::vector<XPUType *> conv_y_list = {
reinterpret_cast<XPUType *>(conv_out_x->mutable_data<T>(place))};
std::vector<std::vector<int>> x_shape_list = { std::vector<std::vector<int>> x_shape_list = {
phi::vectorize<int>(input_x->dims())}; phi::vectorize<int>(input_x->dims())};
...@@ -107,9 +112,10 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> { ...@@ -107,9 +112,10 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> {
Tensor *running_mean_z = ctx.Output<Tensor>("RunningMeanZ"); Tensor *running_mean_z = ctx.Output<Tensor>("RunningMeanZ");
Tensor *running_var_z = ctx.Output<Tensor>("RunningVarZ"); Tensor *running_var_z = ctx.Output<Tensor>("RunningVarZ");
x_list.push_back(input_z->data<T>()); x_list.push_back(reinterpret_cast<const XPUType *>(input_z->data<T>()));
w_list.push_back(filter_z->data<T>()); w_list.push_back(reinterpret_cast<const XPUType *>(filter_z->data<T>()));
conv_y_list.push_back(conv_out_z->mutable_data<T>(place)); conv_y_list.push_back(
reinterpret_cast<XPUType *>(conv_out_z->mutable_data<T>(place)));
x_shape_list.push_back(phi::vectorize<int>(input_z->dims())); x_shape_list.push_back(phi::vectorize<int>(input_z->dims()));
...@@ -133,17 +139,17 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> { ...@@ -133,17 +139,17 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> {
if (fuse_add) { if (fuse_add) {
const Tensor *input_z = ctx.Input<Tensor>("Z"); const Tensor *input_z = ctx.Input<Tensor>("Z");
auto input_z_shape = phi::vectorize<int>(input_z->dims()); auto input_z_shape = phi::vectorize<int>(input_z->dims());
x_list.push_back(input_z->data<T>()); x_list.push_back(reinterpret_cast<const XPUType *>(input_z->data<T>()));
x_shape_list.push_back(input_z_shape); x_shape_list.push_back(input_z_shape);
x_maxlist.push_back(nullptr); x_maxlist.push_back(nullptr);
} }
} }
int r = xpu::resnet_unit_fusion<T, T, T, int16_t>( int r = xpu::resnet_unit_fusion<XPUType, XPUType, XPUType, int16_t>(
dev_ctx.x_context(), dev_ctx.x_context(),
x_list, x_list,
w_list, w_list,
conv_y_list, conv_y_list,
output->mutable_data<T>(place), reinterpret_cast<XPUType *>(output->mutable_data<T>(place)),
x_shape_list, x_shape_list,
filter_x_shape[0], filter_x_shape[0],
ksize_list, ksize_list,
...@@ -172,6 +178,8 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> { ...@@ -172,6 +178,8 @@ class ResNetUnitXPUKernel : public framework::OpKernel<T> {
template <typename T> template <typename T>
class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -208,11 +216,16 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { ...@@ -208,11 +216,16 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
std::vector<const T *> x_list = {x->data<T>()}; std::vector<const XPUType *> x_list = {
std::vector<const T *> w_list = {filter_x->data<T>()}; reinterpret_cast<const XPUType *>(x->data<T>())};
std::vector<const T *> conv_y_list = {conv_out_x->data<T>()}; std::vector<const XPUType *> w_list = {
std::vector<T *> dx_list = {x_grad->mutable_data<T>(place)}; reinterpret_cast<const XPUType *>(filter_x->data<T>())};
std::vector<T *> dw_list = {filter_x_grad->mutable_data<T>(place)}; std::vector<const XPUType *> conv_y_list = {
reinterpret_cast<const XPUType *>(conv_out_x->data<T>())};
std::vector<XPUType *> dx_list = {
reinterpret_cast<XPUType *>(x_grad->mutable_data<T>(place))};
std::vector<XPUType *> dw_list = {
reinterpret_cast<XPUType *>(filter_x_grad->mutable_data<T>(place))};
std::vector<std::vector<int>> x_shape_list = { std::vector<std::vector<int>> x_shape_list = {
phi::vectorize<int>(x->dims())}; phi::vectorize<int>(x->dims())};
...@@ -262,11 +275,14 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { ...@@ -262,11 +275,14 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
Tensor *scale_z_grad = Tensor *scale_z_grad =
ctx.Output<Tensor>(framework::GradVarName("ScaleZ")); ctx.Output<Tensor>(framework::GradVarName("ScaleZ"));
Tensor *bias_z_grad = ctx.Output<Tensor>(framework::GradVarName("BiasZ")); Tensor *bias_z_grad = ctx.Output<Tensor>(framework::GradVarName("BiasZ"));
x_list.push_back(z->data<T>()); x_list.push_back(reinterpret_cast<const XPUType *>(z->data<T>()));
w_list.push_back(filter_z->data<T>()); w_list.push_back(reinterpret_cast<const XPUType *>(filter_z->data<T>()));
conv_y_list.push_back(conv_out_z->data<T>()); conv_y_list.push_back(
dx_list.push_back(z_grad->mutable_data<T>(place)); reinterpret_cast<const XPUType *>(conv_out_z->data<T>()));
dw_list.push_back(filter_z_grad->mutable_data<T>(place)); dx_list.push_back(
reinterpret_cast<XPUType *>(z_grad->mutable_data<T>(place)));
dw_list.push_back(
reinterpret_cast<XPUType *>(filter_z_grad->mutable_data<T>(place)));
x_shape_list.push_back(phi::vectorize<int>(z->dims())); x_shape_list.push_back(phi::vectorize<int>(z->dims()));
auto filter_z_shape = phi::vectorize<int>(filter_z->dims()); auto filter_z_shape = phi::vectorize<int>(filter_z->dims());
...@@ -288,16 +304,17 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { ...@@ -288,16 +304,17 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
} else { } else {
if (fuse_add) { if (fuse_add) {
auto z_grad = ctx.Output<Tensor>(framework::GradVarName("Z")); auto z_grad = ctx.Output<Tensor>(framework::GradVarName("Z"));
dx_list.push_back(z_grad->mutable_data<T>(place)); dx_list.push_back(
reinterpret_cast<XPUType *>(z_grad->mutable_data<T>(place)));
} }
} }
int r = int r = xpu::resnet_unit_grad_fusion<XPUType, XPUType, XPUType, int16_t>(
xpu::resnet_unit_grad_fusion<T, T, T, int16_t>(dev_ctx.x_context(), dev_ctx.x_context(),
x_list, x_list,
w_list, w_list,
y_grad->data<T>(), reinterpret_cast<const XPUType *>(y_grad->data<T>()),
output->data<T>(), reinterpret_cast<const XPUType *>(output->data<T>()),
conv_y_list, conv_y_list,
dx_list, dx_list,
dw_list, dw_list,
...@@ -329,5 +346,9 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> { ...@@ -329,5 +346,9 @@ class ResNetUnitGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(resnet_unit, ops::ResNetUnitXPUKernel<float>); REGISTER_OP_XPU_KERNEL(resnet_unit,
REGISTER_OP_XPU_KERNEL(resnet_unit_grad, ops::ResNetUnitGradXPUKernel<float>); ops::ResNetUnitXPUKernel<plat::float16>,
ops::ResNetUnitXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(resnet_unit_grad,
ops::ResNetUnitGradXPUKernel<plat::float16>,
ops::ResNetUnitGradXPUKernel<float>);
...@@ -22,6 +22,8 @@ namespace operators { ...@@ -22,6 +22,8 @@ namespace operators {
template <typename T> template <typename T>
class LarsMomentumOpXPUKernel : public framework::OpKernel<T> { class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
bool multi_precision = ctx.Attr<bool>("multi_precision"); bool multi_precision = ctx.Attr<bool>("multi_precision");
...@@ -35,14 +37,14 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> { ...@@ -35,14 +37,14 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam"); auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam");
auto master_param_out = auto master_param_out =
ctx.MultiOutput<framework::LoDTensor>("MasterParamOut"); ctx.MultiOutput<framework::LoDTensor>("MasterParamOut");
T mu = static_cast<T>(ctx.Attr<float>("mu")); float mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff"); float lars_coeff = ctx.Attr<float>("lars_coeff");
T epsilon = ctx.Attr<float>("epsilon"); float epsilon = ctx.Attr<float>("epsilon");
T rescale_grad = ctx.Attr<float>("rescale_grad"); float rescale_grad = ctx.Attr<float>("rescale_grad");
std::vector<T*> param_list; std::vector<XPUType*> param_list;
std::vector<T*> grad_list; std::vector<XPUType*> grad_list;
std::vector<T*> param_out_list; std::vector<XPUType*> param_out_list;
std::vector<float*> velocity_list; std::vector<float*> velocity_list;
std::vector<float*> velocity_out_list; std::vector<float*> velocity_out_list;
std::vector<float*> lrs; std::vector<float*> lrs;
...@@ -52,9 +54,12 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> { ...@@ -52,9 +54,12 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
std::vector<float*> master_param_out_list; std::vector<float*> master_param_out_list;
int op_num = param.size(); int op_num = param.size();
for (int i = 0; i < op_num; ++i) { for (int i = 0; i < op_num; ++i) {
param_list.push_back(const_cast<T*>(param[i]->data<T>())); param_list.push_back(
grad_list.push_back(const_cast<T*>(grad[i]->data<T>())); reinterpret_cast<XPUType*>(const_cast<T*>((param[i]->data<T>()))));
param_out_list.push_back(param_out[i]->mutable_data<T>(ctx.GetPlace())); grad_list.push_back(
reinterpret_cast<XPUType*>(const_cast<T*>(grad[i]->data<T>())));
param_out_list.push_back(reinterpret_cast<XPUType*>(
param_out[i]->mutable_data<T>(ctx.GetPlace())));
velocity_list.push_back(const_cast<float*>(velocity[i]->data<float>())); velocity_list.push_back(const_cast<float*>(velocity[i]->data<float>()));
velocity_out_list.push_back( velocity_out_list.push_back(
velocity_out[i]->mutable_data<float>(ctx.GetPlace())); velocity_out[i]->mutable_data<float>(ctx.GetPlace()));
...@@ -111,5 +116,7 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> { ...@@ -111,5 +116,7 @@ class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(lars_momentum, ops::LarsMomentumOpXPUKernel<float>); REGISTER_OP_XPU_KERNEL(lars_momentum,
ops::LarsMomentumOpXPUKernel<paddle::platform::float16>,
ops::LarsMomentumOpXPUKernel<float>);
#endif #endif
...@@ -231,7 +231,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -231,7 +231,9 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"generate_proposals_v2", {"generate_proposals_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"grad_add", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"grad_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"greater_equal", {"greater_equal",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
...@@ -254,9 +256,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -254,9 +256,8 @@ XPUOpMap& get_kl2_ops() {
{"label_smooth", {"label_smooth",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lars_momentum", {"lars_momentum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"layer_norm_grad", pOpKernelType(vartype::FP16, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm_grad", {"layer_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
...@@ -380,9 +381,12 @@ XPUOpMap& get_kl2_ops() { ...@@ -380,9 +381,12 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"resnet_unit", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"resnet_unit",
XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"resnet_unit_grad", {"resnet_unit_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
...@@ -502,6 +506,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -502,6 +506,9 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"top_k_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"top_k_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"update_loss_scaling",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"unsqueeze2_grad", {"unsqueeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
......
...@@ -24,13 +24,15 @@ void GradAddXPUKernel(const Context& dev_ctx, ...@@ -24,13 +24,15 @@ void GradAddXPUKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
auto x_shape = phi::vectorize<int>(x.dims()); auto x_shape = phi::vectorize<int>(x.dims());
auto y_shape = phi::vectorize<int>(y.dims()); auto y_shape = phi::vectorize<int>(y.dims());
int r = xpu::broadcast_add(dev_ctx.x_context(), int r = xpu::broadcast_add(dev_ctx.x_context(),
x.data<T>(), reinterpret_cast<const XPUType*>(x.data<T>()),
y.data<T>(), reinterpret_cast<const XPUType*>(y.data<T>()),
out->data<T>(), reinterpret_cast<XPUType*>(out->data<T>()),
x_shape, x_shape,
y_shape); y_shape);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
...@@ -38,4 +40,9 @@ void GradAddXPUKernel(const Context& dev_ctx, ...@@ -38,4 +40,9 @@ void GradAddXPUKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(grad_add, XPU, ALL_LAYOUT, phi::GradAddXPUKernel, float) {} PD_REGISTER_KERNEL(grad_add,
XPU,
ALL_LAYOUT,
phi::GradAddXPUKernel,
phi::dtype::float16,
float) {}
...@@ -26,6 +26,7 @@ void LogSoftmaxGradKernel(const Context& dev_ctx, ...@@ -26,6 +26,7 @@ void LogSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int axis, int axis,
DenseTensor* x_grad) { DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
const int rank = out.dims().size(); const int rank = out.dims().size();
axis = funcs::CanonicalAxis(axis, rank); axis = funcs::CanonicalAxis(axis, rank);
...@@ -40,22 +41,27 @@ void LogSoftmaxGradKernel(const Context& dev_ctx, ...@@ -40,22 +41,27 @@ void LogSoftmaxGradKernel(const Context& dev_ctx,
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
tmp2_ptr, nullptr, phi::errors::External("no enough memory in xpu")); tmp2_ptr, nullptr, phi::errors::External("no enough memory in xpu"));
int r = int r = xpu::exp<XPUType>(dev_ctx.x_context(),
xpu::exp(dev_ctx.x_context(), out.data<T>(), tmp_ptr, out_grad.numel()); reinterpret_cast<const XPUType*>(out.data<T>()),
reinterpret_cast<XPUType*>(tmp_ptr),
out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp");
r = xpu::reciprocal( r = xpu::reciprocal<XPUType>(dev_ctx.x_context(),
dev_ctx.x_context(), tmp_ptr, tmp2_ptr, out_grad.numel()); reinterpret_cast<const XPUType*>(tmp_ptr),
reinterpret_cast<XPUType*>(tmp2_ptr),
out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reciprocal"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "reciprocal");
r = xpu::mul(dev_ctx.x_context(), r = xpu::mul<XPUType>(dev_ctx.x_context(),
tmp2_ptr, reinterpret_cast<const XPUType*>(tmp2_ptr),
out_grad.data<T>(), reinterpret_cast<const XPUType*>(out_grad.data<T>()),
tmp2_ptr, reinterpret_cast<XPUType*>(tmp2_ptr),
out_grad.numel()); out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
r = xpu::softmax_grad(dev_ctx.x_context(), r = xpu::softmax_grad<XPUType>(
tmp_ptr, dev_ctx.x_context(),
tmp2_ptr, reinterpret_cast<const XPUType*>(tmp_ptr),
x_grad->data<T>(), reinterpret_cast<const XPUType*>(tmp2_ptr),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
out_shape, out_shape,
axis); axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad");
......
...@@ -25,6 +25,7 @@ void LogSoftmaxKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void LogSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int axis, int axis,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
const int rank = x.dims().size(); const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank); axis = funcs::CanonicalAxis(axis, rank);
...@@ -32,11 +33,16 @@ void LogSoftmaxKernel(const Context& dev_ctx, ...@@ -32,11 +33,16 @@ void LogSoftmaxKernel(const Context& dev_ctx,
auto x_shape = phi::vectorize<int>(x.dims()); auto x_shape = phi::vectorize<int>(x.dims());
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (axis < 0) axis += rank; if (axis < 0) axis += rank;
int r = xpu::softmax<T>( int r = xpu::softmax<XPUType>(dev_ctx.x_context(),
dev_ctx.x_context(), x.data<T>(), out->data<T>(), x_shape, axis); reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x_shape,
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
r = xpu::log<T>( r = xpu::log<XPUType>(dev_ctx.x_context(),
dev_ctx.x_context(), out->data<T>(), out->data<T>(), out->numel()); reinterpret_cast<const XPUType*>(out->data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "log"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "log");
} }
} }
......
...@@ -23,10 +23,19 @@ import paddle ...@@ -23,10 +23,19 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static() paddle.enable_static()
class TestUpdateLossScalingOp(XPUOpTest): class XPUTestUpdateLossScalingOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = "update_loss_scaling"
self.use_dynamic_create_class = False
class TestUpdateLossScalingOp(XPUOpTest):
def setUp(self): def setUp(self):
self.op_type = "update_loss_scaling" self.op_type = "update_loss_scaling"
...@@ -69,8 +78,7 @@ class TestUpdateLossScalingOp(XPUOpTest): ...@@ -69,8 +78,7 @@ class TestUpdateLossScalingOp(XPUOpTest):
place = paddle.XPUPlace(0) place = paddle.XPUPlace(0)
self.check_output_with_place(place, no_check_set=['Out']) self.check_output_with_place(place, no_check_set=['Out'])
class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
def setUp(self): def setUp(self):
self.op_type = "update_loss_scaling" self.op_type = "update_loss_scaling"
...@@ -102,8 +110,7 @@ class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): ...@@ -102,8 +110,7 @@ class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
self.check_output_with_place(place) self.check_output_with_place(place)
#self.check_output() #self.check_output()
class TestUpdateLossScalingLayer(unittest.TestCase):
class TestUpdateLossScalingLayer(unittest.TestCase):
def loss_scaling_check(self, scope=fluid.Scope()): def loss_scaling_check(self, scope=fluid.Scope()):
a = fluid.data(name="a", shape=[1024, 1024], dtype='float32') a = fluid.data(name="a", shape=[1024, 1024], dtype='float32')
...@@ -249,5 +256,9 @@ class TestUpdateLossScalingLayer(unittest.TestCase): ...@@ -249,5 +256,9 @@ class TestUpdateLossScalingLayer(unittest.TestCase):
self.loss_scaling_check_inf() self.loss_scaling_check_inf()
support_types = get_xpu_op_support_types('update_loss_scaling')
for stype in support_types:
create_test_class(globals(), XPUTestUpdateLossScalingOp, stype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册