提交 8fc1dd76 编写于 作者: D dolphin8

remove multiplication from transpose

上级 4cfd981a
...@@ -47,6 +47,7 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const { ...@@ -47,6 +47,7 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
size_t ndim = axis.size(); size_t ndim = axis.size();
std::vector<int> xdim(ndim); std::vector<int> xdim(ndim);
std::vector<int> xstride(ndim); std::vector<int> xstride(ndim);
std::vector<int> xout(ndim);
for (int i = 0; i < ndim; i++) { for (int i = 0; i < ndim; i++) {
int j = ndim - 1 - i; int j = ndim - 1 - i;
xdim[j] = input_x_dims[axis[i]]; xdim[j] = input_x_dims[axis[i]];
...@@ -54,6 +55,7 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const { ...@@ -54,6 +55,7 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
for (int k = axis[i] + 1; k < ndim; k++) { for (int k = axis[i] + 1; k < ndim; k++) {
xstride[j] *= input_x_dims[k]; xstride[j] *= input_x_dims[k];
} }
xout[j] = xstride[j] * xdim[j];
} }
auto numel = input_x->numel(); auto numel = input_x->numel();
...@@ -68,7 +70,7 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const { ...@@ -68,7 +70,7 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
ind[j + 1]++; ind[j + 1]++;
ind[j] = 0; ind[j] = 0;
pind += xstride[j + 1]; pind += xstride[j + 1];
pind -= xdim[j] * xstride[j]; pind -= xout[j];
} else { } else {
break; break;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册