fc_mkldnn_op.cc 22.2 KB
Newer Older
M
mozga-intel 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <memory>
W
wanghuancoder 已提交
16

17
#include "paddle/fluid/framework/op_registry.h"
18
#include "paddle/fluid/operators/fc_op.h"
19
#include "paddle/fluid/platform/profiler/event_tracing.h"
20
#include "paddle/phi/backends/onednn/onednn_reuse.h"
21

M
mozga-intel 已提交
22 23 24
namespace paddle {
namespace operators {

25
using framework::ExecutionContext;
26
using phi::OneDNNContext;
27 28
using phi::funcs::OneDNNGetDataType;
using phi::funcs::to_void_cast;
29

30 31 32 33 34 35
struct InnerProductCache {
  dnnl::inner_product_forward inner_product_p;
  dnnl::memory src_mem;
  dnnl::memory weights_mem;
  dnnl::memory bias_mem;
  dnnl::memory dst_mem;
36 37 38
  dnnl::memory src_scales_mem;
  dnnl::memory wei_scales_mem;
  dnnl::memory dst_scales_mem;
39
};
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

std::tuple<std::vector<float>,
           std::vector<float>,
           std::vector<float>,
           std::vector<float>>
GetDNNLScales(const ExecutionContext& ctx) {
  auto scale_in_data = ctx.Attr<float>("Scale_in");
  auto scale_out = ctx.Attr<float>("Scale_out");
  auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
  auto scale_in_eltwise_data = ctx.HasAttr("Scale_in_eltwise")
                                   ? ctx.Attr<float>("Scale_in_eltwise")
                                   : 1.0f;

  std::vector<float> dnnl_src_scales = {1.f / scale_in_data};
  size_t count = scale_weights_data.size();
  std::vector<float> dnnl_wei_scales(count);
#pragma omp parallel for if (count > 50)
  for (size_t i = 0; i < count; i++) {
    dnnl_wei_scales[i] = 1.f / scale_weights_data[i];
  }
  std::vector<float> dnnl_psum_scales = {1.f / scale_in_eltwise_data};
  std::vector<float> dnnl_dst_scales = {1.f / scale_out};

  return std::make_tuple(
      dnnl_src_scales, dnnl_wei_scales, dnnl_psum_scales, dnnl_dst_scales);
}

M
Michał Gallus 已提交
67
template <typename T_in, typename T_w, typename T_out>
68
class FCMKLDNNHandler
69 70
    : public phi::funcs::OneDNNHandlerNoCachingT<T_in,
                                                 dnnl::inner_product_forward> {
M
mozga-intel 已提交
71
 public:
72
  FCMKLDNNHandler(const ExecutionContext& ctx,
73
                  const OneDNNContext& dev_ctx,
74 75 76
                  const phi::DenseTensor* x,
                  const phi::DenseTensor* weights,
                  const phi::DenseTensor* bias,
77
                  phi::DenseTensor* out UNUSED,
78
                  const int in_num_col_dims,
79
                  dnnl::engine onednn_engine,
80
                  platform::Place cpu_place)
81
      : phi::funcs::OneDNNHandlerNoCachingT<T_in, dnnl::inner_product_forward>(
82
            onednn_engine, cpu_place),
83 84 85 86 87 88 89 90 91
        dev_ctx_(dev_ctx) {
    this->memory_key_ = ctx.InputName("W");

    auto x_vec_dims = phi::vectorize(x->dims());
    auto weights_vec_dims = phi::vectorize(weights->dims());

    int MB = 1;
    for (int i = 0; i < in_num_col_dims; ++i) {
      MB *= x_vec_dims[i];
92 93
    }

94 95 96
    int IC = 1;
    for (size_t i = in_num_col_dims; i < x_vec_dims.size(); ++i) {
      IC *= x_vec_dims[i];
97
    }
98

99
    int OC = weights_vec_dims[1];
M
mozga-intel 已提交
100

101
    dnnl::memory::desc bias_md;
102

103
    auto src_md = dnnl::memory::desc(
104
        {MB, IC}, OneDNNGetDataType<T_in>(), dnnl::memory::format_tag::any);
105
    auto weights_md = dnnl::memory::desc(
106
        {OC, IC}, OneDNNGetDataType<T_w>(), dnnl::memory::format_tag::any);
107
    auto dst_md = dnnl::memory::desc(
108
        {MB, OC}, OneDNNGetDataType<T_out>(), dnnl::memory::format_tag::any);
109 110
    if (bias) {
      bias_md = dnnl::memory::desc({bias->numel()},
111
                                   OneDNNGetDataType<float>(),
112 113
                                   dnnl::memory::format_tag::a);
    }
114

115
    const auto attrs = CreateFCAttrs(ctx);
A
Adam 已提交
116

117
    this->AcquireForwardPrimitiveDescriptor(attrs,
118
                                            dnnl::prop_kind::forward_inference,
119 120 121 122
                                            src_md,
                                            weights_md,
                                            bias_md,
                                            dst_md);
M
mozga-intel 已提交
123 124
  }

125
 private:
126 127 128
  dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) {
    dnnl::primitive_attr attributes;
    dnnl::post_ops post_operations;
129

130 131
    float sum_scale = 1.0f;
    float activation_scale = 1.0f;
132
    if (phi::funcs::is_int8<T_w>()) {
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
      std::vector<float> src_scales, wei_scales, psum_scales, dst_scales;
      std::tie(src_scales, wei_scales, psum_scales, dst_scales) =
          GetDNNLScales(ctx);

      bool force_fp32_output = ctx.HasAttr("force_fp32_output") &&
                               ctx.Attr<bool>("force_fp32_output");

      attributes.set_scales_mask(DNNL_ARG_SRC, 0);

      dnnl::memory::desc src_scales_md(
          {1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
      src_scales_mem_ = dnnl::memory(src_scales_md, this->engine_);
      memcpy(src_scales_mem_.get_data_handle(),
             src_scales.data(),
             src_scales.size() * sizeof(float));

      int mask = wei_scales.size() > 1 ? 1 : 0;
      attributes.set_scales_mask(DNNL_ARG_WEIGHTS, mask);

      dnnl::memory::desc wei_scales_md(
          {static_cast<int64_t>(wei_scales.size())},
          dnnl::memory::data_type::f32,
          dnnl::memory::format_tag::x);
      wei_scales_mem_ = dnnl::memory(wei_scales_md, this->engine_);
      memcpy(wei_scales_mem_.get_data_handle(),
             wei_scales.data(),
             wei_scales.size() * sizeof(float));

      if (!force_fp32_output) {
        attributes.set_scales_mask(DNNL_ARG_DST, 0);

        dnnl::memory::desc dst_scales_md(
            {1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
        dst_scales_mem_ = dnnl::memory(dst_scales_md, this->engine_);
        memcpy(dst_scales_mem_.get_data_handle(),
               dst_scales.data(),
               dst_scales.size() * sizeof(float));
      }

      sum_scale = psum_scales[0];
173
    }
174

175 176
    if (ctx.HasAttr("fuse_residual_connection") &&
        ctx.Attr<bool>("fuse_residual_connection")) {
177
      post_operations.append_sum(sum_scale);
178
    }
M
mozga-intel 已提交
179

180 181
    // ReLU from "fc_fuse_pass"
    if (ctx.Attr<std::string>("activation_type") == "relu") {
182
      post_operations.append_eltwise(dnnl::algorithm::eltwise_relu, 0.0f, 0.0f);
183
    }
184
    AppendActivation(ctx, post_operations, activation_scale);
185

186 187 188
    if (ctx.HasAttr("fused_output_scale")) {
      float scale_alpha = ctx.Attr<float>("fused_output_scale");
      post_operations.append_eltwise(
189
          dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
190 191
    }

192 193
    attributes.set_post_ops(post_operations);
    return attributes;
194 195
  }

196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
  void AppendActivation(const ExecutionContext& ctx,
                        dnnl::post_ops& post_ops,  // NOLINT
                        float activation_scale = 1.0f) {
    const auto invalid_attribute =
        ctx.HasAttr("fuse_activation")
            ? ctx.Attr<std::string>("fuse_activation").empty()
            : true;
    if (invalid_attribute) return;

    const auto fuse_activation = ctx.Attr<std::string>("fuse_activation");
    const auto fuse_alpha =
        ctx.HasAttr("fuse_alpha") ? ctx.Attr<float>("fuse_alpha") : 0.0f;
    const auto fuse_beta =
        ctx.HasAttr("fuse_beta") ? ctx.Attr<float>("fuse_beta") : 0.0f;

211 212 213 214 215 216 217 218 219 220
    const auto activation_map = phi::funcs::OneDNNActivationMap();
    const auto& activation_type = activation_map.find(fuse_activation);

    PADDLE_ENFORCE_NE(
        activation_type,
        activation_map.end(),
        phi::errors::InvalidArgument(
            "Activation '%s' not found in oneDNN algorithms mapper",
            fuse_activation));

221
    post_ops.append_eltwise(activation_type->second, fuse_alpha, fuse_beta);
222
    post_ops.append_eltwise(
223
        dnnl::algorithm::eltwise_linear, activation_scale, 0.0f);
M
Michał Gallus 已提交
224 225
  }

226 227
  // Computing oneDNN's scaling mask which determines along which dimension
  // slice should the scaling be applied.
M
Michał Gallus 已提交
228 229 230 231
  int CreateMask(int slice_dimension, bool is_multi_channel_quantizied) {
    return is_multi_channel_quantizied ? 1 << slice_dimension : 0;
  }

232 233 234 235
  std::shared_ptr<dnnl::memory> AcquireMemoryWithReorderAndAttrs(
      const dnnl::memory::desc& user_md,
      const dnnl::memory::desc& target_md,
      void* ptr,
236 237
      const dnnl::primitive_attr& attrs,
      const std::vector<float>& scale_data) {
238
    std::shared_ptr<dnnl::memory> target_memory_p;
M
Michał Gallus 已提交
239

240 241 242 243 244
    auto user_memory_p =
        std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
    target_memory_p = std::make_shared<dnnl::memory>(target_md, this->engine_);
    auto reorder_p = std::make_shared<dnnl::reorder>(
        *user_memory_p, *target_memory_p, attrs);
M
Michał Gallus 已提交
245

246 247 248 249 250 251 252 253 254
    auto scales_md =
        dnnl::memory::desc({static_cast<int64_t>(scale_data.size())},
                           dnnl::memory::data_type::f32,
                           dnnl::memory::format_tag::x);
    auto scale_mem =
        dnnl::memory(scales_md,
                     this->engine_,
                     phi::funcs::to_void_cast<float>(scale_data.data()));

255
    auto& astream = OneDNNContext::tls().get_stream();
256
    {
257 258 259 260
      reorder_p->execute(astream,
                         {{DNNL_ARG_FROM, *user_memory_p},
                          {DNNL_ARG_TO, *target_memory_p},
                          {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scale_mem}});
261 262
      astream.wait();
    }
M
Michał Gallus 已提交
263

264 265
    return target_memory_p;
  }
266

267
  std::string memory_key_;
268
  const OneDNNContext& dev_ctx_;
269 270 271
  dnnl::memory src_scales_mem_;
  dnnl::memory wei_scales_mem_;
  dnnl::memory dst_scales_mem_;
M
Michał Gallus 已提交
272

273
 public:
274 275
  std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
      const phi::DenseTensor* x) {
276 277 278 279 280 281
    const T_in* x_data = x->data<T_in>();

    auto user_md = x->mem_desc();
    if (x->dims().size() != 2) {
      // reshape restrictions are always satisfied because in case of 3 or 4 dim
      // input, plain layout is enforced
282
      user_md = user_md.reshape(this->fwd_pd_->src_desc().get_dims());
M
Michał Gallus 已提交
283 284
    }

285 286
    return this->AcquireMemoryWithReorder(
        user_md, this->fwd_pd_->src_desc(), to_void_cast<T_in>(x_data));
287
  }
M
mozga-intel 已提交
288

289
  std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
290
      const ExecutionContext& ctx, const phi::DenseTensor* bias) {
291
    const float* bias_data = bias->data<float>();
292 293
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->bias_desc(),
                                            to_void_cast<float>(bias_data));
294 295 296
  }

  std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
297
      const phi::DenseTensor* weights, const std::vector<float>& scale_data) {
298 299 300
    const std::string weights_key = this->memory_key_ + "@weights";
    auto memory_p = std::static_pointer_cast<dnnl::memory>(
        this->dev_ctx_.GetBlob(weights_key));
M
mozga-intel 已提交
301

302 303
    if (!memory_p) {
      const float* weights_data = weights->data<float>();
304
      auto weights_dims = this->fwd_pd_->weights_desc().get_dims();
305 306

      auto user_md = dnnl::memory::desc(weights_dims,
307
                                        OneDNNGetDataType<float>(),
308 309
                                        dnnl::memory::format_tag::io);

310
      if (phi::funcs::is_int8<T_w>()) {
311 312
        dnnl::primitive_attr attrs;
        int mask = CreateMask(0, scale_data.size() > 1);
313
        attrs.set_scales_mask(DNNL_ARG_SRC, mask);
314 315 316 317 318

        memory_p = this->AcquireMemoryWithReorderAndAttrs(
            user_md,
            this->fwd_pd_->weights_desc(),
            to_void_cast<float>(weights_data),
319 320
            attrs,
            scale_data);
321 322 323 324 325 326 327 328 329 330
      } else {
        memory_p =
            this->AcquireMemoryWithReorder(user_md,
                                           this->fwd_pd_->weights_desc(),
                                           to_void_cast<float>(weights_data));
      }

      this->dev_ctx_.SetBlob(weights_key, memory_p);
    }
    return memory_p;
331
  }
M
mozga-intel 已提交
332

333
  std::shared_ptr<dnnl::memory> AcquireCustomDstMemory(
334
      const ExecutionContext& ctx, phi::DenseTensor* out) {
335 336
    if (ctx.HasAttr("fuse_residual_connection") &&
        ctx.Attr<bool>("fuse_residual_connection")) {
337
      auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
338 339

      PADDLE_ENFORCE_EQ(
340
          out->dims(),
341
          residual_param->dims(),
342
          phi::errors::InvalidArgument(
343 344 345
              "Output and elementwise parameter need to have the "
              "same dimension sizes, but got output's dimension = %d"
              " and residual param's dimension =%d .",
346
              out->dims().size(),
347
              residual_param->dims().size()));
348

349
      out->ShareDataWith(*residual_param);
350
    }
351
    return this->template AcquireDstMemory<T_out>(out);
352
  }  // namespace operators
353 354 355 356 357 358 359 360 361 362 363 364

  void SetScalesIfNeeded(std::unordered_map<int, dnnl::memory>* args) {
    if (src_scales_mem_.get_desc().is_zero() != true) {
      args->insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_mem_});
      args->insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_mem_});
    }
    // dst scales may be empty when force fp32 output
    if (dst_scales_mem_.get(true)) {
      args->insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_mem_});
    }
  }
};  // namespace paddle
365

366 367 368 369 370 371 372 373 374 375
#define IF_CHANGE_FC_TW_TYPENAME(condition, ...) \
  if (condition) {                               \
    using T_w = int8_t;                          \
    __VA_ARGS__();                               \
  } else {                                       \
    using T_w = T_in;                            \
    __VA_ARGS__();                               \
  }

template <typename T_in>
376 377
class FCMKLDNNKernel : public framework::OpKernel<T_in> {
 public:
378
  void Compute(const ExecutionContext& ctx) const override {
379 380
    bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
    bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
381

382 383 384
    IF_CHANGE_FC_TW_TYPENAME((std::is_same<T_in, uint8_t>::value), ([&] {
                               if (force_fp32_output) {
                                 this->RunKernel<float, T_w>(ctx);
385
                               } else if (phi::funcs::is_int8<T_in>()) {
386 387 388 389 390 391 392 393 394
                                 if (fuse_relu) {
                                   this->RunKernel<uint8_t, T_w>(ctx);
                                 } else {
                                   this->RunKernel<int8_t, T_w>(ctx);
                                 }
                               } else {
                                 this->RunKernel<T_in, T_w>(ctx);
                               }
                             }));
395 396
  }

397 398
  void PrepareSrcMem(const std::shared_ptr<dnnl::inner_product_forward>& fc_p
                         UNUSED,
399
                     const std::shared_ptr<dnnl::memory>& src_mem,
400
                     const phi::DenseTensor* x,
401
                     const dnnl::engine& engine) const {
402
    auto x_md = x->mem_desc().reshape(src_mem->get_desc().get_dims());
403 404 405 406
    if (x_md != src_mem->get_desc()) {
      dnnl::memory x_mem(x_md, engine, to_void_cast<T_in>(x->data<T_in>()));
      auto reorder_p = dnnl::reorder(x_mem, *src_mem);

407
      auto& astream = OneDNNContext::tls().get_stream();
408 409 410 411 412 413 414
      reorder_p.execute(astream, x_mem, *src_mem);
      astream.wait();
    } else {
      src_mem->set_data_handle(to_void_cast<T_in>(x->data<T_in>()));
    }
  }

415
  template <typename T_out, typename T_w>
416
  void RunKernel(const ExecutionContext& ctx) const {
417
    const auto& dev_ctx = ctx.template device_context<OneDNNContext>();
418
    const auto& onednn_engine = dev_ctx.GetEngine();
419

420
    const auto* x = ctx.Input<phi::DenseTensor>("Input");
421 422
    const auto* weights = ctx.Input<phi::DenseTensor>("W");
    const auto* bias = ctx.Input<phi::DenseTensor>("Bias");
423
    auto out = ctx.Output<phi::DenseTensor>("Out");
424 425 426

    const auto& scale_weights = ctx.Attr<std::vector<float>>("Scale_weights");

427 428 429 430 431 432 433 434
    std::shared_ptr<dnnl::inner_product_forward> fc_p;
    std::shared_ptr<dnnl::memory> src_memory_p;
    std::shared_ptr<dnnl::memory> weights_memory_p;
    std::shared_ptr<dnnl::memory> bias_memory_p;
    std::shared_ptr<dnnl::memory> dst_memory_p;

    std::string cache_key;
    cache_key.reserve(64);
435
    cache_key = phi::funcs::ExtendKeyWithThreadInfoIfNeeded(
436
        dev_ctx,
437 438 439 440
        phi::funcs::CreateKey(dev_ctx,
                              ctx.InputName("Input"),
                              ctx.InputName("W"),
                              phi::vectorize(x->dims())));
441 442 443 444

    auto inner_product_cache =
        std::static_pointer_cast<InnerProductCache>(dev_ctx.GetBlob(cache_key));

445 446
    RecomputeOutputDims(ctx, x, weights, out);

447 448
    std::unordered_map<int, dnnl::memory> fc_args;

449 450 451 452 453
    if (inner_product_cache) {
      fc_p = std::make_shared<dnnl::inner_product_forward>(
          inner_product_cache->inner_product_p);
      src_memory_p =
          std::make_shared<dnnl::memory>(inner_product_cache->src_mem);
454
      PrepareSrcMem(fc_p, src_memory_p, x, onednn_engine);
455 456 457 458 459 460 461 462

      weights_memory_p =
          std::make_shared<dnnl::memory>(inner_product_cache->weights_mem);

      dst_memory_p =
          std::make_shared<dnnl::memory>(inner_product_cache->dst_mem);
      if (ctx.HasAttr("fuse_residual_connection") &&
          ctx.Attr<bool>("fuse_residual_connection")) {
463
        auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
464 465 466 467 468 469
        out->ShareDataWith(*residual_param);
      }
      auto out_ptr = out->mutable_data<T_out>(
          ctx.GetPlace(), dst_memory_p->get_desc().get_size());
      dst_memory_p->set_data_handle(out_ptr);

470 471 472 473
      fc_args.insert({DNNL_ARG_SRC, *src_memory_p});
      fc_args.insert({DNNL_ARG_WEIGHTS, *weights_memory_p});
      fc_args.insert({DNNL_ARG_DST, *dst_memory_p});

474 475 476
      if (bias) {
        bias_memory_p =
            std::make_shared<dnnl::memory>(inner_product_cache->bias_mem);
477 478 479 480 481 482 483 484 485 486 487 488
        fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p});
      }

      if (inner_product_cache->src_scales_mem.get(true)) {
        fc_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC,
                        inner_product_cache->src_scales_mem});
        fc_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS,
                        inner_product_cache->wei_scales_mem});
      }
      if (inner_product_cache->dst_scales_mem.get(true)) {
        fc_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST,
                        inner_product_cache->dst_scales_mem});
489 490 491 492 493 494 495 496 497 498 499
      }
    } else {
      auto in_col_dims = ctx.Attr<int>("in_num_col_dims");

      FCMKLDNNHandler<T_in, T_w, T_out> handler(ctx,
                                                dev_ctx,
                                                x,
                                                weights,
                                                bias,
                                                out,
                                                in_col_dims,
500
                                                onednn_engine,
501 502 503 504 505 506
                                                ctx.GetPlace());

      src_memory_p = handler.AcquireSrcMemoryWithReorder(x);
      weights_memory_p =
          handler.AcquireWeightsMemoryWithReorder(weights, scale_weights);
      dst_memory_p = handler.AcquireCustomDstMemory(ctx, out);
507 508 509
      fc_args.insert({DNNL_ARG_SRC, *src_memory_p});
      fc_args.insert({DNNL_ARG_WEIGHTS, *weights_memory_p});
      fc_args.insert({DNNL_ARG_DST, *dst_memory_p});
510 511

      if (bias) {
512
        bias_memory_p = handler.AcquireBiasMemoryWithReorder(ctx, bias);
513 514 515 516 517
        fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p});
      }

      if (phi::funcs::is_int8<T_in>()) {
        handler.SetScalesIfNeeded(&fc_args);
518 519 520 521 522
      }

      fc_p = handler.AcquireForwardPrimitive();
    }

523
    auto& astream = OneDNNContext::tls().get_stream();
524 525 526
    fc_p->execute(astream, fc_args);
    astream.wait();

527 528 529 530 531 532 533 534 535
    if (!inner_product_cache) {
      auto ip_cache = std::make_shared<InnerProductCache>();
      ip_cache->inner_product_p = *fc_p;
      ip_cache->src_mem = *src_memory_p;
      ip_cache->weights_mem = *weights_memory_p;
      ip_cache->dst_mem = *dst_memory_p;
      if (bias) {
        ip_cache->bias_mem = *bias_memory_p;
      }
536 537 538 539 540 541 542 543 544 545 546 547
      if (fc_args.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC)) {
        ip_cache->src_scales_mem =
            fc_args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
        ip_cache->wei_scales_mem =
            fc_args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
      }

      if (fc_args.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST)) {
        ip_cache->dst_scales_mem =
            fc_args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
      }

548 549 550
      dev_ctx.SetBlob(cache_key, ip_cache);
    }

551 552 553 554 555 556 557 558 559
    const auto out_md =
        dst_memory_p->get_desc().reshape(phi::vectorize(out->dims()));

    if (ctx.HasAttr("fused_reshape2_shape")) {
      phi::funcs::SetOutMemDescWithReshape2FuseSupport(
          ctx.Attr<std::vector<int>>("fused_reshape2_shape"), out, out_md);
    } else {
      out->set_mem_desc(out_md);
    }
560
  }
M
mozga-intel 已提交
561

562
  void RecomputeOutputDims(const ExecutionContext& ctx,
563
                           const phi::DenseTensor* x,
564
                           const phi::DenseTensor* weights,
565
                           phi::DenseTensor* out) const {
L
luotao1 已提交
566
    int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
567
    bool padding_weights = ctx.Attr<bool>("padding_weights");
568 569
    PADDLE_ENFORCE_EQ(padding_weights,
                      false,
570 571
                      phi::errors::PermissionDenied(
                          "Weight padding in fc can not be used in oneDNN."));
L
luotao1 已提交
572
    std::vector<int64_t> output_dims;
573 574
    FCOutputSize(x->dims(),
                 weights->dims(),
575 576
                 output_dims,
                 in_num_col_dims,
577
                 padding_weights);
578 579
    out->Resize(phi::make_ddim(output_dims));
    out->set_lod(x->lod());
580 581
  }
};
M
mozga-intel 已提交
582 583 584 585

}  // namespace operators
}  // namespace paddle

M
Michał Gallus 已提交
586 587 588 589
// Weights of FC are by default stored using fp32, template argument of weight
// data type implies their destination data type. (What's eventually going to
// be used during computations of kernel).
namespace ops = paddle::operators;
590 591 592

REGISTER_OP_KERNEL(fc,
                   MKLDNN,
593
                   ::phi::CPUPlace,
594 595 596 597
                   ops::FCMKLDNNKernel<float>,
                   ops::FCMKLDNNKernel<paddle::platform::bfloat16>,
                   ops::FCMKLDNNKernel<uint8_t>,
                   ops::FCMKLDNNKernel<int8_t>);