未验证 提交 2891ce64 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix pull&push dim (#34217)

上级 2c945737
......@@ -51,6 +51,19 @@ struct FeaturePushValue {
int slot;
float lr_g;
float mf_g[MF_DIM];
__device__ __forceinline__ FeaturePushValue
operator+(const FeaturePushValue& a) const {
FeaturePushValue out;
out.slot = a.slot;
out.show = a.show + show;
out.clk = a.clk + clk;
out.lr_g = a.lr_g + lr_g;
for (int i = 0; i < MF_DIM; ++i) {
out.mf_g[i] = a.mf_g[i] + mf_g[i];
}
return out;
}
};
} // end namespace framework
......
......@@ -50,11 +50,11 @@ __global__ void PullCopy(float** dest, const FeatureValue* src,
*(dest[x] + y * hidden + 2) = (src + i)->lr;
}
if ((src + i)->mf_size == 0 || *(keys[x] + y) == 0) {
for (int j = 0; j < 8; j++) {
for (int j = 0; j < hidden - 3; j++) {
*(dest[x] + y * hidden + 3 + j) = 0;
}
} else {
for (int j = 0; j < 8; j++) {
for (int j = 0; j < hidden - 3; j++) {
*(dest[x] + y * hidden + 3 + j) = (src + i)->mf[1 + j];
}
}
......@@ -99,7 +99,7 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len,
(dest + i)->show = *(src[x] + y * hidden);
(dest + i)->clk = *(src[x] + y * hidden + 1);
(dest + i)->lr_g = *(src[x] + y * hidden + 2) * -1. * bs;
for (int j = 0; j < 8; j++) {
for (int j = 0; j < hidden - 3; j++) {
(dest + i)->mf_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册