提交 43f3c380 编写于 作者: D dolphin8

remove multiplication from transpose

上级 cc4f55c8
......@@ -47,6 +47,7 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
size_t ndim = axis.size();
std::vector<int> xdim(ndim);
std::vector<int> xstride(ndim);
std::vector<int> xout(ndim);
for (int i = 0; i < ndim; i++) {
int j = ndim - 1 - i;
xdim[j] = input_x_dims[axis[i]];
......@@ -54,6 +55,7 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
for (int k = axis[i] + 1; k < ndim; k++) {
xstride[j] *= input_x_dims[k];
}
xout[j] = xstride[j] * xdim[j];
}
auto numel = input_x->numel();
......@@ -68,7 +70,7 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
ind[j + 1]++;
ind[j] = 0;
pind += xstride[j + 1];
pind -= xdim[j] * xstride[j];
pind -= xout[j];
} else {
break;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册