未验证 提交 97a95526 编写于 作者: T Tongxin Bai 提交者: GitHub

Refactor `dot` op's CPU kernel for better performance (#32589)

* OP dot: refactor CPU kernels and get better loop performance.

* Minor fix on code format.

* Fixed minor errors.
上级 ce2bdb0a
......@@ -205,35 +205,25 @@ struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
auto const *x = tensor_x->data<T>(), *y = tensor_y->data<T>(),
*dz = tensor_dout->data<T>();
auto&& d = tensor_x->dims();
auto const N = tensor_x->numel();
auto const B = d[d.size() - 1];
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
auto* dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss;
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
auto* dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; i++) *dy++ = *x++ * ss;
}
}
#endif
......@@ -266,21 +256,20 @@ class DotKernel : public framework::OpKernel<T> {
out.device(dev) = (x * y).sum(Eigen::DSizes<int, 1>(1));
}
#else
const auto* data_x = tensor_x->data<T>();
const auto* data_y = tensor_y->data<T>();
auto* data_out = tensor_out->data<T>();
auto x_dims = tensor_x->dims();
auto step = x_dims[x_dims.size() - 1];
int size = static_cast<int>(framework::product(x_dims));
for (int ind = -1, j = 0; j < size; ++j) {
if (j % step == 0) {
++ind;
data_out[ind] = data_x[j] * data_y[j];
} else {
data_out[ind] += data_x[j] * data_y[j];
}
auto const *x = tensor_x->data<T>(), *x_ = &x[0];
auto const *y = tensor_y->data<T>(), *y_ = &y[0];
auto* z = tensor_out->data<T>();
// Loop over the total N elements of both operands while sum-reducing every
// B pairs along the way where B is the dimension of the least ordered axis
auto&& d = tensor_x->dims();
auto const N = tensor_x->numel();
auto const B = d[d.size() - 1];
for (int j = 0; j < N / B; j++) {
T ss = 0;
for (int i = 0; i < B; i++) ss += (*x_++) * (*y_++);
z[j] = ss;
}
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册