fc_xpu_kernel.cc 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
namespace fusion {

template <typename T, typename Context>
void FcXPUKernel(const Context& ctx,
                 const DenseTensor& x,
                 const DenseTensor& w,
                 const DenseTensor& w_max,
                 const paddle::optional<DenseTensor>& bias,
                 int in_num_col_dims,
                 bool transpose_x,
                 float alpha,
                 float beta,
                 int act_type,
                 float act_alpha,
                 DenseTensor* out) {
  auto in_mat_dims = flatten_to_2d(x.dims(), in_num_col_dims);
  int m = in_mat_dims[0];
  int k = in_mat_dims[1];
  int n = w.dims()[0];
  const float* bias_data =
      bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<T>();
  xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
  if (act_type == 5) {
    act.leaky_alpha = act_alpha;
  } else if (act_type == 15) {
    act.hard_sigmoid_slope = act_alpha;
  }
  ctx.template Alloc<T>(out);
  int r = xpu::fc_fusion<T, int16_t, T, int16_t>(  // TX, TW. TY, TGEMM
      ctx.x_context(),                             // ctx
      x.data<T>(),                                 // x
      w.data<int16_t>(),                           // w
      out->data<T>(),                              // y
      m,                                           // m
      n,                                           // n
      k,                                           // k
      transpose_x,                                 // x_trans
      true,                                        // w_trans
      nullptr,                                     // x_maxptr
      w_max.data<float>(),                         // w_maxptr
      nullptr,                                     // y_maxptr
      transpose_x ? m : k,                         // ldx
      k,                                           // ldw
      n,                                           // ldy
      alpha,                                       // alpha
      beta,                                        // beta
      bias_data,                                   // bias
      act);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu");
}

}  // namespace fusion
}  // namespace phi

PD_REGISTER_KERNEL(fc_xpu, XPU, ALL_LAYOUT, phi::fusion::FcXPUKernel, float) {}