sum_mkldnn_op.cc 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*Licensed under the Apache License, Version 2.0(the "License");
  you may not use this file except in compliance with the License.
  You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

  Unless required by applicable law or agreed to in writing, software
  distributed under the License is distributed on an "AS IS" BASIS,
  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  See the License for the specific language governing permissions and
  limitations under the License. */

#include "paddle/fluid/operators/sum_op.h"
J
Jacek Czaja 已提交
28
#include "paddle/fluid/platform/mkldnn_reuse.h"
29

30
namespace phi {
31
class DenseTensor;
32
}  // namespace phi
33

34 35 36
namespace paddle {
namespace operators {

T
tangwei12 已提交
37
using paddle::platform::MKLDNNDeviceContext;
L
Leo Chen 已提交
38
using phi::CPUContext;
39 40
using platform::to_void_cast;

J
Jacek Czaja 已提交
41
template <typename T>
42 43
class SumMKLDNNHandler
    : public platform::MKLDNNHandlerNoCachingT<T, dnnl::sum> {
J
Jacek Czaja 已提交
44
 public:
45 46
  SumMKLDNNHandler(dnnl::engine engine,
                   platform::Place cpu_place,
J
Jacek Czaja 已提交
47
                   const std::vector<framework::Variable*>& in_vars,
48
                   framework::LoDTensor* z)
J
Jacek Czaja 已提交
49

50
      : platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place),
J
Jacek Czaja 已提交
51
        num_inputs_(0) {
52
    auto dst_tz = phi::vectorize<int64_t>(z->dims());
53
    auto src_tz = dst_tz;
J
Jacek Czaja 已提交
54

55
    std::vector<dnnl::memory::desc> srcs_md;
56
    srcs_md.reserve(in_vars.size());
57 58 59 60
    for (size_t i = 0; i < in_vars.size(); i++) {
      auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
      if (input_it.numel() == 0) {
        continue;
J
Jacek Czaja 已提交
61
      }
62
      srcs_md.push_back(input_it.mem_desc());
63 64
      ++num_inputs_;
    }
65
    std::vector<float> scales(num_inputs_, 1.0f);
J
Jacek Czaja 已提交
66

67 68
    auto dst_md = dnnl::memory::desc(
        dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
J
Jacek Czaja 已提交
69

70
    this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
J
Jacek Czaja 已提交
71 72 73 74 75
  }

  // (jczaja) sum oneDNN prim is not having .desc attribute so
  // we cannot use base AcquireForwardPrimitiveDescriptor
  void AcquireForwardPrimitiveDescriptor(
76 77
      const dnnl::memory::desc& dst_md,
      const std::vector<float>& scales,
78
      const std::vector<dnnl::memory::desc>& srcs_md) {
79 80
    this->fwd_pd_.reset(
        new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_));
J
Jacek Czaja 已提交
81 82
  }

83 84
  std::shared_ptr<dnnl::memory> AcquireSrcMemory(const framework::Tensor& input,
                                                 int i) {
J
Jacek Czaja 已提交
85 86
    const T* input_data = input.data<T>();
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
87
                                            to_void_cast<T>(input_data));
J
Jacek Czaja 已提交
88 89
  }

90
  using platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory;
J
Jacek Czaja 已提交
91

92
  std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
93
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc());
J
Jacek Czaja 已提交
94 95 96 97 98 99 100 101
  }

  inline int GetNumInputs(void) { return num_inputs_; }

 private:
  int num_inputs_;
};

102 103 104 105
template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
106 107
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
108 109
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL Sum must use CPUPlace"));
110
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
111
    const auto& mkldnn_engine = dev_ctx.GetEngine();
112
    auto in_vars = ctx.MultiInputVar("X");
113

114
    PADDLE_ENFORCE_NE(
115 116
        in_vars.empty(),
        true,
117
        platform::errors::InvalidArgument("Input variable is empty."));
J
Jacek Czaja 已提交
118
    auto& input0 = in_vars[0]->Get<LoDTensor>();
119
    LoDTensor* output = ctx.Output<LoDTensor>("Out");
120

J
Jacek Czaja 已提交
121
    bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
122

123
    SumMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), in_vars, output);
124

J
Jacek Czaja 已提交
125
    // Create list of SRC MEMs
126
    std::vector<std::shared_ptr<dnnl::memory>> srcs_mem;
J
Jacek Czaja 已提交
127 128
    srcs_mem.reserve(handler.GetNumInputs());
    int input_index = 0;
129
    for (size_t i = 0; i < in_vars.size(); i++) {
J
Jacek Czaja 已提交
130
      auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
131 132
      if (input_it.numel() == 0) {
        continue;
A
Adam 已提交
133
      }
J
Jacek Czaja 已提交
134 135
      srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index));
      ++input_index;
136
    }
137

138 139 140 141 142 143 144
    std::unordered_map<int, dnnl::memory> args;
    std::shared_ptr<dnnl::memory> dst_mem;

    for (size_t i = 0; i < srcs_mem.size(); ++i) {
      args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
    }

145
    if (in_place) {
146
      dst_mem = srcs_mem[0];
147 148 149
    } else {
      dst_mem = handler.AcquireDstMemory(output);
    }
150
    args.insert({DNNL_ARG_DST, *dst_mem});
151

J
Jacek Czaja 已提交
152
    auto sum_p = handler.AcquireForwardPrimitive();
153

154
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
J
Jacek Czaja 已提交
155
    sum_p->execute(astream, args);
156 157
    astream.wait();

158
    output->set_mem_desc(dst_mem->get_desc());
159 160 161 162 163 164
  }
};

}  // namespace operators
}  // namespace paddle

J
Jacek Czaja 已提交
165
REGISTER_OP_KERNEL(
166 167 168
    sum,
    MKLDNN,
    ::paddle::platform::CPUPlace,
J
Jacek Czaja 已提交
169 170
    paddle::operators::SumMKLDNNOpKernel<paddle::platform::bfloat16>,
    paddle::operators::SumMKLDNNOpKernel<float>);