gather_op_v2.h 2.5 KB
Newer Older
L
lilong12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"

25
#if defined(PADDLE_WITH_GLOO)
S
sandyhouse 已提交
26
#include <gloo/gather.h>
27 28 29
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif

L
lilong12 已提交
30 31 32 33
namespace paddle {
namespace operators {

template <typename T>
S
sandyhouse 已提交
34
class GatherOpV2CPUKernel : public framework::OpKernel<T> {
L
lilong12 已提交
35 36
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
37 38 39 40
#if defined(PADDLE_WITH_GLOO)
    auto in = ctx.Input<framework::Tensor>("X");
    auto out = ctx.Output<framework::Tensor>("Out");
    auto root_id = ctx.Attr<int>("root");
S
sandyhouse 已提交
41
    auto nranks = ctx.Attr<int>("nranks");
42 43 44 45 46 47 48

    auto gloo = paddle::framework::GlooWrapper::GetInstance();
    PADDLE_ENFORCE_EQ(
        gloo->IsInitialized(), true,
        platform::errors::PreconditionNotMet(
            "You must initialize the gloo environment first to use it."));

S
sandyhouse 已提交
49 50 51 52 53 54 55 56 57 58
    PADDLE_ENFORCE_EQ(nranks, gloo->Size(),
                      platform::errors::InvalidArgument(
                          "The number of ranks (%d) you set must "
                          "be equal to gloo->Size() (%d).",
                          nranks, gloo->Size()));
    int64_t send_numel = in->numel();
    int64_t recv_numel = out->numel();
    auto in_dim = x->dims();
    auto out_dim = framework::DDim(in_dim);
    out_dim[0] *= nranks;
59 60
    auto nranks = gloo->Size();
    auto rank = gloo->Rank();
S
sandyhouse 已提交
61
    gloo::GatherOptions opts(gloo->GetContext());
62
    if (root_id == rank) {
S
sandyhouse 已提交
63 64
      T* recv_buff = out->mutable_data<T>(place, out_dim);
      opts.setOutput(recv_buff, recv_numel);
65
    }
S
sandyhouse 已提交
66
    opts.setInput(send_buff, send_numel);
67 68
    opts.setRoot(root_id);

S
sandyhouse 已提交
69
    gloo::gather(opts);
70 71 72 73
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
#endif
L
lilong12 已提交
74 75 76 77 78
  }
};

}  // namespace operators
}  // namespace paddle