sum_mkldnn_op.cc 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*Licensed under the Apache License, Version 2.0(the "License");
  you may not use this file except in compliance with the License.
  You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

  Unless required by applicable law or agreed to in writing, software
  distributed under the License is distributed on an "AS IS" BASIS,
  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  See the License for the specific language governing permissions and
  limitations under the License. */

#include "paddle/fluid/operators/sum_op.h"
J
Jacek Czaja 已提交
28
#include "paddle/fluid/platform/mkldnn_reuse.h"
29

30
namespace phi {
31
class DenseTensor;
32
}  // namespace phi
33

W
wanghuancoder 已提交
34
namespace paddle {
35
namespace framework {}  // namespace framework
W
wanghuancoder 已提交
36 37 38 39 40
namespace platform {
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

41 42 43
namespace paddle {
namespace operators {

T
tangwei12 已提交
44 45
using paddle::platform::CPUDeviceContext;
using paddle::platform::MKLDNNDeviceContext;
46 47
using platform::to_void_cast;

J
Jacek Czaja 已提交
48
template <typename T>
49 50
class SumMKLDNNHandler
    : public platform::MKLDNNHandlerNoCachingT<T, dnnl::sum> {
J
Jacek Czaja 已提交
51
 public:
52 53
  SumMKLDNNHandler(dnnl::engine engine,
                   platform::Place cpu_place,
J
Jacek Czaja 已提交
54
                   const std::vector<framework::Variable*>& in_vars,
55
                   framework::LoDTensor* z)
J
Jacek Czaja 已提交
56

57
      : platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place),
J
Jacek Czaja 已提交
58
        num_inputs_(0) {
59
    auto dst_tz = phi::vectorize<int64_t>(z->dims());
60
    auto src_tz = dst_tz;
J
Jacek Czaja 已提交
61

62
    std::vector<dnnl::memory::desc> srcs_md;
63
    srcs_md.reserve(in_vars.size());
64 65 66 67
    for (size_t i = 0; i < in_vars.size(); i++) {
      auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
      if (input_it.numel() == 0) {
        continue;
J
Jacek Czaja 已提交
68
      }
69
      srcs_md.push_back(input_it.mem_desc());
70 71
      ++num_inputs_;
    }
72
    std::vector<float> scales(num_inputs_, 1.0f);
J
Jacek Czaja 已提交
73

74 75
    auto dst_md = dnnl::memory::desc(
        dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
J
Jacek Czaja 已提交
76

77
    this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
J
Jacek Czaja 已提交
78 79 80 81 82
  }

  // (jczaja) sum oneDNN prim is not having .desc attribute so
  // we cannot use base AcquireForwardPrimitiveDescriptor
  void AcquireForwardPrimitiveDescriptor(
83 84
      const dnnl::memory::desc& dst_md,
      const std::vector<float>& scales,
85
      const std::vector<dnnl::memory::desc>& srcs_md) {
86 87
    this->fwd_pd_.reset(
        new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_));
J
Jacek Czaja 已提交
88 89
  }

90 91
  std::shared_ptr<dnnl::memory> AcquireSrcMemory(const framework::Tensor& input,
                                                 int i) {
J
Jacek Czaja 已提交
92 93
    const T* input_data = input.data<T>();
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
94
                                            to_void_cast<T>(input_data));
J
Jacek Czaja 已提交
95 96
  }

97
  using platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory;
J
Jacek Czaja 已提交
98

99
  std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
100
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc());
J
Jacek Czaja 已提交
101 102 103 104 105 106 107 108
  }

  inline int GetNumInputs(void) { return num_inputs_; }

 private:
  int num_inputs_;
};

109 110 111 112
template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
113 114
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
115 116
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL Sum must use CPUPlace"));
117
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
118
    const auto& mkldnn_engine = dev_ctx.GetEngine();
119
    auto in_vars = ctx.MultiInputVar("X");
120

121
    PADDLE_ENFORCE_NE(
122 123
        in_vars.empty(),
        true,
124
        platform::errors::InvalidArgument("Input variable is empty."));
J
Jacek Czaja 已提交
125
    auto& input0 = in_vars[0]->Get<LoDTensor>();
126
    LoDTensor* output = ctx.Output<LoDTensor>("Out");
127

J
Jacek Czaja 已提交
128
    bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
129

130
    SumMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), in_vars, output);
131

J
Jacek Czaja 已提交
132
    // Create list of SRC MEMs
133
    std::vector<std::shared_ptr<dnnl::memory>> srcs_mem;
J
Jacek Czaja 已提交
134 135
    srcs_mem.reserve(handler.GetNumInputs());
    int input_index = 0;
136
    for (size_t i = 0; i < in_vars.size(); i++) {
J
Jacek Czaja 已提交
137
      auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
138 139
      if (input_it.numel() == 0) {
        continue;
A
Adam 已提交
140
      }
J
Jacek Czaja 已提交
141 142
      srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index));
      ++input_index;
143
    }
144

145 146 147 148 149 150 151
    std::unordered_map<int, dnnl::memory> args;
    std::shared_ptr<dnnl::memory> dst_mem;

    for (size_t i = 0; i < srcs_mem.size(); ++i) {
      args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
    }

152
    if (in_place) {
153
      dst_mem = srcs_mem[0];
154 155 156
    } else {
      dst_mem = handler.AcquireDstMemory(output);
    }
157
    args.insert({DNNL_ARG_DST, *dst_mem});
158

J
Jacek Czaja 已提交
159
    auto sum_p = handler.AcquireForwardPrimitive();
160

161
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
J
Jacek Czaja 已提交
162
    sum_p->execute(astream, args);
163 164
    astream.wait();

165
    output->set_mem_desc(dst_mem->get_desc());
166 167 168 169 170 171
  }
};

}  // namespace operators
}  // namespace paddle

J
Jacek Czaja 已提交
172
REGISTER_OP_KERNEL(
173 174 175
    sum,
    MKLDNN,
    ::paddle::platform::CPUPlace,
J
Jacek Czaja 已提交
176 177
    paddle::operators::SumMKLDNNOpKernel<paddle::platform::bfloat16>,
    paddle::operators::SumMKLDNNOpKernel<float>);