未验证 提交 628ff34b 编写于 作者: W whs 提交者: GitHub

Fix FPE of label smooth op (#35861)

上级 7ff226f0
......@@ -29,20 +29,21 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
auto label_dim = in_t->dims()[in_t->dims().size() - 1];
out_t->mutable_data<T>(ctx.GetPlace());
auto epsilon = ctx.Attr<float>("epsilon");
auto out = framework::EigenVector<T>::Flatten(*out_t);
auto in = framework::EigenVector<T>::Flatten(*in_t);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
if (dist_t) {
auto dist = framework::EigenVector<T>::Flatten(*dist_t);
out.device(dev) =
static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon) *
dist.broadcast(Eigen::DSizes<int, 1>(in_t->numel() / label_dim));
} else {
out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon / label_dim);
if (label_dim != 0) {
auto epsilon = ctx.Attr<float>("epsilon");
auto out = framework::EigenVector<T>::Flatten(*out_t);
auto in = framework::EigenVector<T>::Flatten(*in_t);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
if (dist_t) {
auto dist = framework::EigenVector<T>::Flatten(*dist_t);
out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon) *
dist.broadcast(Eigen::DSizes<int, 1>(
in_t->numel() / label_dim));
} else {
out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon / label_dim);
}
}
}
};
......@@ -54,13 +55,15 @@ class LabelSmoothGradKernel : public framework::OpKernel<T> {
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_in_t = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_in_t->mutable_data<T>(ctx.GetPlace());
auto d_out_dim = d_out_t->dims()[d_out_t->dims().size() - 1];
if (d_out_dim != 0) {
auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
auto d_in = framework::EigenVector<T>::Flatten(*d_in_t);
auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
auto d_in = framework::EigenVector<T>::Flatten(*d_in_t);
auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out;
auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out;
}
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册