elementwise_add_mkldnn_op.cc 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
16

W
wanghuancoder 已提交
17 18 19 20 21 22 23 24 25 26
namespace paddle {
namespace framework {
class ExecutionContext;
}  // namespace framework
namespace platform {
class CPUDeviceContext;
struct CPUPlace;
}  // namespace platform
}  // namespace paddle

27 28 29
namespace paddle {
namespace operators {
template <typename T>
30
class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
31 32
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
33
    ElemwiseGradKernel<T>::Compute(ctx);
34 35
    using Tensor = framework::Tensor;

36 37 38 39
    auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
    const auto& onednn_engine = dev_ctx.GetEngine();

40 41 42 43
    auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));

44 45 46 47 48 49 50 51 52 53
    auto tz = paddle::framework::vectorize<int64_t>(dout->dims());
    memory::data_type dout_type = framework::ToMKLDNNDataType(dout->type());
    std::string key = platform::CreateKey(dev_ctx, tz, dout->format(),
                                          dout->format(), dout_type);
    platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type, dev_ctx,
                                           onednn_engine, key);

    mkldnn::stream astream(onednn_engine);
    auto reorder_src_memory_p = handler.AcquireSrcMemory(
        dout->format(), platform::to_void_cast(dout->data<T>()));
54

55
    if (dx) {
56 57 58 59 60 61 62 63
      auto reorder_dst_memory_p =
          handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace());
      auto reorder_p =
          handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
      platform::RecordEvent record_reorder("int_reorder",
                                           platform::EventRole::kUniqueOp);
      reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
      astream.wait();
64 65 66
    }

    if (dy) {
67 68 69 70 71 72 73 74
      auto reorder_dst_memory_p =
          handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace());
      auto reorder_p =
          handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
      platform::RecordEvent record_reorder("int_reorder",
                                           platform::EventRole::kUniqueOp);
      reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
      astream.wait();
75 76 77 78 79 80 81 82 83
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

84 85 86
REGISTER_OP_KERNEL(
    elementwise_add, MKLDNN, ::paddle::platform::CPUPlace,
    ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_add>,
87 88
    ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16,
                             dnnl::algorithm::binary_add>,
89 90
    ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_add>,
    ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_add>)
91 92 93

REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace,
                   ops::EltwiseAddMKLDNNGradKernel<float>)