未验证 提交 97a95526 编写于 作者: T Tongxin Bai 提交者: GitHub

Refactor `dot` op's CPU kernel for better performance (#32589)

* OP dot: refactor CPU kernels and get better loop performance.

* Minor fix on code format.

* Fixed minor errors.
上级 ce2bdb0a
...@@ -205,35 +205,25 @@ struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> { ...@@ -205,35 +205,25 @@ struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
} }
} }
#else #else
const auto* data_dout = tensor_dout->data<T>(); auto const *x = tensor_x->data<T>(), *y = tensor_y->data<T>(),
*dz = tensor_dout->data<T>();
auto&& d = tensor_x->dims();
auto const N = tensor_x->numel();
auto const B = d[d.size() - 1];
if (tensor_dx) { if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace()); auto* dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>(); for (auto j = 0; j < N / B; ++j) {
const framework::DDim& dim = tensor_x->dims(); auto const ss = dz[j];
size_t N = static_cast<size_t>(framework::product(dim)); for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss;
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
} }
} }
if (tensor_dy) { if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace()); auto* dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>(); for (auto j = 0; j < N / B; ++j) {
const framework::DDim& dim = tensor_y->dims(); auto const ss = dz[j];
size_t N = static_cast<size_t>(framework::product(dim)); for (auto i = 0; i < B; i++) *dy++ = *x++ * ss;
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
} }
} }
#endif #endif
...@@ -266,21 +256,20 @@ class DotKernel : public framework::OpKernel<T> { ...@@ -266,21 +256,20 @@ class DotKernel : public framework::OpKernel<T> {
out.device(dev) = (x * y).sum(Eigen::DSizes<int, 1>(1)); out.device(dev) = (x * y).sum(Eigen::DSizes<int, 1>(1));
} }
#else #else
const auto* data_x = tensor_x->data<T>(); auto const *x = tensor_x->data<T>(), *x_ = &x[0];
const auto* data_y = tensor_y->data<T>(); auto const *y = tensor_y->data<T>(), *y_ = &y[0];
auto* data_out = tensor_out->data<T>(); auto* z = tensor_out->data<T>();
auto x_dims = tensor_x->dims(); // Loop over the total N elements of both operands while sum-reducing every
auto step = x_dims[x_dims.size() - 1]; // B pairs along the way where B is the dimension of the least ordered axis
int size = static_cast<int>(framework::product(x_dims)); auto&& d = tensor_x->dims();
auto const N = tensor_x->numel();
for (int ind = -1, j = 0; j < size; ++j) { auto const B = d[d.size() - 1];
if (j % step == 0) {
++ind; for (int j = 0; j < N / B; j++) {
data_out[ind] = data_x[j] * data_y[j]; T ss = 0;
} else { for (int i = 0; i < B; i++) ss += (*x_++) * (*y_++);
data_out[ind] += data_x[j] * data_y[j]; z[j] = ss;
}
} }
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册