提交 c61af417 编写于 作者: Z zhangyang0701 提交者: GitHub

Merge pull request #962 from chonwhite/develop

fix:#961
...@@ -25,21 +25,22 @@ namespace operators { ...@@ -25,21 +25,22 @@ namespace operators {
template <> template <>
bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) { bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) {
const Tensor *input = param->InputX(); const Tensor *input = param->InputX();
if (input->type() == typeid(half)) { auto input_ptr = input->data<float>();
auto input_ptr = input->data<half>(); auto output_ptr = param->Out();
auto output_ptr = param->Out(); Tensor *floatInput = new Tensor(*input);
fpga::BypassArgs args; fpga::BypassArgs args;
args.input_layout_type = fpga::LAYOUT_HWC; args.input_layout_type = fpga::LAYOUT_HWC;
args.output_layout_type = fpga::LAYOUT_CHW; args.output_layout_type = fpga::LAYOUT_CHW;
args.input_data_type = fpga::DATA_TYPE_FP16; args.input_data_type = fpga::DATA_TYPE_FP16;
args.output_data_type = fpga::DATA_TYPE_FP32; args.output_data_type = fpga::DATA_TYPE_FP32;
args.image.address = (void *)(input_ptr); args.image.address = (void *)(input_ptr);
args.image.height = (uint32_t)input->dims()[0]; args.image.height = (uint32_t)input->dims()[0];
args.image.width = (uint32_t)input->dims()[1]; args.image.width = (uint32_t)input->dims()[1];
args.image.channels = 1; args.image.channels = 1;
args.output.address = output_ptr; args.output.address = (void *)floatInput->mutable_data<float>();
param->SetFpgaArgs(args);
} param->SetFloatInput(floatInput);
param->SetFpgaArgs(args);
return true; return true;
} }
...@@ -48,8 +49,12 @@ void SoftmaxKernel<FPGA, float>::Compute( ...@@ -48,8 +49,12 @@ void SoftmaxKernel<FPGA, float>::Compute(
const SoftmaxParam<FPGA> &param) const { const SoftmaxParam<FPGA> &param) const {
DLOG << "======================================= FPGA SoftMAX " DLOG << "======================================= FPGA SoftMAX "
"==============================================="; "===============================================";
const Tensor *in_x = param.InputX(); const Tensor *in_x = param.FloatInput();
Tensor *out = param.Out(); Tensor *out = param.Out();
fpga::fpga_flush((void *)in_x->data<float>(), in_x->memory_size());
fpga::PerformBypass(param.FpgaArgs());
fpga::fpga_invalidate(out->data<float>(), out->memory_size());
auto x_dims = in_x->dims(); auto x_dims = in_x->dims();
out->Resize(x_dims); out->Resize(x_dims);
math::SoftmaxFuntor<CPU, float>()(in_x, out); math::SoftmaxFuntor<CPU, float>()(in_x, out);
......
...@@ -785,7 +785,7 @@ class SoftmaxParam : public OpParam { ...@@ -785,7 +785,7 @@ class SoftmaxParam : public OpParam {
fpga::BypassArgs fpga_bypass_args; fpga::BypassArgs fpga_bypass_args;
public: public:
RType *FloatInput() { RType *FloatInput() const {
return float_input_x_ == nullptr ? input_x_ : float_input_x_.get(); return float_input_x_ == nullptr ? input_x_ : float_input_x_.get();
} }
void SetFloatInput(Tensor *input) { float_input_x_.reset(input); } void SetFloatInput(Tensor *input) { float_input_x_.reset(input); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册