提交 85b839f0 编写于 作者: E emailweixu 提交者: Abhinav Arora

Fix l1_norm_op and squared_l2_norm_op for debug mode (#5560)

上级 b6c262e1
......@@ -29,7 +29,7 @@ class L1NormKernel : public framework::OpKernel<T> {
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto out = framework::EigenScalar<T>::From(*Out);
auto place = context.GetEigenDevice<Place>();
out.device(place) = x.abs().sum();
......
......@@ -29,7 +29,7 @@ class SquaredL2NormKernel : public framework::OpKernel<T> {
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto out = framework::EigenScalar<T>::From(*Out);
auto place = context.GetEigenDevice<Place>();
out.device(place) = x.square().sum();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册