未验证 提交 1c29196e 编写于 作者: 0 0x45f 提交者: GitHub

[Phi]Move bincount OP to phi (#39947)

* move bincount OP to phi

* fix dtype

* set_dtype by weights or x

* fix conflicts
上级 2a3d9eca
......@@ -12,12 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bincount_op.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -28,51 +31,6 @@ class BincountOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of BincountOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of BincountOp should not be null."));
auto input_dim = ctx->GetInputDim("X");
auto minlength = ctx->Attrs().Get<int>("minlength");
PADDLE_ENFORCE_GE(minlength, 0,
platform::errors::InvalidArgument(
"The minlength should be greater than or equal to 0."
"But received minlength is %d",
minlength));
PADDLE_ENFORCE_EQ(input_dim.size(), 1,
platform::errors::InvalidArgument(
"The 'shape' of Input(X) must be 1-D tensor."
"But the dimension of Input(X) is [%d]",
input_dim.size()));
if (ctx->HasInput("Weights")) {
auto weights_dim = ctx->GetInputDim("Weights");
PADDLE_ENFORCE_EQ(weights_dim.size(), 1,
platform::errors::InvalidArgument(
"The 'shape' of Input(Weights) must be 1-D tensor."
"But the dimension of Input(Weights) is [%d]",
weights_dim.size()));
PADDLE_ENFORCE_EQ(
weights_dim[0], input_dim[0],
platform::errors::InvalidArgument(
"The 'shape' of Input(Weights) must be equal to the 'shape' of "
"Input(X)."
"But received: the 'shape' of Input(Weights) is [%s],"
"the 'shape' of Input(X) is [%s]",
weights_dim, input_dim));
}
ctx->SetOutputDim("Out", phi::make_ddim({-1}));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto data_type =
......@@ -105,12 +63,10 @@ class BincountOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(bincount, BincountInferShapeFunctor,
PD_INFER_META(phi::BincountInferMeta));
REGISTER_OPERATOR(
bincount, ops::BincountOp, ops::BincountOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
bincount, ops::BincountKernel<paddle::platform::CPUDeviceContext, float>,
ops::BincountKernel<paddle::platform::CPUDeviceContext, double>,
ops::BincountKernel<paddle::platform::CPUDeviceContext, int>,
ops::BincountKernel<paddle::platform::CPUDeviceContext, int64_t>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
BincountInferShapeFunctor);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T, typename InputT>
void BincountInner(const framework::ExecutionContext& context) {
const Tensor* input = context.Input<framework::Tensor>("X");
const Tensor* weights = context.Input<framework::Tensor>("Weights");
Tensor* output = context.Output<framework::Tensor>("Out");
auto& minlength = context.Attr<int>("minlength");
const InputT* input_data = input->data<InputT>();
auto input_numel = input->numel();
if (input_data == nullptr) {
framework::DDim out_dim{0};
output->Resize(out_dim);
output->mutable_data<InputT>(context.GetPlace());
return;
}
PADDLE_ENFORCE_GE(
*std::min_element(input_data, input_data + input_numel),
static_cast<InputT>(0),
platform::errors::InvalidArgument(
"The elements in input tensor must be non-negative ints"));
int64_t output_size = static_cast<int64_t>(*std::max_element(
input_data, input_data + input_numel)) +
1L;
output_size = std::max(output_size, static_cast<int64_t>(minlength));
framework::DDim out_dim{output_size};
output->Resize(out_dim);
bool has_weights = (weights != nullptr);
if (has_weights) {
const T* weights_data = weights->data<T>();
const auto& weights_type = framework::TransToProtoVarType(weights->dtype());
if (weights_type == framework::proto::VarType::FP32) {
float* output_data = output->mutable_data<float>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, float>()(
context.template device_context<DeviceContext>(), output,
static_cast<float>(0));
for (int64_t i = 0; i < input_numel; i++) {
output_data[input_data[i]] += static_cast<float>(weights_data[i]);
}
} else {
double* output_data = output->mutable_data<double>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, double>()(
context.template device_context<DeviceContext>(), output,
static_cast<double>(0));
for (int64_t i = 0; i < input_numel; i++) {
output_data[input_data[i]] += static_cast<double>(weights_data[i]);
}
}
} else {
int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, int64_t>()(
context.template device_context<DeviceContext>(), output, 0L);
for (int64_t i = 0; i < input_numel; i++) {
output_data[input_data[i]] += 1L;
}
}
}
template <typename DeviceContext, typename T>
class BincountKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<framework::Tensor>("X");
const auto& input_type = framework::TransToProtoVarType(input->dtype());
if (input_type == framework::proto::VarType::INT32) {
BincountInner<DeviceContext, T, int>(context);
} else if (input_type == framework::proto::VarType::INT64) {
BincountInner<DeviceContext, T, int64_t>(context);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -456,6 +456,56 @@ void BCELossInferMeta(const MetaTensor& input,
out->share_lod(input);
}
void BincountInferMeta(const MetaTensor& x,
const paddle::optional<const MetaTensor&> weights,
int minlength,
MetaTensor* out) {
auto input_dim = x.dims();
PADDLE_ENFORCE_GE(minlength,
0,
phi::errors::InvalidArgument(
"The minlength should be greater than or equal to 0."
"But received minlength is %d",
minlength));
PADDLE_ENFORCE_EQ(
input_dim.size(),
1,
phi::errors::InvalidArgument("The 'shape' of Input(X) must be 1-D tensor."
"But the dimension of Input(X) is [%d]",
input_dim.size()));
if (weights.is_initialized()) {
auto weights_dim = weights->dims();
PADDLE_ENFORCE_EQ(weights_dim.size(),
1,
phi::errors::InvalidArgument(
"The 'shape' of Input(Weights) must be 1-D tensor."
"But the dimension of Input(Weights) is [%d]",
weights_dim.size()));
PADDLE_ENFORCE_EQ(
weights_dim[0],
input_dim[0],
phi::errors::InvalidArgument(
"The 'shape' of Input(Weights) must be equal to the 'shape' of "
"Input(X)."
"But received: the 'shape' of Input(Weights) is [%s],"
"the 'shape' of Input(X) is [%s]",
weights_dim,
input_dim));
}
out->set_dims(phi::make_ddim({-1}));
if (weights.is_initialized()) {
out->set_dtype(weights->dtype());
} else {
out->set_dtype(x.dtype());
}
out->share_lod(x);
}
void DistInferMeta(const MetaTensor& x,
const MetaTensor& y,
float p,
......
......@@ -85,6 +85,10 @@ void BCELossInferMeta(const MetaTensor& input,
MetaTensor* out,
MetaConfig config = MetaConfig());
void BincountInferMeta(const MetaTensor& x,
const paddle::optional<const MetaTensor&> weights,
int minlength,
MetaTensor* out);
void DistInferMeta(const MetaTensor& x,
const MetaTensor& y,
float p,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BincountKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<const DenseTensor&> weights,
int minlength,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/bincount_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename Context, typename T, typename InputT>
void BincountInner(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<const DenseTensor&> weights,
int minlength,
DenseTensor* out) {
const DenseTensor* input = &x;
DenseTensor* output = out;
const InputT* input_data = input->data<InputT>();
auto input_numel = input->numel();
if (input_data == nullptr) {
phi::DDim out_dim{0};
output->Resize(out_dim);
dev_ctx.template Alloc<InputT>(output);
return;
}
PADDLE_ENFORCE_GE(
*std::min_element(input_data, input_data + input_numel),
static_cast<InputT>(0),
phi::errors::InvalidArgument(
"The elements in input tensor must be non-negative ints"));
int64_t output_size = static_cast<int64_t>(*std::max_element(
input_data, input_data + input_numel)) +
1L;
output_size = std::max(output_size, static_cast<int64_t>(minlength));
phi::DDim out_dim{output_size};
output->Resize(out_dim);
bool has_weights = weights.is_initialized();
if (has_weights) {
const T* weights_data = weights->data<T>();
if (weights->dtype() == DataType::FLOAT32) {
float* output_data = dev_ctx.template Alloc<float>(output);
phi::funcs::SetConstant<Context, float>()(
dev_ctx, output, static_cast<float>(0));
for (int64_t i = 0; i < input_numel; i++) {
output_data[input_data[i]] += static_cast<float>(weights_data[i]);
}
} else {
double* output_data = dev_ctx.template Alloc<double>(output);
phi::funcs::SetConstant<Context, double>()(
dev_ctx, output, static_cast<double>(0));
for (int64_t i = 0; i < input_numel; i++) {
output_data[input_data[i]] += static_cast<double>(weights_data[i]);
}
}
} else {
int64_t* output_data = dev_ctx.template Alloc<int64_t>(output);
phi::funcs::SetConstant<Context, int64_t>()(dev_ctx, output, 0L);
for (int64_t i = 0; i < input_numel; i++) {
output_data[input_data[i]] += 1L;
}
}
}
template <typename T, typename Context>
void BincountKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<const DenseTensor&> weights,
int minlength,
DenseTensor* out) {
if (x.dtype() == DataType::INT32) {
BincountInner<Context, T, int>(dev_ctx, x, weights, minlength, out);
} else if (x.dtype() == DataType::INT64) {
BincountInner<Context, T, int64_t>(dev_ctx, x, weights, minlength, out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(bincount,
CPU,
ALL_LAYOUT,
phi::BincountKernel,
float,
double,
int,
int64_t) {}
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/bincount_kernel.h"
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/bincount_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
namespace phi {
using Tensor = framework::Tensor;
using platform::PADDLE_CUDA_NUM_THREADS;
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
inline int GET_BLOCKS(const int N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}
template <typename T, typename InputT, typename OutT>
__global__ void KernelBincount(const InputT* input, const int total_elements,
const bool has_weights, const T* weights,
__global__ void KernelBincount(const InputT* input,
const int total_elements,
const bool has_weights,
const T* weights,
OutT* output) {
if (!has_weights) {
for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
......@@ -44,119 +46,119 @@ __global__ void KernelBincount(const InputT* input, const int total_elements,
}
}
template <typename DeviceContext, typename T, typename InputT>
void BincountCUDAInner(const framework::ExecutionContext& context) {
const Tensor* input = context.Input<framework::Tensor>("X");
const Tensor* weights = context.Input<framework::Tensor>("Weights");
Tensor* output = context.Output<framework::Tensor>("Out");
auto& minlength = context.Attr<int>("minlength");
template <typename Context, typename T, typename InputT>
void BincountCUDAInner(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<const DenseTensor&> weights,
int minlength,
DenseTensor* out) {
const DenseTensor* input = &x;
DenseTensor* output = out;
const InputT* input_data = input->data<InputT>();
const int input_numel = input->numel();
if (input_data == nullptr) {
framework::DDim out_dim{0};
phi::DDim out_dim{0};
output->Resize(out_dim);
output->mutable_data<T>(context.GetPlace());
dev_ctx.template Alloc<T>(output);
return;
}
auto input_x = framework::EigenVector<InputT>::Flatten(*input);
framework::Tensor input_min_t, input_max_t;
auto* input_max_data =
input_max_t.mutable_data<InputT>({1}, context.GetPlace());
auto* input_min_data =
input_min_t.mutable_data<InputT>({1}, context.GetPlace());
auto input_x = EigenVector<InputT>::Flatten(*input);
DenseTensor input_min_t, input_max_t;
input_max_t.Resize({1});
auto* input_max_data = dev_ctx.template Alloc<InputT>(&input_max_t);
input_min_t.Resize({1});
auto* input_min_data = dev_ctx.template Alloc<InputT>(&input_min_t);
auto input_max_scala = framework::EigenScalar<InputT>::From(input_max_t);
auto input_min_scala = framework::EigenScalar<InputT>::From(input_min_t);
auto input_max_scala = EigenScalar<InputT>::From(input_max_t);
auto input_min_scala = EigenScalar<InputT>::From(input_min_t);
auto* place = context.template device_context<DeviceContext>().eigen_device();
auto* place = dev_ctx.eigen_device();
input_max_scala.device(*place) = input_x.maximum();
input_min_scala.device(*place) = input_x.minimum();
Tensor input_min_cpu, input_max_cpu;
paddle::framework::TensorCopySync(input_max_t, platform::CPUPlace(),
&input_max_cpu);
paddle::framework::TensorCopySync(input_min_t, platform::CPUPlace(),
&input_min_cpu);
DenseTensor input_min_cpu, input_max_cpu;
paddle::framework::TensorCopySync(
input_max_t, phi::CPUPlace(), &input_max_cpu);
paddle::framework::TensorCopySync(
input_min_t, phi::CPUPlace(), &input_min_cpu);
InputT input_min = input_min_cpu.data<InputT>()[0];
PADDLE_ENFORCE_GE(
input_min, static_cast<InputT>(0),
platform::errors::InvalidArgument(
input_min,
static_cast<InputT>(0),
phi::errors::InvalidArgument(
"The elements in input tensor must be non-negative ints"));
int64_t output_size =
static_cast<int64_t>(input_max_cpu.data<InputT>()[0]) + 1L;
output_size = std::max(output_size, static_cast<int64_t>(minlength));
framework::DDim out_dim{output_size};
phi::DDim out_dim{output_size};
output->Resize(out_dim);
bool has_weights = (weights != nullptr);
bool has_weights = weights.is_initialized();
const T* weights_data = has_weights ? weights->data<T>() : nullptr;
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
auto stream = dev_ctx.stream();
if (!has_weights) {
int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, int64_t>()(
context.template device_context<DeviceContext>(), output, 0L);
int64_t* output_data = dev_ctx.template Alloc<int64_t>(output);
phi::funcs::SetConstant<Context, int64_t>()(dev_ctx, output, 0L);
KernelBincount<T, InputT, int64_t><<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
} else {
const auto& weights_type = framework::TransToProtoVarType(weights->dtype());
const auto& weights_type =
paddle::framework::TransToProtoVarType(weights->dtype());
if (weights_type == framework::proto::VarType::FP32) {
float* output_data = output->mutable_data<float>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, float>()(
context.template device_context<DeviceContext>(), output,
static_cast<float>(0));
if (weights->dtype() == DataType::FLOAT32) {
float* output_data = dev_ctx.template Alloc<float>(output);
phi::funcs::SetConstant<Context, float>()(
dev_ctx, output, static_cast<float>(0));
KernelBincount<T, InputT, float><<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
} else {
double* output_data = output->mutable_data<double>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, double>()(
context.template device_context<DeviceContext>(), output,
static_cast<double>(0));
double* output_data = dev_ctx.template Alloc<double>(output);
phi::funcs::SetConstant<Context, double>()(
dev_ctx, output, static_cast<double>(0));
KernelBincount<T, InputT, double><<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
}
}
}
template <typename DeviceContext, typename T>
class BincountCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<framework::Tensor>("X");
const auto& input_type = framework::TransToProtoVarType(input->dtype());
if (input_type == framework::proto::VarType::INT32) {
BincountCUDAInner<DeviceContext, T, int>(context);
} else if (input_type == framework::proto::VarType::INT64) {
BincountCUDAInner<DeviceContext, T, int64_t>(context);
}
template <typename T, typename Context>
void BincountKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<const DenseTensor&> weights,
int minlength,
DenseTensor* out) {
if (x.dtype() == DataType::INT32) {
BincountCUDAInner<Context, T, int>(dev_ctx, x, weights, minlength, out);
} else if (x.dtype() == DataType::INT64) {
BincountCUDAInner<Context, T, int64_t>(dev_ctx, x, weights, minlength, out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
bincount, ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, double>);
}
} // namespace phi
PD_REGISTER_KERNEL(bincount,
GPU,
ALL_LAYOUT,
phi::BincountKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature BincountOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("bincount", {"X", "Weights"}, {"minlength"}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(bincount, phi::BincountOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册