diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps index f73757902fef61584177b7f83134848c7589e04c..b44ea1807fd6595d11ad4070c57dbd042d158547 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps @@ -18,6 +18,7 @@ limitations under the License. */ #if defined(PADDLE_WITH_XPU_KP) #include #include "xpu/kernel/cluster_header.h" +#include "xpu/kernel/debug.h" // NOLINT #include "xpu/kernel/math.h" #include "xpu/kernel/simd.h" #endif @@ -91,7 +92,7 @@ __global__ void calc_shard_offset_kernel(T* idx, T* left, T* right, // read batch from GM will boost performance int read_len = min(len_per_loop, len - i); GM2LM(idx + i, local_idx, read_len * sizeof(T)); - for (int k = 0; k < read_len; k++) { + for (int k = 0; k < read_len - 1; k++) { if (local_idx[k] != local_idx[k + 1]) { int real_idx = i + k; local_right[local_idx[k]] = real_idx; @@ -102,7 +103,7 @@ __global__ void calc_shard_offset_kernel(T* idx, T* left, T* right, local_left[local_idx[i]] = i; } if (i + read_len == len) { - local_right[local_idx[len - 1]] = len - 1; + local_right[local_idx[read_len - 1]] = len - 1; } } // to be optimized: call LM2GM too frequently @@ -150,7 +151,7 @@ __global__ void fill_shard_key_kernel(KeyType* d_shard_keys, KeyType* d_keys, int thread_id = ncores * cluster_id() + cid; int nthreads = ncores * cluster_num(); const int buf_size = 400; - __local__ KeyType local_keys[buf_size]; + // __local__ KeyType local_keys[buf_size]; __local__ KeyType local_shard_keys[buf_size]; __local__ T local_idx[buf_size]; int len_per_loop = min(buf_size, roundup_div(len, nthreads)); @@ -158,10 +159,11 @@ __global__ void fill_shard_key_kernel(KeyType* d_shard_keys, KeyType* d_keys, i += nthreads * len_per_loop) { // read batch from GM will boost performance int read_len = min(len_per_loop, len - i); - GM2LM(d_keys + i, local_keys, read_len * sizeof(KeyType)); + // GM2LM(d_keys + i, local_keys, read_len * sizeof(KeyType)); GM2LM(idx + i, local_idx, read_len * sizeof(T)); for (int k = 0; k < read_len; k++) { - local_shard_keys[k] = local_keys[local_idx[k]]; + GM2LM(d_keys + local_idx[k], &local_shard_keys[k], 1 * sizeof(KeyType)); + // local_shard_keys[k] = local_keys[local_idx[k]]; } LM2GM(local_shard_keys, d_shard_keys + i, read_len * sizeof(KeyType)); } @@ -181,9 +183,9 @@ __global__ void fill_shard_grads_kernel(KeyType* d_shard_keys, KeyType* d_keys, int thread_id = ncores * cluster_id() + cid; int nthreads = ncores * cluster_num(); - const int buf_size = 100; - __local__ KeyType local_keys[buf_size]; - __local__ GradType local_grads[buf_size]; + const int buf_size = 50; + // __local__ KeyType local_keys[buf_size]; + // __local__ GradType local_grads[buf_size]; __local__ KeyType local_shard_keys[buf_size]; __local__ GradType local_shard_grads[buf_size]; __local__ T local_idx[buf_size]; @@ -193,12 +195,15 @@ __global__ void fill_shard_grads_kernel(KeyType* d_shard_keys, KeyType* d_keys, i += nthreads * len_per_loop) { // read batch from GM will boost performance int read_len = min(len_per_loop, len - i); - GM2LM(d_keys + i, local_keys, read_len * sizeof(KeyType)); - GM2LM(d_grads + i, local_grads, read_len * sizeof(GradType)); + // GM2LM(d_keys + i, local_keys, read_len * sizeof(KeyType)); + // GM2LM(d_grads + i, local_grads, read_len * sizeof(GradType)); GM2LM(idx + i, local_idx, read_len * sizeof(T)); for (int k = 0; k < read_len; k++) { - local_shard_keys[k] = local_keys[local_idx[k]]; - local_shard_grads[k] = local_grads[local_idx[k]]; + GM2LM(d_keys + local_idx[k], &local_shard_keys[k], 1 * sizeof(KeyType)); + GM2LM(d_grads + local_idx[k], &local_shard_grads[k], + 1 * sizeof(GradType)); + // local_shard_keys[k] = local_keys[local_idx[k]]; + // local_shard_grads[k] = local_grads[local_idx[k]]; } LM2GM(local_shard_keys, d_shard_keys + i, read_len * sizeof(KeyType)); LM2GM(local_shard_grads, d_shard_grads + i, read_len * sizeof(GradType)); @@ -227,9 +232,10 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals, GM2LM(idx + i, local_idx, read_len * sizeof(T)); GM2LM(d_shard_vals + i, local_shard_vals, read_len * sizeof(ValType)); for (int k = 0; k < read_len; k++) { - local_vals[local_idx[k]] = local_shard_vals[k]; + LM2GM(&local_shard_vals[k], d_vals + local_idx[k], 1 * sizeof(ValType)); + // local_vals[local_idx[k]] = local_shard_vals[k]; } - LM2GM(local_vals, d_vals + i, read_len * sizeof(ValType)); + // LM2GM(local_vals, d_vals + i, read_len * sizeof(ValType)); } } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps b/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps index 58b9f0f722f8cd2e0232821b43b73a92119fb611..ef6c70e624d4cf551a7bd15affd01e4c29d40587 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps @@ -28,9 +28,9 @@ limitations under the License. */ namespace paddle { namespace framework { -__global__ void PullCopy(float** dest, const FeatureValue* src, +__global__ void PullCopy(float* dest, const FeatureValue* src, const long long* len, int hidden, int slot_num, - int total_len, unsigned long long** keys) { + int total_len, unsigned long long* keys) { int cid = core_id(); int ncores = core_num(); if (cid >= ncores) { @@ -41,11 +41,21 @@ __global__ void PullCopy(float** dest, const FeatureValue* src, __local__ int64_t local_len[slot_num]; GM2LM(len, local_len, slot_num * sizeof(int64_t)); + __global_ptr__ unsigned long long* local_keys[slot_num]; + GM2LM(keys, local_keys, + slot_num * sizeof(__global_ptr__ unsigned long long*)); + + __global_ptr__ float* local_dest[slot_num]; + GM2LM(dest, local_dest, slot_num * sizeof(__global_ptr__ float*)); + + int read_len = 30; + for (int i = thread_id; i < slot_num; i += nthreads) { // max core local memory = 8KB // slot's max memory size = slot_len * sizeof(FeatureValue) int slot_len = i ? local_len[i] - local_len[i - 1] : local_len[0]; - int read_len = min(roundup_div(1024 * 8, sizeof(FeatureValue)), slot_len); + // int read_len = min(roundup_div(1024 * 8, sizeof(FeatureValue)), + // slot_len); int dest_len = i ? local_len[i - 1] : 0; __local__ FeatureValue local_slot_vals[read_len]; __local__ float local_dest_vals[read_len * hidden]; @@ -56,7 +66,8 @@ __global__ void PullCopy(float** dest, const FeatureValue* src, int real_read_len = min(read_len, slot_len - k); GM2LM(src + dest_len + k, local_slot_vals, real_read_len * sizeof(FeatureValue)); - GM2LM(keys[i] + k, local_slot_keys, real_read_len * sizeof(uint64_t)); + GM2LM(local_keys[i] + k, local_slot_keys, + real_read_len * sizeof(uint64_t)); for (int j = 0; j < real_read_len; j++) { if (local_slot_keys[j] == 0) { local_dest_vals[j * hidden] = 0; @@ -78,7 +89,7 @@ __global__ void PullCopy(float** dest, const FeatureValue* src, } } } - LM2GM(local_dest_vals, dest[i] + k * hidden, + LM2GM(local_dest_vals, local_dest[i] + k * hidden, real_read_len * hidden * sizeof(float)); } } @@ -120,7 +131,7 @@ __global__ void CopyKeysKernel(unsigned long long* src_keys, } } -__global__ void PushCopy(FeaturePushValue* dest, float** src, long long* len, +__global__ void PushCopy(FeaturePushValue* dest, float* src, long long* len, int hidden, int slot_num, int total_len, int bs, int* slot_vector) { int cid = core_id(); @@ -135,12 +146,16 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, long long* len, GM2LM(len, local_len, slot_num * sizeof(int64_t)); GM2LM(slot_vector, local_slot, slot_num * sizeof(int)); + __global_ptr__ float* local_src[slot_num]; + GM2LM(src, local_src, slot_num * sizeof(__global_ptr__ float*)); + for (int i = thread_id; i < slot_num; i += nthreads) { int slot_len = i ? local_len[i] - local_len[i - 1] : local_len[0]; // max core local memory = 8KB // slot's max memory size = slot_len * hidden * 8 - int read_len = min(roundup_div(1024, hidden), slot_len); + // int read_len = min(roundup_div(1024, hidden), slot_len); + int read_len = 40; int dest_len = i ? local_len[i - 1] : 0; __local__ float local_slot_grads[read_len * hidden]; __local__ FeaturePushValue local_dest_grads[read_len]; @@ -148,7 +163,7 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, long long* len, // copy read_len(length) of slots' grad to LM for (int k = 0; k < slot_len; k += read_len) { int real_read_len = min(read_len, slot_len - k); - GM2LM(src[i] + k * hidden, local_slot_grads, + GM2LM(local_src[i] + k * hidden, local_slot_grads, real_read_len * hidden * sizeof(float)); // copy from slots' grad to total grad for (int j = 0; j < real_read_len; j++) { @@ -181,14 +196,18 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, stream = static_cast(dev_ctx) ->x_context() ->xpu_stream; - float* buf_value = nullptr; - xpu_malloc(reinterpret_cast(&buf_value), + // float* buf_value = nullptr; + // xpu_malloc(reinterpret_cast(&buf_value), + // values.size() * sizeof(float*)); + // float** gpu_values = reinterpret_cast(&buf_value); + float* gpu_values = nullptr; + xpu_malloc(reinterpret_cast(&gpu_values), values.size() * sizeof(float*)); - float** gpu_values = reinterpret_cast(&buf_value); xpu_memcpy(gpu_values, values.data(), values.size() * sizeof(float*), XPU_HOST_TO_DEVICE); - unsigned long long** c_keys = (unsigned long long**)gpu_keys; + // unsigned long long** c_keys = (unsigned long long**)gpu_keys; + unsigned long long* c_keys = reinterpret_cast(gpu_keys); const long long* c_len = (const long long*)gpu_len; PullCopy<<<2, 64, stream>>>(gpu_values, total_values_gpu, c_len, hidden_size, slot_num, total_length, c_keys); @@ -230,20 +249,17 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, slot_lengths_lod[i] += slot_lengths_lod[i - 1]; } - float* buf_grad_value = nullptr; - int64_t* buf_length = nullptr; - int* buf_slot_vector = nullptr; + float* gpu_values = nullptr; + int64_t* gpu_len = nullptr; + int* d_slot_vector = nullptr; - xpu_malloc(reinterpret_cast(&buf_grad_value), + xpu_malloc(reinterpret_cast(&gpu_values), grad_values.size() * sizeof(float*)); - xpu_malloc(reinterpret_cast(&buf_length), + xpu_malloc(reinterpret_cast(&gpu_len), slot_lengths.size() * sizeof(int64_t)); - xpu_malloc(reinterpret_cast(&buf_slot_vector), + xpu_malloc(reinterpret_cast(&d_slot_vector), slot_lengths_lod.size() * sizeof(int)); - float** gpu_values = reinterpret_cast(&buf_grad_value); - int64_t* gpu_len = reinterpret_cast(buf_length); - int* d_slot_vector = reinterpret_cast(buf_slot_vector); xpu_memcpy(gpu_values, grad_values.data(), grad_values.size() * sizeof(float*), XPU_HOST_TO_DEVICE); xpu_memcpy(gpu_len, slot_lengths_lod.data(), diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc index 088366dbc8f6995cc90bf8c4d8334a7dcdb6a8a6..6ad22ff8b19eb58918d5728fd985551058c15338 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc @@ -11,27 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include - -#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/collective/c_sync_calc_stream_op.h" namespace paddle { namespace operators { -class CSyncCalcStreamOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override {} - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); - } -}; - class CSyncCalcStreamOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { @@ -45,53 +29,6 @@ Call calculation stream synchronization. } }; -template -class CSyncCalcStreamKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { -#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) - - auto place = ctx.GetPlace(); - auto dev_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - - platform::GpuStreamSync(dev_ctx->stream()); - -#elif defined(PADDLE_WITH_ASCEND_CL) && !defined(_WIN32) - auto place = ctx.GetPlace(); - PADDLE_ENFORCE_EQ(platform::is_npu_place(place), true, - platform::errors::PreconditionNotMet( - "Sync stream op can run on npu place only for now.")); - - auto dev_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - platform::NPUStreamSync(dev_ctx->stream()); - -#elif defined(PADDLE_WITH_CNCL) - auto place = ctx.GetPlace(); - PADDLE_ENFORCE_EQ(platform::is_mlu_place(place), true, - platform::errors::PreconditionNotMet( - "Sync stream op can run on mlu place only for now.")); - - auto dev_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - platform::MLUStreamSync(dev_ctx->stream()); -#elif defined(PADDLE_WITH_XPU_BKCL) - auto place = ctx.GetPlace(); - PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), true, - platform::errors::PreconditionNotMet( - "Sync stream op can run on xpu place only for now.")); - - auto dev_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - dev_ctx->Wait(); -#else - PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should compile with GPU.")); -#endif - } -}; - } // namespace operators } // namespace paddle @@ -105,5 +42,3 @@ REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); REGISTER_OP_NPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); REGISTER_OP_MLU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); - -REGISTER_OP_XPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.h b/paddle/fluid/operators/collective/c_sync_calc_stream_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b07367f801fa3fc147a36fc1c78b6b33609b99e6 --- /dev/null +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.h @@ -0,0 +1,83 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class CSyncCalcStreamOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } +}; + +template +class CSyncCalcStreamKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) + + auto place = ctx.GetPlace(); + auto dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + + platform::GpuStreamSync(dev_ctx->stream()); + +#elif defined(PADDLE_WITH_ASCEND_CL) && !defined(_WIN32) + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_npu_place(place), true, + platform::errors::PreconditionNotMet( + "Sync stream op can run on npu place only for now.")); + + auto dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + platform::NPUStreamSync(dev_ctx->stream()); + +#elif defined(PADDLE_WITH_CNCL) + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_mlu_place(place), true, + platform::errors::PreconditionNotMet( + "Sync stream op can run on mlu place only for now.")); + + auto dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + platform::MLUStreamSync(dev_ctx->stream()); +#elif defined(PADDLE_WITH_XPU_BKCL) + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), true, + platform::errors::PreconditionNotMet( + "Sync stream op can run on xpu place only for now.")); + + auto dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + dev_ctx->Wait(); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..04a83ea64f07637477556be135a09a02f7345d80 --- /dev/null +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc @@ -0,0 +1,20 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_sync_calc_stream_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel) diff --git a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h index ab68ebf3a5448c429d1466bce30f0f1f55af5121..778c18146d64d014d8384ed3c4b35148268b6e56 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h +++ b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h @@ -109,6 +109,8 @@ XPUOpMap& get_kp_ops() { {"reduce_any", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})}, {"pull_box_sparse", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"push_box_sparse", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_amax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_amin", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, };