diff --git a/paddle/fluid/operators/shard_index_op.cc b/paddle/fluid/operators/shard_index_op.cc index 54555e494ffe5f2c226c7aabd47b4ce991dab2ec..053a90f2fc9fa2f93c2647c420a046401198bc28 100644 --- a/paddle/fluid/operators/shard_index_op.cc +++ b/paddle/fluid/operators/shard_index_op.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/shard_index_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -20,27 +23,6 @@ namespace operators { class ShardIndexOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShardIndex"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShardIndex"); - - auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_GE(x_dims.size(), 2, - platform::errors::InvalidArgument( - "Rank of Input(X) should be at least 2, " - "but the value given is %d.", - x_dims.size())); - if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) { - PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1], 1U, - platform::errors::InvalidArgument( - "The last dimension of Input(X) should be 1, " - "but the value given is %d.", - x_dims[x_dims.size() - 1])); - } - - ctx->SetOutputDim("Out", x_dims); - ctx->ShareLoD("X", /* --> */ "Out"); - } protected: framework::OpKernelType GetExpectedKernelType( @@ -114,7 +96,10 @@ Examples: } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(shard_index, ops::ShardIndexOp, - ops::ShardIndexOpMaker); -REGISTER_OP_CPU_KERNEL(shard_index, ops::ShardIndexCPUKernel, - ops::ShardIndexCPUKernel); +DECLARE_INFER_SHAPE_FUNCTOR(shard_index, ShardIndexInferShapeFunctor, + PD_INFER_META(phi::ShardIndexInferMeta)); +REGISTER_OPERATOR( + shard_index, ops::ShardIndexOp, ops::ShardIndexOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + ShardIndexInferShapeFunctor); diff --git a/paddle/fluid/operators/shard_index_op.cu b/paddle/fluid/operators/shard_index_op.cu deleted file mode 100644 index 115b3f47d664ba00228343d221d5be70d13a7ff1..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/shard_index_op.cu +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/shard_index_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -using platform::PADDLE_CUDA_NUM_THREADS; - -template -__global__ void ShardIndexInner(const T* in_data, T* out_data, - const int64_t numel, const int index_num, - const int nshards, const int shard_id, - const int ignore_value) { - int shard_size = (index_num + nshards - 1) / nshards; - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < numel) { - assert(in_data[idx] >= 0 && in_data[idx] < index_num); - if (in_data[idx] / shard_size == shard_id) { - out_data[idx] = in_data[idx] % shard_size; - } else { - out_data[idx] = ignore_value; - } - } -} - -using LoDTensor = framework::LoDTensor; - -template -class ShardIndexCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - int index_num = context.Attr("index_num"); - int nshards = context.Attr("nshards"); - int shard_id = context.Attr("shard_id"); - int ignore_value = context.Attr("ignore_value"); - PADDLE_ENFORCE_GT( - index_num, 0, - platform::errors::InvalidArgument( - "The value 'index_num' for Op(shard_index) must be greater than 0, " - "but the value given is %d.", - index_num)); - PADDLE_ENFORCE_GT(nshards, 0, - platform::errors::InvalidArgument( - "The value 'nshard' for Op(shard_index) must be " - "greater than 0, but the value given is %d.", - nshards)); - PADDLE_ENFORCE_GE( - shard_id, 0, - platform::errors::InvalidArgument( - "The value 'shard_id' for Op(shard_index) must be greater or " - "equal to 0, but the value given is %d.", - shard_id)); - PADDLE_ENFORCE_LT( - shard_id, nshards, - platform::errors::InvalidArgument( - "The value 'shard_id' for Op(shard_index) must be less than " - "nshards (%d), but the value given is %d.", - nshards, shard_id)); - - out->Resize(in->dims()); - out->set_lod(in->lod()); - auto* in_data = in->data(); - auto* out_data = out->mutable_data(context.GetPlace()); - int64_t numel = in->numel(); - auto stream = - context.template device_context().stream(); - ShardIndexInner<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>( - in_data, out_data, numel, index_num, nshards, shard_id, ignore_value); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(shard_index, ops::ShardIndexCUDAKernel, - ops::ShardIndexCUDAKernel); diff --git a/paddle/fluid/operators/shard_index_op.h b/paddle/fluid/operators/shard_index_op.h deleted file mode 100644 index c2fe3711686d4c4c802fadd66d4bc994232ef5ec..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/shard_index_op.h +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using LoDTensor = framework::LoDTensor; -template -class ShardIndexCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - int index_num = context.Attr("index_num"); - int nshards = context.Attr("nshards"); - int shard_id = context.Attr("shard_id"); - int ignore_value = context.Attr("ignore_value"); - PADDLE_ENFORCE_GT( - index_num, 0, - platform::errors::InvalidArgument( - "The value 'index_num' for Op(shard_index) must be greater than 0, " - "but the value given is %d.", - index_num)); - PADDLE_ENFORCE_GT(nshards, 0, - platform::errors::InvalidArgument( - "The value 'nshard' for Op(shard_index) must be " - "greater than 0, but the value given is %d.", - nshards)); - PADDLE_ENFORCE_GE( - shard_id, 0, - platform::errors::InvalidArgument( - "The value 'shard_id' for Op(shard_index) must be greater or " - "equal to 0, but the value given is %d.", - shard_id)); - PADDLE_ENFORCE_LT( - shard_id, nshards, - platform::errors::InvalidArgument( - "The value 'shard_id' for Op(shard_index) must be less than " - "nshards (%d), but the value given is %d.", - nshards, shard_id)); - - int shard_size = (index_num + nshards - 1) / nshards; - - out->Resize(in->dims()); - out->set_lod(in->lod()); - auto* in_data = in->data(); - auto* out_data = out->mutable_data(context.GetPlace()); - int64_t numel = in->numel(); - for (int64_t i = 0; i < numel; ++i) { - PADDLE_ENFORCE_GE(in_data[i], 0, - platform::errors::InvalidArgument( - "The input_index for Op(shard_index) must be " - "greater or equal to 0, but the value given is %d.", - in_data[i])); - PADDLE_ENFORCE_LT(in_data[i], index_num, - platform::errors::InvalidArgument( - "The input_index for Op(shard_index) must be less " - "than index_num (%d), but the value given is %d.", - index_num, in_data[i])); - if (in_data[i] / shard_size == shard_id) { - out_data[i] = in_data[i] % shard_size; - } else { - out_data[i] = ignore_value; - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/shard_index_op_npu.cc b/paddle/fluid/operators/shard_index_op_npu.cc index dc2e8ad58f31ce8fe845ecb1f368544704e1d9ad..c875448424a24e686b9a6285725f801d604abc46 100644 --- a/paddle/fluid/operators/shard_index_op_npu.cc +++ b/paddle/fluid/operators/shard_index_op_npu.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/shard_index_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d7e2bc1767a704e84c4a3c02c9e6b16dc1e3ee05..c26af34f771d5c1e2476489fc40ad0844933bad5 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1312,6 +1312,34 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { out->set_dtype(DataType::INT64); } +void ShardIndexInferMeta(const MetaTensor& in, + int index_num, + int nshards, + int shard_id, + int ignore_value, + MetaTensor* out, + MetaConfig config) { + auto x_dims = in.dims(); + PADDLE_ENFORCE_GE( + x_dims.size(), + 2, + phi::errors::InvalidArgument("Rank of Input(X) should be at least 2, " + "but the value given is %d.", + x_dims.size())); + if (config.is_runtime || x_dims[x_dims.size() - 1] > 0) { + PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1], + 1U, + phi::errors::InvalidArgument( + "The last dimension of Input(X) should be 1, " + "but the value given is %d.", + x_dims[x_dims.size() - 1])); + } + + out->set_dims(x_dims); + out->share_lod(in); + out->set_dtype(in.dtype()); +} + } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a3e5628a4d7b4aa297b1275b1e50c1ce9acceacc..59ee613b8b0eb0f7650d00b4ab04a68e1f3b7027 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -190,4 +190,12 @@ void EighInferMeta(const MetaTensor& x, void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out); +void ShardIndexInferMeta(const MetaTensor& in, + int index_num, + int nshards, + int shard_id, + int ignore_value, + MetaTensor* out, + MetaConfig config = MetaConfig()); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/shard_index_kernel.cc b/paddle/phi/kernels/cpu/shard_index_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..a82bb8ce5929dcae46700dc4bd7013ab1fee0b39 --- /dev/null +++ b/paddle/phi/kernels/cpu/shard_index_kernel.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/shard_index_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void ShardIndexKernel(const Context& dev_ctx, + const DenseTensor& in, + int index_num, + int nshards, + int shard_id, + int ignore_value, + DenseTensor* out) { + PADDLE_ENFORCE_GT( + index_num, + 0, + errors::InvalidArgument( + "The value 'index_num' for Op(shard_index) must be greater than 0, " + "but the value given is %d.", + index_num)); + PADDLE_ENFORCE_GT( + nshards, + 0, + errors::InvalidArgument("The value 'nshard' for Op(shard_index) must be " + "greater than 0, but the value given is %d.", + nshards)); + PADDLE_ENFORCE_GE( + shard_id, + 0, + errors::InvalidArgument( + "The value 'shard_id' for Op(shard_index) must be greater or " + "equal to 0, but the value given is %d.", + shard_id)); + PADDLE_ENFORCE_LT( + shard_id, + nshards, + errors::InvalidArgument( + "The value 'shard_id' for Op(shard_index) must be less than " + "nshards (%d), but the value given is %d.", + nshards, + shard_id)); + + int shard_size = (index_num + nshards - 1) / nshards; + + out->Resize(in.dims()); + out->set_lod(in.lod()); + auto* in_data = in.data(); + auto* out_data = dev_ctx.template Alloc(out); + int64_t numel = in.numel(); + for (int64_t i = 0; i < numel; ++i) { + PADDLE_ENFORCE_GE(in_data[i], + 0, + errors::InvalidArgument( + "The input_index for Op(shard_index) must be " + "greater or equal to 0, but the value given is %d.", + in_data[i])); + PADDLE_ENFORCE_LT(in_data[i], + index_num, + errors::InvalidArgument( + "The input_index for Op(shard_index) must be less " + "than index_num (%d), but the value given is %d.", + index_num, + in_data[i])); + if (in_data[i] / shard_size == shard_id) { + out_data[i] = in_data[i] % shard_size; + } else { + out_data[i] = ignore_value; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + shard_index, CPU, ALL_LAYOUT, phi::ShardIndexKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/shard_index_kernel.cu b/paddle/phi/kernels/gpu/shard_index_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0bd7b93f68928ae268feb2aad9c2aa0304ca4790 --- /dev/null +++ b/paddle/phi/kernels/gpu/shard_index_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/shard_index_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void ShardIndexInner(const T* in_data, + T* out_data, + const int64_t numel, + const int index_num, + const int nshards, + const int shard_id, + const int ignore_value) { + int shard_size = (index_num + nshards - 1) / nshards; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) { + assert(in_data[idx] >= 0 && in_data[idx] < index_num); + if (in_data[idx] / shard_size == shard_id) { + out_data[idx] = in_data[idx] % shard_size; + } else { + out_data[idx] = ignore_value; + } + } +} + +template +void ShardIndexKernel(const Context& dev_ctx, + const DenseTensor& in, + int index_num, + int nshards, + int shard_id, + int ignore_value, + DenseTensor* out) { + PADDLE_ENFORCE_GT( + index_num, + 0, + phi::errors::InvalidArgument( + "The value 'index_num' for Op(shard_index) must be greater than 0, " + "but the value given is %d.", + index_num)); + PADDLE_ENFORCE_GT(nshards, + 0, + phi::errors::InvalidArgument( + "The value 'nshard' for Op(shard_index) must be " + "greater than 0, but the value given is %d.", + nshards)); + PADDLE_ENFORCE_GE( + shard_id, + 0, + phi::errors::InvalidArgument( + "The value 'shard_id' for Op(shard_index) must be greater or " + "equal to 0, but the value given is %d.", + shard_id)); + PADDLE_ENFORCE_LT( + shard_id, + nshards, + phi::errors::InvalidArgument( + "The value 'shard_id' for Op(shard_index) must be less than " + "nshards (%d), but the value given is %d.", + nshards, + shard_id)); + + out->Resize(in.dims()); + out->set_lod(in.lod()); + auto* in_data = in.data(); + auto* out_data = dev_ctx.template Alloc(out); + int64_t numel = in.numel(); + auto stream = dev_ctx.stream(); + ShardIndexInner< + T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>( + in_data, out_data, numel, index_num, nshards, shard_id, ignore_value); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + shard_index, GPU, ALL_LAYOUT, phi::ShardIndexKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/shard_index_kernel.h b/paddle/phi/kernels/shard_index_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..54ad9a14fa023e41d5261cfe9f03039a1fa031ec --- /dev/null +++ b/paddle/phi/kernels/shard_index_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ShardIndexKernel(const Context& dev_ctx, + const DenseTensor& in, + int index_num, + int nshards, + int shard_id, + int ignore_value, + DenseTensor* out); + +} // namespace phi