未验证 提交 c27bd049 编写于 作者: R Ruibin Cheung 提交者: GitHub

[Fluid] NO.12 Migrate number_count to PHI (#56128)

* [Fluid] Migrate number_count to PHI

* fix out alloc

* fix ut (add python_api)
上级 e1eb52e1
......@@ -61,6 +61,3 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(number_count,
ops::NumberCountOp,
ops::NumberCountOpMaker);
PD_REGISTER_STRUCT_KERNEL(
number_count, CPU, ALL_LAYOUT, ops::NumberCountOpCPUKernel, int, int64_t) {}
......@@ -27,10 +27,7 @@ namespace operators {
template <typename T, typename DeviceContext>
class NumberCountOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support expert count op for cpu kernel now."));
}
void Compute(const framework::ExecutionContext& ctx) const override {}
};
} // namespace operators
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/number_count_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void NumberCountKernel(const Context& dev_ctx,
const DenseTensor& numbers,
int upper_range,
DenseTensor* out) {
PADDLE_THROW(phi::errors::Unavailable(
"Do not support expert count op for cpu kernel now."));
}
} // namespace phi
PD_REGISTER_KERNEL(
number_count, CPU, ALL_LAYOUT, phi::NumberCountKernel, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -11,23 +11,14 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// The file has been adapted from the two files:
// https://github.com/laekov/fastmoe/blob/master/cuda/local_exchange.cu
// https://github.com/laekov/fastmoe/blob/master/cuda/local_exchange.cuh
// Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4
// We retain the following license from the original files:
// Copyright 2021, Jiaao He. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License").
#include "paddle/fluid/operators/number_count_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/number_count_kernel.h"
namespace paddle {
namespace operators {
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1)
#define PERTHREAD_EXPERTS 256
#define WARP_SIZE 32
......@@ -79,37 +70,28 @@ __global__ void NumberCount(const T* numbers,
}
}
template <typename T, typename DeviceContext>
class NumberCountOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto numbers = context.Input<phi::DenseTensor>("numbers");
auto upper_range = context.Attr<int>("upper_range");
auto number_count = context.Output<phi::DenseTensor>("Out");
int64_t batch_size = numbers->numel();
auto place = context.GetPlace();
const auto& dev_ctx = context.template device_context<phi::GPUContext>();
template <typename T, typename Context>
void NumberCountKernel(const Context& ctx,
const DenseTensor& numbers,
int upper_range,
DenseTensor* out) {
int64_t batch_size = numbers.numel();
framework::DDim out_dims = phi::make_ddim({upper_range});
auto out_data = number_count->mutable_data<T>(out_dims, place);
const T* gate_data = numbers->data<T>();
DDim out_dims = phi::make_ddim({upper_range});
out->Resize(out_dims);
auto out_data = ctx.template Alloc<T>(out);
const T* gate_data = numbers.data<T>();
initialize_zero_kernel<T>
<<<GET_BLOCKS(upper_range), CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>(
<<<GET_BLOCKS(upper_range), CUDA_NUM_THREADS, 0, ctx.stream()>>>(
out_data, upper_range);
NumberCount<T>
<<<CEIL(upper_range, PERTHREAD_EXPERTS), 256, 0, dev_ctx.stream()>>>(
<<<CEIL(upper_range, PERTHREAD_EXPERTS), 256, 0, ctx.stream()>>>(
gate_data, out_data, batch_size, upper_range);
}
};
} // namespace operators
} // namespace paddle
}
namespace ops = paddle::operators;
namespace plat = paddle::platform;
} // namespace phi
PD_REGISTER_STRUCT_KERNEL(
number_count, GPU, ALL_LAYOUT, ops::NumberCountOpCUDAKernel, int64_t) {}
PD_REGISTER_KERNEL(
number_count, GPU, ALL_LAYOUT, phi::NumberCountKernel, int64_t) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void NumberCountKernel(const Context& ctx,
const DenseTensor& numbers,
int upper_range,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature NumberCountOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("number_count", {"numbers"}, {"upper_range"}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(number_count, phi::NumberCountOpArgumentMapping);
......@@ -30,6 +30,10 @@ def count(x, upper_num):
return res
def number_count_wrapper(numbers, upper_num):
return paddle._legacy_C_ops.number_count(numbers, 'upper_range', upper_num)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
......@@ -37,6 +41,7 @@ class TestNumberCountOpInt64(eager_op_test.OpTest):
def setUp(self):
upper_num = 16
self.op_type = "number_count"
self.python_api = number_count_wrapper
x = np.random.randint(-1, upper_num, size=(1000, 2)).astype('int64')
self.inputs = {'numbers': x}
self.outputs = {'Out': count(x, upper_num)}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册