未验证 提交 ad037caa 编写于 作者: Jeffrey Chen's avatar Jeffrey Chen 提交者: GitHub

[PHI] Migrate shard_index op (#40254)

上级 8cabb9f3
......@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/shard_index_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -20,27 +23,6 @@ namespace operators {
class ShardIndexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShardIndex");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShardIndex");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"Rank of Input(X) should be at least 2, "
"but the value given is %d.",
x_dims.size()));
if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) {
PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1], 1U,
platform::errors::InvalidArgument(
"The last dimension of Input(X) should be 1, "
"but the value given is %d.",
x_dims[x_dims.size() - 1]));
}
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /* --> */ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -114,7 +96,10 @@ Examples:
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(shard_index, ops::ShardIndexOp,
ops::ShardIndexOpMaker);
REGISTER_OP_CPU_KERNEL(shard_index, ops::ShardIndexCPUKernel<int>,
ops::ShardIndexCPUKernel<int64_t>);
DECLARE_INFER_SHAPE_FUNCTOR(shard_index, ShardIndexInferShapeFunctor,
PD_INFER_META(phi::ShardIndexInferMeta));
REGISTER_OPERATOR(
shard_index, ops::ShardIndexOp, ops::ShardIndexOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ShardIndexInferShapeFunctor);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/shard_index_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
template <typename T>
__global__ void ShardIndexInner(const T* in_data, T* out_data,
const int64_t numel, const int index_num,
const int nshards, const int shard_id,
const int ignore_value) {
int shard_size = (index_num + nshards - 1) / nshards;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
assert(in_data[idx] >= 0 && in_data[idx] < index_num);
if (in_data[idx] / shard_size == shard_id) {
out_data[idx] = in_data[idx] % shard_size;
} else {
out_data[idx] = ignore_value;
}
}
}
using LoDTensor = framework::LoDTensor;
template <typename T>
class ShardIndexCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int index_num = context.Attr<int>("index_num");
int nshards = context.Attr<int>("nshards");
int shard_id = context.Attr<int>("shard_id");
int ignore_value = context.Attr<int>("ignore_value");
PADDLE_ENFORCE_GT(
index_num, 0,
platform::errors::InvalidArgument(
"The value 'index_num' for Op(shard_index) must be greater than 0, "
"but the value given is %d.",
index_num));
PADDLE_ENFORCE_GT(nshards, 0,
platform::errors::InvalidArgument(
"The value 'nshard' for Op(shard_index) must be "
"greater than 0, but the value given is %d.",
nshards));
PADDLE_ENFORCE_GE(
shard_id, 0,
platform::errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be greater or "
"equal to 0, but the value given is %d.",
shard_id));
PADDLE_ENFORCE_LT(
shard_id, nshards,
platform::errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be less than "
"nshards (%d), but the value given is %d.",
nshards, shard_id));
out->Resize(in->dims());
out->set_lod(in->lod());
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = in->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
ShardIndexInner<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, out_data, numel, index_num, nshards, shard_id, ignore_value);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(shard_index, ops::ShardIndexCUDAKernel<int>,
ops::ShardIndexCUDAKernel<int64_t>);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename T>
class ShardIndexCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int index_num = context.Attr<int>("index_num");
int nshards = context.Attr<int>("nshards");
int shard_id = context.Attr<int>("shard_id");
int ignore_value = context.Attr<int>("ignore_value");
PADDLE_ENFORCE_GT(
index_num, 0,
platform::errors::InvalidArgument(
"The value 'index_num' for Op(shard_index) must be greater than 0, "
"but the value given is %d.",
index_num));
PADDLE_ENFORCE_GT(nshards, 0,
platform::errors::InvalidArgument(
"The value 'nshard' for Op(shard_index) must be "
"greater than 0, but the value given is %d.",
nshards));
PADDLE_ENFORCE_GE(
shard_id, 0,
platform::errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be greater or "
"equal to 0, but the value given is %d.",
shard_id));
PADDLE_ENFORCE_LT(
shard_id, nshards,
platform::errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be less than "
"nshards (%d), but the value given is %d.",
nshards, shard_id));
int shard_size = (index_num + nshards - 1) / nshards;
out->Resize(in->dims());
out->set_lod(in->lod());
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = in->numel();
for (int64_t i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(in_data[i], 0,
platform::errors::InvalidArgument(
"The input_index for Op(shard_index) must be "
"greater or equal to 0, but the value given is %d.",
in_data[i]));
PADDLE_ENFORCE_LT(in_data[i], index_num,
platform::errors::InvalidArgument(
"The input_index for Op(shard_index) must be less "
"than index_num (%d), but the value given is %d.",
index_num, in_data[i]));
if (in_data[i] / shard_size == shard_id) {
out_data[i] = in_data[i] % shard_size;
} else {
out_data[i] = ignore_value;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/shard_index_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -1312,6 +1312,34 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
out->set_dtype(DataType::INT64);
}
void ShardIndexInferMeta(const MetaTensor& in,
int index_num,
int nshards,
int shard_id,
int ignore_value,
MetaTensor* out,
MetaConfig config) {
auto x_dims = in.dims();
PADDLE_ENFORCE_GE(
x_dims.size(),
2,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 2, "
"but the value given is %d.",
x_dims.size()));
if (config.is_runtime || x_dims[x_dims.size() - 1] > 0) {
PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1],
1U,
phi::errors::InvalidArgument(
"The last dimension of Input(X) should be 1, "
"but the value given is %d.",
x_dims[x_dims.size() - 1]));
}
out->set_dims(x_dims);
out->share_lod(in);
out->set_dtype(in.dtype());
}
} // namespace phi
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
......
......@@ -190,4 +190,12 @@ void EighInferMeta(const MetaTensor& x,
void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out);
void ShardIndexInferMeta(const MetaTensor& in,
int index_num,
int nshards,
int shard_id,
int ignore_value,
MetaTensor* out,
MetaConfig config = MetaConfig());
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/shard_index_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ShardIndexKernel(const Context& dev_ctx,
const DenseTensor& in,
int index_num,
int nshards,
int shard_id,
int ignore_value,
DenseTensor* out) {
PADDLE_ENFORCE_GT(
index_num,
0,
errors::InvalidArgument(
"The value 'index_num' for Op(shard_index) must be greater than 0, "
"but the value given is %d.",
index_num));
PADDLE_ENFORCE_GT(
nshards,
0,
errors::InvalidArgument("The value 'nshard' for Op(shard_index) must be "
"greater than 0, but the value given is %d.",
nshards));
PADDLE_ENFORCE_GE(
shard_id,
0,
errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be greater or "
"equal to 0, but the value given is %d.",
shard_id));
PADDLE_ENFORCE_LT(
shard_id,
nshards,
errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be less than "
"nshards (%d), but the value given is %d.",
nshards,
shard_id));
int shard_size = (index_num + nshards - 1) / nshards;
out->Resize(in.dims());
out->set_lod(in.lod());
auto* in_data = in.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
int64_t numel = in.numel();
for (int64_t i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(in_data[i],
0,
errors::InvalidArgument(
"The input_index for Op(shard_index) must be "
"greater or equal to 0, but the value given is %d.",
in_data[i]));
PADDLE_ENFORCE_LT(in_data[i],
index_num,
errors::InvalidArgument(
"The input_index for Op(shard_index) must be less "
"than index_num (%d), but the value given is %d.",
index_num,
in_data[i]));
if (in_data[i] / shard_size == shard_id) {
out_data[i] = in_data[i] % shard_size;
} else {
out_data[i] = ignore_value;
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
shard_index, CPU, ALL_LAYOUT, phi::ShardIndexKernel, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/shard_index_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename T>
__global__ void ShardIndexInner(const T* in_data,
T* out_data,
const int64_t numel,
const int index_num,
const int nshards,
const int shard_id,
const int ignore_value) {
int shard_size = (index_num + nshards - 1) / nshards;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
assert(in_data[idx] >= 0 && in_data[idx] < index_num);
if (in_data[idx] / shard_size == shard_id) {
out_data[idx] = in_data[idx] % shard_size;
} else {
out_data[idx] = ignore_value;
}
}
}
template <typename T, typename Context>
void ShardIndexKernel(const Context& dev_ctx,
const DenseTensor& in,
int index_num,
int nshards,
int shard_id,
int ignore_value,
DenseTensor* out) {
PADDLE_ENFORCE_GT(
index_num,
0,
phi::errors::InvalidArgument(
"The value 'index_num' for Op(shard_index) must be greater than 0, "
"but the value given is %d.",
index_num));
PADDLE_ENFORCE_GT(nshards,
0,
phi::errors::InvalidArgument(
"The value 'nshard' for Op(shard_index) must be "
"greater than 0, but the value given is %d.",
nshards));
PADDLE_ENFORCE_GE(
shard_id,
0,
phi::errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be greater or "
"equal to 0, but the value given is %d.",
shard_id));
PADDLE_ENFORCE_LT(
shard_id,
nshards,
phi::errors::InvalidArgument(
"The value 'shard_id' for Op(shard_index) must be less than "
"nshards (%d), but the value given is %d.",
nshards,
shard_id));
out->Resize(in.dims());
out->set_lod(in.lod());
auto* in_data = in.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
int64_t numel = in.numel();
auto stream = dev_ctx.stream();
ShardIndexInner<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
in_data, out_data, numel, index_num, nshards, shard_id, ignore_value);
}
} // namespace phi
PD_REGISTER_KERNEL(
shard_index, GPU, ALL_LAYOUT, phi::ShardIndexKernel, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ShardIndexKernel(const Context& dev_ctx,
const DenseTensor& in,
int index_num,
int nshards,
int shard_id,
int ignore_value,
DenseTensor* out);
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册