提交 4a11fdb4 编写于 作者: C chengduoZH

follow comments

上级 8bd75900
......@@ -155,11 +155,11 @@ struct CosSimDyFunctor<platform::CPUDeviceContext, T> {
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const {
for (size_t offset = 0; offset < rows; ++offset) {
auto xy_norm_prod = x_norm[offset] * y_norm[0];
auto dz_data = dz[offset];
auto z_data = z[offset];
auto* x_data = x + cols * offset;
for (size_t row_id = 0; row_id < rows; ++row_id) {
auto xy_norm_prod = x_norm[row_id] * y_norm[0];
auto dz_data = dz[row_id];
auto z_data = z[row_id];
auto* x_data = x + cols * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto y_norm_square = y_norm[0] * y_norm[0];
......
......@@ -25,12 +25,12 @@ __global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x,
const size_t rows, const size_t cols, T* dy) {
int grid_size = blockDim.x * gridDim.x;
T y_norm_data = y_norm[0];
for (int offset = blockIdx.x * blockDim.x + threadIdx.x; offset < rows;
offset += grid_size) {
T xy_norm_prod = x_norm[offset] * y_norm_data;
T dz_data = dz[offset];
T z_data = z[offset];
const T* x_data = x + cols * offset;
for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows;
row_id += grid_size) {
T xy_norm_prod = x_norm[row_id] * y_norm_data;
T dz_data = dz[row_id];
T z_data = z[row_id];
const T* x_data = x + cols * row_id;
T reciprocal_xy_norm_prod = 1 / xy_norm_prod;
T y_norm_square = y_norm_data * y_norm_data;
......
......@@ -32,11 +32,11 @@ struct CosSimFunctor {
z_(z),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t offset) const {
auto* x = x_ + cols_ * offset;
inline HOSTDEVICE void operator()(size_t row_id) const {
auto* x = x_ + cols_ * row_id;
T xx = 0, xy = 0, yy = 0;
if (same_row) {
auto* y = y_ + cols_ * offset;
auto* y = y_ + cols_ * row_id;
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
......@@ -47,9 +47,9 @@ struct CosSimFunctor {
}
xx = sqrt(xx);
yy = sqrt(yy);
y_norm_[offset] = yy;
x_norm_[offset] = xx;
z_[offset] = xy / (xx * yy);
y_norm_[row_id] = yy;
x_norm_[row_id] = xx;
z_[row_id] = xy / (xx * yy);
} else { // This can be wrote in a better way.
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
......@@ -61,9 +61,9 @@ struct CosSimFunctor {
}
xx = sqrt(xx);
yy = sqrt(yy);
if (offset == 0) y_norm_[0] = yy;
x_norm_[offset] = xx;
z_[offset] = xy / (xx * yy);
if (row_id == 0) y_norm_[0] = yy;
x_norm_[row_id] = xx;
z_[row_id] = xy / (xx * yy);
}
}
......@@ -125,15 +125,15 @@ struct CosSimGradFunctor {
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t offset) const {
auto x_norm_square = x_norm_[offset] * x_norm_[offset];
auto xy_norm_prod = x_norm_[offset] * y_norm_[offset];
auto dz = dz_[offset];
auto z = z_[offset];
inline HOSTDEVICE void operator()(size_t row_id) const {
auto x_norm_square = x_norm_[row_id] * x_norm_[row_id];
auto xy_norm_prod = x_norm_[row_id] * y_norm_[row_id];
auto dz = dz_[row_id];
auto z = z_[row_id];
auto* dx = dx_ + cols_ * offset;
auto* x = x_ + cols_ * offset;
auto* y = y_ + cols_ * offset;
auto* dx = dx_ + cols_ * row_id;
auto* x = x_ + cols_ * row_id;
auto* y = y_ + cols_ * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto reciprocal_x_norm_square = 1 / x_norm_square;
......@@ -166,14 +166,14 @@ struct CosSimDxFunctor {
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t offset) const {
auto xy_norm_prod = x_norm_[offset] * y_norm_[0];
auto dz = dz_[offset];
auto z = z_[offset];
auto* x = x_ + cols_ * offset;
inline HOSTDEVICE void operator()(size_t row_id) const {
auto xy_norm_prod = x_norm_[row_id] * y_norm_[0];
auto dz = dz_[row_id];
auto z = z_[row_id];
auto* x = x_ + cols_ * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto x_norm_square = x_norm_[offset] * x_norm_[offset];
auto* dx = dx_ + cols_ * offset;
auto x_norm_square = x_norm_[row_id] * x_norm_[row_id];
auto* dx = dx_ + cols_ * row_id;
auto reciprocal_x_norm_square = 1 / x_norm_square;
for (size_t i = 0; i < cols_; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册