未验证 提交 628ff34b 编写于 作者: W whs 提交者: GitHub

Fix FPE of label smooth op (#35861)

上级 7ff226f0
...@@ -29,22 +29,23 @@ class LabelSmoothKernel : public framework::OpKernel<T> { ...@@ -29,22 +29,23 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist"); auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
auto label_dim = in_t->dims()[in_t->dims().size() - 1]; auto label_dim = in_t->dims()[in_t->dims().size() - 1];
out_t->mutable_data<T>(ctx.GetPlace()); out_t->mutable_data<T>(ctx.GetPlace());
if (label_dim != 0) {
auto epsilon = ctx.Attr<float>("epsilon"); auto epsilon = ctx.Attr<float>("epsilon");
auto out = framework::EigenVector<T>::Flatten(*out_t); auto out = framework::EigenVector<T>::Flatten(*out_t);
auto in = framework::EigenVector<T>::Flatten(*in_t); auto in = framework::EigenVector<T>::Flatten(*in_t);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device(); auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
if (dist_t) { if (dist_t) {
auto dist = framework::EigenVector<T>::Flatten(*dist_t); auto dist = framework::EigenVector<T>::Flatten(*dist_t);
out.device(dev) = out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon) * static_cast<T>(epsilon) *
dist.broadcast(Eigen::DSizes<int, 1>(in_t->numel() / label_dim)); dist.broadcast(Eigen::DSizes<int, 1>(
in_t->numel() / label_dim));
} else { } else {
out.device(dev) = static_cast<T>(1 - epsilon) * in + out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon / label_dim); static_cast<T>(epsilon / label_dim);
} }
} }
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -54,7 +55,8 @@ class LabelSmoothGradKernel : public framework::OpKernel<T> { ...@@ -54,7 +55,8 @@ class LabelSmoothGradKernel : public framework::OpKernel<T> {
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_in_t = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* d_in_t = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_in_t->mutable_data<T>(ctx.GetPlace()); d_in_t->mutable_data<T>(ctx.GetPlace());
auto d_out_dim = d_out_t->dims()[d_out_t->dims().size() - 1];
if (d_out_dim != 0) {
auto d_out = framework::EigenVector<T>::Flatten(*d_out_t); auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
auto d_in = framework::EigenVector<T>::Flatten(*d_in_t); auto d_in = framework::EigenVector<T>::Flatten(*d_in_t);
...@@ -62,6 +64,7 @@ class LabelSmoothGradKernel : public framework::OpKernel<T> { ...@@ -62,6 +64,7 @@ class LabelSmoothGradKernel : public framework::OpKernel<T> {
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device(); auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out; d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out;
} }
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册