sum_mkldnn_op.cc 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*Licensed under the Apache License, Version 2.0(the "License");
  you may not use this file except in compliance with the License.
  You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

  Unless required by applicable law or agreed to in writing, software
  distributed under the License is distributed on an "AS IS" BASIS,
  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  See the License for the specific language governing permissions and
  limitations under the License. */

27 28
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
J
Jacek Czaja 已提交
29
#include "paddle/fluid/platform/mkldnn_reuse.h"
30

31
namespace phi {
32
class DenseTensor;
33
}  // namespace phi
34

35 36 37
namespace paddle {
namespace operators {

T
tangwei12 已提交
38
using paddle::platform::MKLDNNDeviceContext;
L
Leo Chen 已提交
39
using phi::CPUContext;
40
using platform::to_void_cast;
41 42 43
using Tensor = framework::Tensor;
using SelectedRows = phi::SelectedRows;
using LoDTensor = framework::LoDTensor;
44

J
Jacek Czaja 已提交
45
template <typename T>
46 47
class SumMKLDNNHandler
    : public platform::MKLDNNHandlerNoCachingT<T, dnnl::sum> {
J
Jacek Czaja 已提交
48
 public:
49 50
  SumMKLDNNHandler(dnnl::engine engine,
                   platform::Place cpu_place,
J
Jacek Czaja 已提交
51
                   const std::vector<framework::Variable*>& in_vars,
52
                   framework::LoDTensor* z)
J
Jacek Czaja 已提交
53

54
      : platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place),
J
Jacek Czaja 已提交
55
        num_inputs_(0) {
56
    auto dst_tz = phi::vectorize<int64_t>(z->dims());
57
    auto src_tz = dst_tz;
J
Jacek Czaja 已提交
58

59
    std::vector<dnnl::memory::desc> srcs_md;
60
    srcs_md.reserve(in_vars.size());
61 62 63 64
    for (size_t i = 0; i < in_vars.size(); i++) {
      auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
      if (input_it.numel() == 0) {
        continue;
J
Jacek Czaja 已提交
65
      }
66
      srcs_md.push_back(input_it.mem_desc());
67 68
      ++num_inputs_;
    }
69
    std::vector<float> scales(num_inputs_, 1.0f);
J
Jacek Czaja 已提交
70

71 72
    auto dst_md = dnnl::memory::desc(
        dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
J
Jacek Czaja 已提交
73

74
    this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
J
Jacek Czaja 已提交
75 76 77 78 79
  }

  // (jczaja) sum oneDNN prim is not having .desc attribute so
  // we cannot use base AcquireForwardPrimitiveDescriptor
  void AcquireForwardPrimitiveDescriptor(
80 81
      const dnnl::memory::desc& dst_md,
      const std::vector<float>& scales,
82
      const std::vector<dnnl::memory::desc>& srcs_md) {
83 84
    this->fwd_pd_.reset(
        new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_));
J
Jacek Czaja 已提交
85 86
  }

87 88
  std::shared_ptr<dnnl::memory> AcquireSrcMemory(const framework::Tensor& input,
                                                 int i) {
J
Jacek Czaja 已提交
89 90
    const T* input_data = input.data<T>();
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
91
                                            to_void_cast<T>(input_data));
J
Jacek Czaja 已提交
92 93
  }

94
  using platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory;
J
Jacek Czaja 已提交
95

96
  std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
97
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc());
J
Jacek Czaja 已提交
98 99 100 101 102 103 104 105
  }

  inline int GetNumInputs(void) { return num_inputs_; }

 private:
  int num_inputs_;
};

106 107 108 109
template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
110 111
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
112 113
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL Sum must use CPUPlace"));
114
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
115
    const auto& mkldnn_engine = dev_ctx.GetEngine();
116
    auto in_vars = ctx.MultiInputVar("X");
117

118
    PADDLE_ENFORCE_NE(
119 120
        in_vars.empty(),
        true,
121
        platform::errors::InvalidArgument("Input variable is empty."));
J
Jacek Czaja 已提交
122
    auto& input0 = in_vars[0]->Get<LoDTensor>();
123
    LoDTensor* output = ctx.Output<LoDTensor>("Out");
124

J
Jacek Czaja 已提交
125
    bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
126

127
    SumMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), in_vars, output);
128

J
Jacek Czaja 已提交
129
    // Create list of SRC MEMs
130
    std::vector<std::shared_ptr<dnnl::memory>> srcs_mem;
J
Jacek Czaja 已提交
131 132
    srcs_mem.reserve(handler.GetNumInputs());
    int input_index = 0;
133
    for (size_t i = 0; i < in_vars.size(); i++) {
J
Jacek Czaja 已提交
134
      auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
135 136
      if (input_it.numel() == 0) {
        continue;
A
Adam 已提交
137
      }
J
Jacek Czaja 已提交
138 139
      srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index));
      ++input_index;
140
    }
141

142 143 144 145 146 147 148
    std::unordered_map<int, dnnl::memory> args;
    std::shared_ptr<dnnl::memory> dst_mem;

    for (size_t i = 0; i < srcs_mem.size(); ++i) {
      args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
    }

149
    if (in_place) {
150
      dst_mem = srcs_mem[0];
151 152 153
    } else {
      dst_mem = handler.AcquireDstMemory(output);
    }
154
    args.insert({DNNL_ARG_DST, *dst_mem});
155

J
Jacek Czaja 已提交
156
    auto sum_p = handler.AcquireForwardPrimitive();
157

158
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
J
Jacek Czaja 已提交
159
    sum_p->execute(astream, args);
160 161
    astream.wait();

162
    output->set_mem_desc(dst_mem->get_desc());
163 164 165 166 167 168
  }
};

}  // namespace operators
}  // namespace paddle

J
Jacek Czaja 已提交
169
REGISTER_OP_KERNEL(
170 171 172
    sum,
    MKLDNN,
    ::paddle::platform::CPUPlace,
J
Jacek Czaja 已提交
173 174
    paddle::operators::SumMKLDNNOpKernel<paddle::platform::bfloat16>,
    paddle::operators::SumMKLDNNOpKernel<float>);