diff --git a/paddle/fluid/operators/collective/alltoall_op.cc b/paddle/fluid/operators/collective/alltoall_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..75cf5fb47f7e529f1c255f5932cc02ec899e00bb --- /dev/null +++ b/paddle/fluid/operators/collective/alltoall_op.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/alltoall_op.h" + +namespace paddle { +namespace operators { + +class AllToAllOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CAllToAll"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CAllToAll"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for alltoall_op must be non-negative.", ring_id)); + framework::DDim dim = ctx->GetInputDim("X"); + if (dim[0] < 0) dim[0] = -1; + ctx->SetOutputDim("Out", dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class AllToAllOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor send."); + AddOutput("Out", "(Tensor) the result of alltoall."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +AllToAll Operator +Gather tensors from all participators to all participators. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(alltoall, ops::AllToAllOp, ops::AllToAllOpMaker); + +REGISTER_OP_CPU_KERNEL(alltoall, ops::AllToAllOpCPUKernel, + ops::AllToAllOpCPUKernel, + ops::AllToAllOpCPUKernel, + ops::AllToAllOpCPUKernel, + ops::AllToAllOpCPUKernel); diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3171968d94ed3ed164c39fe8961ce9ef57c2626 --- /dev/null +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/alltoall_op.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class AllToAllOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + int send_numel = x->numel(); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int ring_id = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + int nranks = comm->nranks(); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + framework::DDim x_dims = x->dims(); + framework::DDim out_dims(x_dims); + PADDLE_ENFORCE_EQ( + x_dims[0] % nranks, 0, + platform::errors::InvalidArgument( + "The first dimension size (%d) of the input tensor must be " + "divisible by the number of ranks (%d).", + x_dims[0], nranks)); + auto send_buf = x->data(); + auto recv_buf = out->mutable_data(out_dims, place); + size_t offset = 0; + send_numel /= nranks; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( + send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + offset += send_numel; + } + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(alltoall, ops::AllToAllOpCUDAKernel, + ops::AllToAllOpCUDAKernel, + ops::AllToAllOpCUDAKernel, + ops::AllToAllOpCUDAKernel, + ops::AllToAllOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/alltoall_op.h b/paddle/fluid/operators/collective/alltoall_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f0ee470158d4bd0418c72995c852d53546551cdc --- /dev/null +++ b/paddle/fluid/operators/collective/alltoall_op.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class AllToAllOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support alltoall for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/recv_op_v2.cc b/paddle/fluid/operators/collective/recv_op_v2.cc new file mode 100644 index 0000000000000000000000000000000000000000..8755b7e3af33596534f112a438e70f3baa2dd4e0 --- /dev/null +++ b/paddle/fluid/operators/collective/recv_op_v2.cc @@ -0,0 +1,93 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/recv_op_v2.h" +#include + +namespace paddle { +namespace operators { + +class RecvOpV2 : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CRecv"); + int peer = ctx->Attrs().Get("peer"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + peer, 0, + platform::errors::InvalidArgument( + "The peer (%d) for send_op_v2 must be non-negative.", peer)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for send_op_v2 must be non-negative.", ring_id)); + auto out_shape = ctx->Attrs().Get>("out_shape"); + PADDLE_ENFORCE_GE(out_shape.size(), 1, + platform::errors::InvalidArgument( + "The size of the output shape must be greater than 0 " + "but the value given is %d.", + out_shape.size())); + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + VLOG(0) << "wow1"; + int dtype = ctx.Attr("dtype"); + framework::proto::VarType::Type type = + framework::proto::VarType::Type(data_type); + return framework::OpKernelType(type, ctx.GetPlace()); + } +}; + +class RecvOpV2Maker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddOutput("Out", "(Tensor) tensor to receive."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr("peer", "(int default 0) rank id for sender.").SetDefault(0); + AddAttr("dtype", + "(std::string default 5(float32)) data type of tensor.") + .SetDefault(5); + AddAttr>("out_shape", "shape of the output tensor.") + .SetDefault(std::vector()); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +Recv Operator + +Reference: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#sendrecv +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(recv_v2, ops::RecvOpV2, ops::RecvOpV2Maker); + +REGISTER_OP_CPU_KERNEL(recv_v2, ops::RecvOpV2CPUKernel, + ops::RecvOpV2CPUKernel, + ops::RecvOpV2CPUKernel, + ops::RecvOpV2CPUKernel, + ops::RecvOpV2CPUKernel); diff --git a/paddle/fluid/operators/collective/recv_op_v2.cu.cc b/paddle/fluid/operators/collective/recv_op_v2.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e830442b1c9df4ea5271adfdd6c32877985a9c3 --- /dev/null +++ b/paddle/fluid/operators/collective/recv_op_v2.cu.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/send_op_v2.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class RecvOpV2CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { +#if defined(PADDLE_WITH_NCCL) + auto out = ctx.Output("Out"); + int data_type = ctx.Attr("dtype"); + framework::proto::VarType::Type type = + framework::proto::VarType::Type(data_type); + ncclDataType_t dtype = platform::ToNCCLDataType(type); + + auto out_dims = out->dims(); + // Recv the number of element first + int numel = 0; + int *numel_ptr = nullptr; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int))); + + int rid = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + int peer = ctx.Attr("peer"); + PADDLE_ENFORCE_LT( + peer, comm->nranks(), + platform::errors::InvalidArgument("The value of peer (%d) you set must " + "be less than comm->nranks (%d).", + peer, comm->nranks())); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclRecv(static_cast(numel_ptr), 1, ncclInt, + peer, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpy(&numel, numel_ptr, sizeof(int), cudaMemcpyDeviceToHost)); + + int rest_numel = 1; + for (size_t i = 1; i < out_dims.size(); ++i) { + rest_numel = rest_numel * out_dims[i]; + } + out_dims[0] = numel / rest_numel; + out->mutable_data(out_dims, place); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + out->data(), numel, dtype, peer, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); + VLOG(3) << "rank " << comm->rank() << " recv " + << framework::product(out->dims()) << " from " << peer; +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(recv_v2, ops::RecvOpV2CUDAKernel, + ops::RecvOpV2CUDAKernel, + ops::RecvOpV2CUDAKernel, + ops::RecvOpV2CUDAKernel, + ops::RecvOpV2CUDAKernel); diff --git a/paddle/fluid/operators/collective/recv_op_v2.h b/paddle/fluid/operators/collective/recv_op_v2.h new file mode 100644 index 0000000000000000000000000000000000000000..1f46d573c53c00573a43ebd44ad9862d14a324e4 --- /dev/null +++ b/paddle/fluid/operators/collective/recv_op_v2.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class RecvOpV2CPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support recv for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/send_op_v2.cc b/paddle/fluid/operators/collective/send_op_v2.cc new file mode 100644 index 0000000000000000000000000000000000000000..d0e00d748576ea73caa2691b431a304677f720c2 --- /dev/null +++ b/paddle/fluid/operators/collective/send_op_v2.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/send_op_v2.h" + +namespace paddle { +namespace operators { + +class SendOpV2 : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CSend"); + int peer = ctx->Attrs().Get("peer"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + peer, 0, + platform::errors::InvalidArgument( + "The peer (%d) for send_op_v2 must be non-negative.", peer)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for send_op_v2 must be non-negative.", ring_id)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class SendOpV2Maker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor to be sent."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr("peer", "(int default 0) rank id for receiver.").SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +Send Operator + +Reference: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#sendrecv +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(send_v2, ops::SendOpV2, ops::SendOpV2Maker); + +REGISTER_OP_CPU_KERNEL(send_v2, ops::SendOpV2CPUKernel, + ops::SendOpV2CPUKernel, + ops::SendOpV2CPUKernel, + ops::SendOpV2CPUKernel, + ops::SendOpV2CPUKernel); diff --git a/paddle/fluid/operators/collective/send_op_v2.cu.cc b/paddle/fluid/operators/collective/send_op_v2.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..fd8259b00d72d59b30c96f3730d306d0b9acedec --- /dev/null +++ b/paddle/fluid/operators/collective/send_op_v2.cu.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/send_op_v2.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class SendOpV2CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) + auto x = ctx.Input("X"); + int numel = x->numel(); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int rid = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + int peer = ctx.Attr("peer"); + PADDLE_ENFORCE_LT( + peer, comm->nranks(), + platform::errors::InvalidArgument("The value of peer (%d) you set must " + "be less than comm->nranks (%d).", + peer, comm->nranks())); + + // Send number of elements to the receiver, as the receiver may have + // no information of the Tensor size. + int* numel_ptr = nullptr; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int))); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpy(numel_ptr, &numel, sizeof(int), cudaMemcpyHostToDevice)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( + numel_ptr, 1, ncclInt, peer, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( + x->data(), numel, dtype, peer, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); + VLOG(3) << "rank " << comm->rank() << " send " + << framework::product(x->dims()) << " to " << peer; +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(send_v2, ops::SendOpV2CUDAKernel, + ops::SendOpV2CUDAKernel, + ops::SendOpV2CUDAKernel, + ops::SendOpV2CUDAKernel, + ops::SendOpV2CUDAKernel); diff --git a/paddle/fluid/operators/collective/send_op_v2.h b/paddle/fluid/operators/collective/send_op_v2.h new file mode 100644 index 0000000000000000000000000000000000000000..6215fb1f3b643b30b42fd9c386a46ca31bc6a54d --- /dev/null +++ b/paddle/fluid/operators/collective/send_op_v2.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SendOpV2CPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support send for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle