sum_mkldnn_op.cc 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*Licensed under the Apache License, Version 2.0(the "License");
  you may not use this file except in compliance with the License.
  You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

  Unless required by applicable law or agreed to in writing, software
  distributed under the License is distributed on an "AS IS" BASIS,
  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  See the License for the specific language governing permissions and
  limitations under the License. */

#include "paddle/fluid/operators/sum_op.h"
J
Jacek Czaja 已提交
28
#include "paddle/fluid/platform/mkldnn_reuse.h"
29

W
wanghuancoder 已提交
30 31 32 33 34 35 36 37 38 39
namespace paddle {
namespace framework {
class Tensor;
}  // namespace framework
namespace platform {
class CPUDeviceContext;
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

40 41 42
namespace paddle {
namespace operators {

T
tangwei12 已提交
43 44
using paddle::platform::CPUDeviceContext;
using paddle::platform::MKLDNNDeviceContext;
45 46
using platform::to_void_cast;

J
Jacek Czaja 已提交
47
template <typename T>
48 49
class SumMKLDNNHandler
    : public platform::MKLDNNHandlerNoCachingT<T, dnnl::sum> {
J
Jacek Czaja 已提交
50
 public:
51
  SumMKLDNNHandler(dnnl::engine engine, platform::Place cpu_place,
J
Jacek Czaja 已提交
52
                   const std::vector<framework::Variable*>& in_vars,
53
                   framework::LoDTensor* z)
J
Jacek Czaja 已提交
54

55
      : platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place),
J
Jacek Czaja 已提交
56
        num_inputs_(0) {
57 58
    auto dst_tz = framework::vectorize<int64_t>(z->dims());
    auto src_tz = dst_tz;
J
Jacek Czaja 已提交
59

60
    std::vector<dnnl::memory::desc> srcs_md;
61 62 63 64
    for (size_t i = 0; i < in_vars.size(); i++) {
      auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
      if (input_it.numel() == 0) {
        continue;
J
Jacek Czaja 已提交
65
      }
66
      MKLDNNMemoryFormat input_format = input_it.format();
67
      srcs_md.push_back(dnnl::memory::desc(
68 69 70 71
          src_tz, platform::MKLDNNGetDataType<T>(), input_format));
      ++num_inputs_;
    }
    std::vector<float> scales(num_inputs_, 1.0);
J
Jacek Czaja 已提交
72

73 74
    auto dst_md = dnnl::memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
                                     MKLDNNMemoryFormat::any);
J
Jacek Czaja 已提交
75

76
    this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
J
Jacek Czaja 已提交
77 78 79 80 81
  }

  // (jczaja) sum oneDNN prim is not having .desc attribute so
  // we cannot use base AcquireForwardPrimitiveDescriptor
  void AcquireForwardPrimitiveDescriptor(
82 83
      const dnnl::memory::desc& dst_md, const std::vector<float>& scales,
      const std::vector<dnnl::memory::desc>& srcs_md) {
84 85
    this->fwd_pd_.reset(
        new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_));
J
Jacek Czaja 已提交
86 87
  }

88 89
  std::shared_ptr<dnnl::memory> AcquireSrcMemory(const framework::Tensor& input,
                                                 int i) {
J
Jacek Czaja 已提交
90 91
    const T* input_data = input.data<T>();
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
92
                                            to_void_cast<T>(input_data));
J
Jacek Czaja 已提交
93 94
  }

95
  using platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory;
J
Jacek Czaja 已提交
96

97
  std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
98
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc());
J
Jacek Czaja 已提交
99 100 101 102 103 104 105 106
  }

  inline int GetNumInputs(void) { return num_inputs_; }

 private:
  int num_inputs_;
};

107 108 109 110
template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
111 112 113
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL Sum must use CPUPlace"));
114
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
115
    const auto& mkldnn_engine = dev_ctx.GetEngine();
116
    auto in_vars = ctx.MultiInputVar("X");
117 118 119

    PADDLE_ENFORCE_NE(in_vars.empty(), true, platform::errors::InvalidArgument(
                                                 "Input variable is empty."));
J
Jacek Czaja 已提交
120
    auto& input0 = in_vars[0]->Get<LoDTensor>();
121
    LoDTensor* output = ctx.Output<LoDTensor>("Out");
122

J
Jacek Czaja 已提交
123
    bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
124

125
    SumMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), in_vars, output);
126

J
Jacek Czaja 已提交
127
    // Create list of SRC MEMs
128
    std::vector<std::shared_ptr<dnnl::memory>> srcs_mem;
J
Jacek Czaja 已提交
129 130
    srcs_mem.reserve(handler.GetNumInputs());
    int input_index = 0;
131
    for (size_t i = 0; i < in_vars.size(); i++) {
J
Jacek Czaja 已提交
132
      auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
133 134
      if (input_it.numel() == 0) {
        continue;
A
Adam 已提交
135
      }
J
Jacek Czaja 已提交
136 137
      srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index));
      ++input_index;
138
    }
139

140 141 142 143 144 145 146
    std::shared_ptr<dnnl::memory> dst_mem = nullptr;
    if (in_place) {
      dst_mem = handler.AcquireDstMemory();
      output->mutable_data<T>(ctx.GetPlace());
    } else {
      dst_mem = handler.AcquireDstMemory(output);
    }
147

J
Jacek Czaja 已提交
148
    auto sum_p = handler.AcquireForwardPrimitive();
149

150
    std::unordered_map<int, dnnl::memory> args;
151
    for (size_t i = 0; i < srcs_mem.size(); ++i) {
152
      args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
153
    }
154
    args.insert({DNNL_ARG_DST, *dst_mem});
155

156
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
J
Jacek Czaja 已提交
157
    sum_p->execute(astream, args);
158 159
    astream.wait();

J
Jacek Czaja 已提交
160 161
    // For in-place execution which sum does not have we need to fake it
    // so from oneDNN dst memory we reorder data into input
162
    if (in_place) {
J
Jacek Czaja 已提交
163 164 165 166
      auto& in_out = in_vars[0]->Get<framework::LoDTensor>();
      auto output_tz = framework::vectorize<int64_t>(output->dims());
      platform::ReorderMKLDNNHandler reorder_handler(
          output_tz, output->type(), framework::ToMKLDNNDataType(in_out.type()),
167
          dev_ctx.GetEngine());
J
Jacek Czaja 已提交
168 169 170 171 172

      auto target_mem = reorder_handler.AcquireDstMemory(
          output, in_out.format(), ctx.GetPlace());

      auto reorder_p = reorder_handler.AcquireReorder(target_mem, dst_mem);
173 174 175 176 177 178
      {
        platform::RecordEvent record_reorder("int_reorder",
                                             platform::EventRole::kUniqueOp);
        reorder_p->execute(astream, *dst_mem, *target_mem);
        astream.wait();
      }
179
    }
J
Jacek Czaja 已提交
180 181
    output->set_layout(framework::DataLayout::kMKLDNN);
    output->set_format(platform::GetMKLDNNFormat(*dst_mem));
182 183 184 185 186 187
  }
};

}  // namespace operators
}  // namespace paddle

J
Jacek Czaja 已提交
188 189 190 191
REGISTER_OP_KERNEL(
    sum, MKLDNN, ::paddle::platform::CPUPlace,
    paddle::operators::SumMKLDNNOpKernel<paddle::platform::bfloat16>,
    paddle::operators::SumMKLDNNOpKernel<float>);