// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "lite/kernels/xpu/search_fc_compute.h" #include #include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { namespace xpu { void SearchFcCompute::PrepareForRun() { maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(float)); } void SearchFcCompute::Run() { auto& param = this->template Param(); auto& ctx = this->ctx_->template As(); auto* bottom = param.X; auto* w = param.W; auto* b = param.b; auto* top = param.Out; float w_max = param.__xpu__w_max; int out_size = param.out_size; bool fuse_relu = param.fuse_relu; bool float_to_fix = param.__xpu__float_to_fix; CHECK(float_to_fix) << "W should be fixed point"; int batch = bottom->dims()[0]; int _out = w->dims()[0]; int _in = w->dims()[1]; xdnn::Activation_t act = xdnn::Activation_t::LINEAR; if (fuse_relu) { act = xdnn::Activation_t::RELU; } std::vector top_dims{bottom->dims()[0], out_size}; top->Resize(top_dims); const auto* bottom_data = bottom->data(); const auto* weights = w->data(); const auto* bias_data = b->data(); auto* top_data = top->mutable_data(TARGET(kXPU)); float* maxs_xpu = reinterpret_cast(maxs_xpu_guard_->addr_); float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, w_max, 0.0f, 0.0f, 0.0f}; xpu_memcpy(maxs_xpu, &maxs_cpu[0], 8 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE); int r = xdnn::findmax( ctx.GetRawContext(), bottom_data, batch * _in, maxs_xpu); CHECK_EQ(r, 0); r = xdnn::gemm_int16_maxptr( ctx.GetRawContext(), /* ctx */ false, true, /*trans_a, trans_b*/ batch, _out, _in, /*m, n, k*/ 1.0f, bottom_data, _in, /*alpha, data_a, lda*/ weights, _in, 0.0f, /*data_b, ldb, beta*/ top_data, _out, bias_data, /* data_c, ldc, bias*/ act, maxs_xpu, maxs_xpu + 4, nullptr /*act, max_a, max_b, max_c*/); CHECK_EQ(r, 0); } } // namespace xpu } // namespace kernels } // namespace lite } // namespace paddle REGISTER_LITE_KERNEL(search_fc, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::SearchFcCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("b", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize();