From 2891ce6412b867bf94de59ddaa6a69afdd96512c Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Mon, 19 Jul 2021 11:20:55 +0800 Subject: [PATCH] fix pull&push dim (#34217) --- .../fluid/framework/fleet/heter_ps/feature_value.h | 13 +++++++++++++ paddle/fluid/framework/fleet/ps_gpu_wrapper.cu | 6 +++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index f6c4d47ce2d..db11fca109b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -51,6 +51,19 @@ struct FeaturePushValue { int slot; float lr_g; float mf_g[MF_DIM]; + + __device__ __forceinline__ FeaturePushValue + operator+(const FeaturePushValue& a) const { + FeaturePushValue out; + out.slot = a.slot; + out.show = a.show + show; + out.clk = a.clk + clk; + out.lr_g = a.lr_g + lr_g; + for (int i = 0; i < MF_DIM; ++i) { + out.mf_g[i] = a.mf_g[i] + mf_g[i]; + } + return out; + } }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 2bf564d3f76..5ff41d81801 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -50,11 +50,11 @@ __global__ void PullCopy(float** dest, const FeatureValue* src, *(dest[x] + y * hidden + 2) = (src + i)->lr; } if ((src + i)->mf_size == 0 || *(keys[x] + y) == 0) { - for (int j = 0; j < 8; j++) { + for (int j = 0; j < hidden - 3; j++) { *(dest[x] + y * hidden + 3 + j) = 0; } } else { - for (int j = 0; j < 8; j++) { + for (int j = 0; j < hidden - 3; j++) { *(dest[x] + y * hidden + 3 + j) = (src + i)->mf[1 + j]; } } @@ -99,7 +99,7 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len, (dest + i)->show = *(src[x] + y * hidden); (dest + i)->clk = *(src[x] + y * hidden + 1); (dest + i)->lr_g = *(src[x] + y * hidden + 2) * -1. * bs; - for (int j = 0; j < 8; j++) { + for (int j = 0; j < hidden - 3; j++) { (dest + i)->mf_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs; } } -- GitLab