未验证 提交 6b93ba0a 编写于 作者: S Sonder 提交者: GitHub

move prune_gate_by_capacity to phi (#55780)

* move prune_gate_by_capacity to phi

* fix

* fix registe info

* remove useless codes
上级 719b1ed3
......@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/prune_gate_by_capacity_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
......@@ -125,10 +126,3 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(prune_gate_by_capacity,
ops::PruneGateByCapacityOp,
ops::PruneGateByCapacityOpMaker);
PD_REGISTER_STRUCT_KERNEL(prune_gate_by_capacity,
CPU,
ALL_LAYOUT,
ops::PruneGateByCapacityCPUKernel,
int,
int64_t) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/prune_gate_by_capacity_kernel.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void PruneGateByCapacityKernel(const Context& dev_ctx,
const DenseTensor& gate_idx,
const DenseTensor& expert_count,
int64_t n_expert,
int64_t n_worker,
DenseTensor* new_gate_idx) {
PADDLE_THROW(phi::errors::Unimplemented(
"prune_gate_by_capacity is not supported on CPU."));
}
} // namespace phi
PD_REGISTER_KERNEL(prune_gate_by_capacity,
CPU,
ALL_LAYOUT,
phi::PruneGateByCapacityKernel,
int,
int64_t) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -11,25 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// The file has been adapted from the two files:
// https://github.com/laekov/fastmoe/blob/master/cuda/balancing.cu
// https://github.com/laekov/fastmoe/blob/master/cuda/balancing.cuh
// Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4
// We retain the following license from the original files:
// Copyright 2021, Jiaao He. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License").
#include "paddle/fluid/operators/prune_gate_by_capacity_op.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
DECLARE_bool(avoid_op_randomness);
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/prune_gate_by_capacity_kernel.h"
namespace paddle {
namespace operators {
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
......@@ -55,14 +45,14 @@ __global__ void prune_gate_by_capacity_kernel(const T1* gate_idx_data,
}
}
template <typename DeviceContext, typename T1>
template <typename Context, typename T1>
class PruneGateByCapacityFunctor {
public:
PruneGateByCapacityFunctor(const framework::ExecutionContext& context,
PruneGateByCapacityFunctor(const Context& dev_ctx,
const phi::DenseTensor* gate_idx,
phi::DenseTensor* expert_count_out,
T1* new_gate_idx_data)
: context_(context),
: dev_ctx_(dev_ctx),
gate_idx_(gate_idx),
expert_count_out_(expert_count_out),
new_gate_idx_data_(new_gate_idx_data) {}
......@@ -72,32 +62,31 @@ class PruneGateByCapacityFunctor {
auto batch_size = gate_idx_->numel();
auto* gate_idx_data = gate_idx_->data<T1>();
auto& dev_ctx = context_.template device_context<DeviceContext>();
auto* expert_count_out_data = expert_count_out_->data<T2>();
int blocks = NumBlocks(batch_size);
int threads = kNumCUDAThreads;
prune_gate_by_capacity_kernel<T1, T2>
<<<blocks, threads, 0, dev_ctx.stream()>>>(gate_idx_data,
<<<blocks, threads, 0, dev_ctx_.stream()>>>(gate_idx_data,
new_gate_idx_data_,
expert_count_out_data,
batch_size);
}
private:
const framework::ExecutionContext context_;
const Context& dev_ctx_;
const phi::DenseTensor* gate_idx_;
phi::DenseTensor* expert_count_out_;
T1* new_gate_idx_data_;
};
template <typename Visitor>
static void VisitDataType(phi::DataType type, Visitor visitor) {
static void VisitType(phi::DataType type, Visitor visitor) {
if (type == phi::DataType::INT64) {
visitor.template apply<int64_t>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(phi::errors::InvalidArgument(
"The received values gate_id type %s can not meet input requirements. "
"Because the given gate_id data type of operators must be "
"int64. Please input appropriate gate_id again! ",
......@@ -105,30 +94,30 @@ static void VisitDataType(phi::DataType type, Visitor visitor) {
}
}
template <typename T, typename DeviceContext>
class PruneGateByCapacityCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* gate_idx = context.Input<phi::DenseTensor>("GateIdx");
auto* expert_count = context.Input<phi::DenseTensor>("ExpertCount");
template <typename T, typename Context>
void PruneGateByCapacityKernel(const Context& dev_ctx,
const DenseTensor& gate_idx,
const DenseTensor& expert_count,
int64_t n_expert,
int64_t n_worker,
DenseTensor* new_gate_idx) {
auto* gate_idx_ptr = &gate_idx;
// auto* expert_count_out =
// context.Output<phi::DenseTensor>("ExpertCountOut");
auto* new_gate_idx = context.Output<phi::DenseTensor>("NewGateIdx");
auto* new_gate_idx_data = new_gate_idx->mutable_data<T>(context.GetPlace());
auto* new_gate_idx_data = dev_ctx.template Alloc<T>(new_gate_idx);
phi::DenseTensor expert_count_out;
framework::TensorCopy(*expert_count, context.GetPlace(), &expert_count_out);
PruneGateByCapacityFunctor<DeviceContext, T> functor(
context, gate_idx, &expert_count_out, new_gate_idx_data);
::paddle::operators::VisitDataType(expert_count->type(), functor);
}
};
phi::Copy(
dev_ctx, expert_count, dev_ctx.GetPlace(), false, &expert_count_out);
PruneGateByCapacityFunctor<Context, T> functor(
dev_ctx, gate_idx_ptr, &expert_count_out, new_gate_idx_data);
VisitType(expert_count.type(), functor);
}
} // namespace operators
} // namespace paddle
} // namespace phi
PD_REGISTER_STRUCT_KERNEL(prune_gate_by_capacity,
PD_REGISTER_KERNEL(prune_gate_by_capacity,
GPU,
ALL_LAYOUT,
ops::PruneGateByCapacityCUDAKernel,
phi::PruneGateByCapacityKernel,
int64_t) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,20 +14,16 @@
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace operators {
namespace phi {
template <typename T, typename DeviceContext>
class PruneGateByCapacityCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"prune_gate_by_capacity is not supported on CPU."));
}
};
template <typename T, typename Context>
void PruneGateByCapacityKernel(const Context& dev_ctx,
const DenseTensor& gate_idx,
const DenseTensor& expert_count,
int64_t n_expert,
int64_t n_worker,
DenseTensor* new_gate_idx);
} // namespace operators
} // namespace paddle
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PruneGateByCapacityOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("prune_gate_by_capacity",
{"GateIdx", "ExpertCount"},
{"n_expert", "n_worker"},
{"NewGateIdx"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(prune_gate_by_capacity,
phi::PruneGateByCapacityOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册