sum_mkldnn_op.cc 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*Licensed under the Apache License, Version 2.0(the "License");
  you may not use this file except in compliance with the License.
  You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

  Unless required by applicable law or agreed to in writing, software
  distributed under the License is distributed on an "AS IS" BASIS,
  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  See the License for the specific language governing permissions and
  limitations under the License. */

#include "paddle/fluid/operators/sum_op.h"
J
Jacek Czaja 已提交
28
#include "paddle/fluid/platform/mkldnn_reuse.h"
29

30
namespace phi {
31
class DenseTensor;
32
}  // namespace phi
33

W
wanghuancoder 已提交
34
namespace paddle {
35
namespace framework {}  // namespace framework
W
wanghuancoder 已提交
36 37 38 39 40 41
namespace platform {
class CPUDeviceContext;
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

42 43 44
namespace paddle {
namespace operators {

T
tangwei12 已提交
45 46
using paddle::platform::CPUDeviceContext;
using paddle::platform::MKLDNNDeviceContext;
47 48
using platform::to_void_cast;

J
Jacek Czaja 已提交
49
template <typename T>
50 51
class SumMKLDNNHandler
    : public platform::MKLDNNHandlerNoCachingT<T, dnnl::sum> {
J
Jacek Czaja 已提交
52
 public:
53
  SumMKLDNNHandler(dnnl::engine engine, platform::Place cpu_place,
J
Jacek Czaja 已提交
54
                   const std::vector<framework::Variable*>& in_vars,
55
                   framework::LoDTensor* z)
J
Jacek Czaja 已提交
56

57
      : platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place),
J
Jacek Czaja 已提交
58
        num_inputs_(0) {
59
    auto dst_tz = phi::vectorize<int64_t>(z->dims());
60
    auto src_tz = dst_tz;
J
Jacek Czaja 已提交
61

62
    std::vector<dnnl::memory::desc> srcs_md;
63 64 65 66
    for (size_t i = 0; i < in_vars.size(); i++) {
      auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
      if (input_it.numel() == 0) {
        continue;
J
Jacek Czaja 已提交
67
      }
68
      MKLDNNMemoryFormat input_format = input_it.format();
69
      srcs_md.push_back(dnnl::memory::desc(
70 71 72 73
          src_tz, platform::MKLDNNGetDataType<T>(), input_format));
      ++num_inputs_;
    }
    std::vector<float> scales(num_inputs_, 1.0);
J
Jacek Czaja 已提交
74

75 76
    auto dst_md = dnnl::memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
                                     MKLDNNMemoryFormat::any);
J
Jacek Czaja 已提交
77

78
    this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
J
Jacek Czaja 已提交
79 80 81 82 83
  }

  // (jczaja) sum oneDNN prim is not having .desc attribute so
  // we cannot use base AcquireForwardPrimitiveDescriptor
  void AcquireForwardPrimitiveDescriptor(
84 85
      const dnnl::memory::desc& dst_md, const std::vector<float>& scales,
      const std::vector<dnnl::memory::desc>& srcs_md) {
86 87
    this->fwd_pd_.reset(
        new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_));
J
Jacek Czaja 已提交
88 89
  }

90 91
  std::shared_ptr<dnnl::memory> AcquireSrcMemory(const framework::Tensor& input,
                                                 int i) {
J
Jacek Czaja 已提交
92 93
    const T* input_data = input.data<T>();
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
94
                                            to_void_cast<T>(input_data));
J
Jacek Czaja 已提交
95 96
  }

97
  using platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory;
J
Jacek Czaja 已提交
98

99
  std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
100
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc());
J
Jacek Czaja 已提交
101 102 103 104 105 106 107 108
  }

  inline int GetNumInputs(void) { return num_inputs_; }

 private:
  int num_inputs_;
};

109 110 111 112
template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
113 114 115
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL Sum must use CPUPlace"));
116
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
117
    const auto& mkldnn_engine = dev_ctx.GetEngine();
118
    auto in_vars = ctx.MultiInputVar("X");
119 120 121

    PADDLE_ENFORCE_NE(in_vars.empty(), true, platform::errors::InvalidArgument(
                                                 "Input variable is empty."));
J
Jacek Czaja 已提交
122
    auto& input0 = in_vars[0]->Get<LoDTensor>();
123
    LoDTensor* output = ctx.Output<LoDTensor>("Out");
124

J
Jacek Czaja 已提交
125
    bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
126

127
    SumMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), in_vars, output);
128

J
Jacek Czaja 已提交
129
    // Create list of SRC MEMs
130
    std::vector<std::shared_ptr<dnnl::memory>> srcs_mem;
J
Jacek Czaja 已提交
131 132
    srcs_mem.reserve(handler.GetNumInputs());
    int input_index = 0;
133
    for (size_t i = 0; i < in_vars.size(); i++) {
J
Jacek Czaja 已提交
134
      auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
135 136
      if (input_it.numel() == 0) {
        continue;
A
Adam 已提交
137
      }
J
Jacek Czaja 已提交
138 139
      srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index));
      ++input_index;
140
    }
141

142 143 144 145 146 147 148
    std::shared_ptr<dnnl::memory> dst_mem = nullptr;
    if (in_place) {
      dst_mem = handler.AcquireDstMemory();
      output->mutable_data<T>(ctx.GetPlace());
    } else {
      dst_mem = handler.AcquireDstMemory(output);
    }
149

J
Jacek Czaja 已提交
150
    auto sum_p = handler.AcquireForwardPrimitive();
151

152
    std::unordered_map<int, dnnl::memory> args;
153
    for (size_t i = 0; i < srcs_mem.size(); ++i) {
154
      args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
155
    }
156
    args.insert({DNNL_ARG_DST, *dst_mem});
157

158
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
J
Jacek Czaja 已提交
159
    sum_p->execute(astream, args);
160 161
    astream.wait();

J
Jacek Czaja 已提交
162 163
    // For in-place execution which sum does not have we need to fake it
    // so from oneDNN dst memory we reorder data into input
164
    if (in_place) {
J
Jacek Czaja 已提交
165
      auto& in_out = in_vars[0]->Get<framework::LoDTensor>();
166
      auto output_tz = phi::vectorize<int64_t>(output->dims());
J
Jacek Czaja 已提交
167
      platform::ReorderMKLDNNHandler reorder_handler(
168 169 170
          output_tz, framework::TransToProtoVarType(output->dtype()),
          framework::ToMKLDNNDataType(
              framework::TransToProtoVarType(in_out.dtype())),
171
          dev_ctx.GetEngine());
J
Jacek Czaja 已提交
172 173 174 175 176

      auto target_mem = reorder_handler.AcquireDstMemory(
          output, in_out.format(), ctx.GetPlace());

      auto reorder_p = reorder_handler.AcquireReorder(target_mem, dst_mem);
177 178 179

      reorder_p->execute(astream, *dst_mem, *target_mem);
      astream.wait();
180
    }
J
Jacek Czaja 已提交
181 182
    output->set_layout(framework::DataLayout::kMKLDNN);
    output->set_format(platform::GetMKLDNNFormat(*dst_mem));
183 184 185 186 187 188
  }
};

}  // namespace operators
}  // namespace paddle

J
Jacek Czaja 已提交
189 190 191 192
REGISTER_OP_KERNEL(
    sum, MKLDNN, ::paddle::platform::CPUPlace,
    paddle::operators::SumMKLDNNOpKernel<paddle::platform::bfloat16>,
    paddle::operators::SumMKLDNNOpKernel<float>);