diff --git a/src/operators/kernel/fpga/softmax_kernel.cpp b/src/operators/kernel/fpga/softmax_kernel.cpp index 7c784ce474bbb2588dcf78ecded740777445fc80..f218edbec677845c56fffc63166cc3b5665c6fa4 100644 --- a/src/operators/kernel/fpga/softmax_kernel.cpp +++ b/src/operators/kernel/fpga/softmax_kernel.cpp @@ -25,21 +25,23 @@ namespace operators { template <> bool SoftmaxKernel::Init(SoftmaxParam *param) { const Tensor *input = param->InputX(); - if (input->type() == typeid(half)) { - auto input_ptr = input->data(); - auto output_ptr = param->Out(); - fpga::BypassArgs args; - args.input_layout_type = fpga::LAYOUT_HWC; - args.output_layout_type = fpga::LAYOUT_CHW; - args.input_data_type = fpga::DATA_TYPE_FP16; - args.output_data_type = fpga::DATA_TYPE_FP32; - args.image.address = (void *)(input_ptr); - args.image.height = (uint32_t)input->dims()[0]; - args.image.width = (uint32_t)input->dims()[1]; - args.image.channels = 1; - args.output.address = output_ptr; - param->SetFpgaArgs(args); - } + auto input_ptr = input->data(); + auto output_ptr = param->Out(); + Tensor *floatInput = new Tensor(*input); + fpga::BypassArgs args; + args.input_layout_type = fpga::LAYOUT_HWC; + args.output_layout_type = fpga::LAYOUT_CHW; + args.input_data_type = fpga::DATA_TYPE_FP16; + args.output_data_type = fpga::DATA_TYPE_FP32; + args.image.address = reinterpret_cast(input_ptr); + args.image.height = (uint32_t)input->dims()[0]; + args.image.width = (uint32_t)input->dims()[1]; + args.image.channels = 1; + args.output.address = + reinterpret_cast floatInput->mutable_data(); + + param->SetFloatInput(floatInput); + param->SetFpgaArgs(args); return true; } @@ -48,8 +50,13 @@ void SoftmaxKernel::Compute( const SoftmaxParam ¶m) const { DLOG << "======================================= FPGA SoftMAX " "==============================================="; - const Tensor *in_x = param.InputX(); + const Tensor *in_x = param.FloatInput(); Tensor *out = param.Out(); + fpga::fpga_flush(reinterpret_cast in_x->data(), + in_x->memory_size()); + fpga::PerformBypass(param.FpgaArgs()); + fpga::fpga_invalidate(out->data(), out->memory_size()); + auto x_dims = in_x->dims(); out->Resize(x_dims); math::SoftmaxFuntor()(in_x, out); diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 1728e6a6cc778ec223c3f14c971404ba3a5cc0f7..1c5815c64236f1b67fb6ab7752d0c4caef7c2646 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -785,7 +785,7 @@ class SoftmaxParam : public OpParam { fpga::BypassArgs fpga_bypass_args; public: - RType *FloatInput() { + RType *FloatInput() const { return float_input_x_ == nullptr ? input_x_ : float_input_x_.get(); } void SetFloatInput(Tensor *input) { float_input_x_.reset(input); }