未验证 提交 6b93ba0a 编写于 作者: S Sonder 提交者: GitHub

move prune_gate_by_capacity to phi (#55780)

* move prune_gate_by_capacity to phi

* fix

* fix registe info

* remove useless codes
上级 719b1ed3
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/prune_gate_by_capacity_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -125,10 +126,3 @@ namespace ops = paddle::operators; ...@@ -125,10 +126,3 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(prune_gate_by_capacity, REGISTER_OP_WITHOUT_GRADIENT(prune_gate_by_capacity,
ops::PruneGateByCapacityOp, ops::PruneGateByCapacityOp,
ops::PruneGateByCapacityOpMaker); ops::PruneGateByCapacityOpMaker);
PD_REGISTER_STRUCT_KERNEL(prune_gate_by_capacity,
CPU,
ALL_LAYOUT,
ops::PruneGateByCapacityCPUKernel,
int,
int64_t) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/prune_gate_by_capacity_kernel.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void PruneGateByCapacityKernel(const Context& dev_ctx,
const DenseTensor& gate_idx,
const DenseTensor& expert_count,
int64_t n_expert,
int64_t n_worker,
DenseTensor* new_gate_idx) {
PADDLE_THROW(phi::errors::Unimplemented(
"prune_gate_by_capacity is not supported on CPU."));
}
} // namespace phi
PD_REGISTER_KERNEL(prune_gate_by_capacity,
CPU,
ALL_LAYOUT,
phi::PruneGateByCapacityKernel,
int,
int64_t) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -11,25 +11,15 @@ ...@@ -11,25 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//
// The file has been adapted from the two files:
// https://github.com/laekov/fastmoe/blob/master/cuda/balancing.cu
// https://github.com/laekov/fastmoe/blob/master/cuda/balancing.cuh
// Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4
// We retain the following license from the original files:
// Copyright 2021, Jiaao He. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License").
#include "paddle/fluid/operators/prune_gate_by_capacity_op.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace ops = paddle::operators; #include "paddle/phi/core/dense_tensor.h"
namespace plat = paddle::platform; #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
DECLARE_bool(avoid_op_randomness); #include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/prune_gate_by_capacity_kernel.h"
namespace paddle { namespace phi {
namespace operators {
static constexpr int kNumCUDAThreads = 512; static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096; static constexpr int kNumMaxinumNumBlocks = 4096;
...@@ -55,14 +45,14 @@ __global__ void prune_gate_by_capacity_kernel(const T1* gate_idx_data, ...@@ -55,14 +45,14 @@ __global__ void prune_gate_by_capacity_kernel(const T1* gate_idx_data,
} }
} }
template <typename DeviceContext, typename T1> template <typename Context, typename T1>
class PruneGateByCapacityFunctor { class PruneGateByCapacityFunctor {
public: public:
PruneGateByCapacityFunctor(const framework::ExecutionContext& context, PruneGateByCapacityFunctor(const Context& dev_ctx,
const phi::DenseTensor* gate_idx, const phi::DenseTensor* gate_idx,
phi::DenseTensor* expert_count_out, phi::DenseTensor* expert_count_out,
T1* new_gate_idx_data) T1* new_gate_idx_data)
: context_(context), : dev_ctx_(dev_ctx),
gate_idx_(gate_idx), gate_idx_(gate_idx),
expert_count_out_(expert_count_out), expert_count_out_(expert_count_out),
new_gate_idx_data_(new_gate_idx_data) {} new_gate_idx_data_(new_gate_idx_data) {}
...@@ -72,32 +62,31 @@ class PruneGateByCapacityFunctor { ...@@ -72,32 +62,31 @@ class PruneGateByCapacityFunctor {
auto batch_size = gate_idx_->numel(); auto batch_size = gate_idx_->numel();
auto* gate_idx_data = gate_idx_->data<T1>(); auto* gate_idx_data = gate_idx_->data<T1>();
auto& dev_ctx = context_.template device_context<DeviceContext>();
auto* expert_count_out_data = expert_count_out_->data<T2>(); auto* expert_count_out_data = expert_count_out_->data<T2>();
int blocks = NumBlocks(batch_size); int blocks = NumBlocks(batch_size);
int threads = kNumCUDAThreads; int threads = kNumCUDAThreads;
prune_gate_by_capacity_kernel<T1, T2> prune_gate_by_capacity_kernel<T1, T2>
<<<blocks, threads, 0, dev_ctx.stream()>>>(gate_idx_data, <<<blocks, threads, 0, dev_ctx_.stream()>>>(gate_idx_data,
new_gate_idx_data_, new_gate_idx_data_,
expert_count_out_data, expert_count_out_data,
batch_size); batch_size);
} }
private: private:
const framework::ExecutionContext context_; const Context& dev_ctx_;
const phi::DenseTensor* gate_idx_; const phi::DenseTensor* gate_idx_;
phi::DenseTensor* expert_count_out_; phi::DenseTensor* expert_count_out_;
T1* new_gate_idx_data_; T1* new_gate_idx_data_;
}; };
template <typename Visitor> template <typename Visitor>
static void VisitDataType(phi::DataType type, Visitor visitor) { static void VisitType(phi::DataType type, Visitor visitor) {
if (type == phi::DataType::INT64) { if (type == phi::DataType::INT64) {
visitor.template apply<int64_t>(); visitor.template apply<int64_t>();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"The received values gate_id type %s can not meet input requirements. " "The received values gate_id type %s can not meet input requirements. "
"Because the given gate_id data type of operators must be " "Because the given gate_id data type of operators must be "
"int64. Please input appropriate gate_id again! ", "int64. Please input appropriate gate_id again! ",
...@@ -105,30 +94,30 @@ static void VisitDataType(phi::DataType type, Visitor visitor) { ...@@ -105,30 +94,30 @@ static void VisitDataType(phi::DataType type, Visitor visitor) {
} }
} }
template <typename T, typename DeviceContext> template <typename T, typename Context>
class PruneGateByCapacityCUDAKernel : public framework::OpKernel<T> { void PruneGateByCapacityKernel(const Context& dev_ctx,
public: const DenseTensor& gate_idx,
void Compute(const framework::ExecutionContext& context) const override { const DenseTensor& expert_count,
auto* gate_idx = context.Input<phi::DenseTensor>("GateIdx"); int64_t n_expert,
auto* expert_count = context.Input<phi::DenseTensor>("ExpertCount"); int64_t n_worker,
// auto* expert_count_out = DenseTensor* new_gate_idx) {
// context.Output<phi::DenseTensor>("ExpertCountOut"); auto* gate_idx_ptr = &gate_idx;
auto* new_gate_idx = context.Output<phi::DenseTensor>("NewGateIdx"); // auto* expert_count_out =
auto* new_gate_idx_data = new_gate_idx->mutable_data<T>(context.GetPlace()); // context.Output<phi::DenseTensor>("ExpertCountOut");
auto* new_gate_idx_data = dev_ctx.template Alloc<T>(new_gate_idx);
phi::DenseTensor expert_count_out;
framework::TensorCopy(*expert_count, context.GetPlace(), &expert_count_out); phi::DenseTensor expert_count_out;
PruneGateByCapacityFunctor<DeviceContext, T> functor( phi::Copy(
context, gate_idx, &expert_count_out, new_gate_idx_data); dev_ctx, expert_count, dev_ctx.GetPlace(), false, &expert_count_out);
::paddle::operators::VisitDataType(expert_count->type(), functor); PruneGateByCapacityFunctor<Context, T> functor(
} dev_ctx, gate_idx_ptr, &expert_count_out, new_gate_idx_data);
}; VisitType(expert_count.type(), functor);
}
} // namespace operators } // namespace phi
} // namespace paddle
PD_REGISTER_STRUCT_KERNEL(prune_gate_by_capacity, PD_REGISTER_KERNEL(prune_gate_by_capacity,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
ops::PruneGateByCapacityCUDAKernel, phi::PruneGateByCapacityKernel,
int64_t) {} int64_t) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,20 +14,16 @@ ...@@ -14,20 +14,16 @@
#pragma once #pragma once
#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle { namespace phi {
namespace operators {
template <typename T, typename DeviceContext> template <typename T, typename Context>
class PruneGateByCapacityCPUKernel : public framework::OpKernel<T> { void PruneGateByCapacityKernel(const Context& dev_ctx,
public: const DenseTensor& gate_idx,
void Compute(const framework::ExecutionContext& context) const override { const DenseTensor& expert_count,
PADDLE_THROW(platform::errors::Unimplemented( int64_t n_expert,
"prune_gate_by_capacity is not supported on CPU.")); int64_t n_worker,
} DenseTensor* new_gate_idx);
};
} // namespace operators } // namespace phi
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PruneGateByCapacityOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("prune_gate_by_capacity",
{"GateIdx", "ExpertCount"},
{"n_expert", "n_worker"},
{"NewGateIdx"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(prune_gate_by_capacity,
phi::PruneGateByCapacityOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册