提交 0ef9dc61 编写于 作者: X xuwei06

Fix comment for norm_op

上级 6ecbf083
...@@ -39,7 +39,7 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -39,7 +39,7 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker {
"M = C * H * W"); "M = C * H * W");
AddComment(R"DOC( AddComment(R"DOC(
"Input shape: $(N, C, H, W)$ "Input shape: $(N, C, H, W)$
Sclae shape: $(C, 1)$ Scale shape: $(C, 1)$
Output shape: $(N, C, H, W)$ Output shape: $(N, C, H, W)$
Where Where
forward forward
......
...@@ -66,7 +66,7 @@ class NormKernel : public framework::OpKernel<T> { ...@@ -66,7 +66,7 @@ class NormKernel : public framework::OpKernel<T> {
context.GetPlace()); context.GetPlace());
auto tmp = framework::EigenVector<T, Eigen::RowMajor, auto tmp = framework::EigenVector<T, Eigen::RowMajor,
Eigen::DenseIndex>::Flatten(tmp_tensor); Eigen::DenseIndex>::Flatten(tmp_tensor);
// get colsum and sqrt , inverse // get colsum and sqrt , inverse
auto dim = Eigen::array<int, 1>({{0}}); auto dim = Eigen::array<int, 1>({{0}});
tmp.device(*place) = x_square_batch_eigen.sum(dim); tmp.device(*place) = x_square_batch_eigen.sum(dim);
tmp.device(*place) = (tmp + epsilon).sqrt().inverse(); tmp.device(*place) = (tmp + epsilon).sqrt().inverse();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册