提交 84a38b19 编写于 作者: H hanbuhe

softmax added cpu cache flush and invalidate

上级 2af471b7
...@@ -25,21 +25,23 @@ namespace operators { ...@@ -25,21 +25,23 @@ namespace operators {
template <> template <>
bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) { bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) {
const Tensor *input = param->InputX(); const Tensor *input = param->InputX();
if (input->type() == typeid(half)) { auto input_ptr = input->data<float>();
auto input_ptr = input->data<half>(); auto output_ptr = param->Out();
auto output_ptr = param->Out(); Tensor *floatInput = new Tensor(*input);
fpga::BypassArgs args; fpga::BypassArgs args;
args.input_layout_type = fpga::LAYOUT_HWC; args.input_layout_type = fpga::LAYOUT_HWC;
args.output_layout_type = fpga::LAYOUT_CHW; args.output_layout_type = fpga::LAYOUT_CHW;
args.input_data_type = fpga::DATA_TYPE_FP16; args.input_data_type = fpga::DATA_TYPE_FP16;
args.output_data_type = fpga::DATA_TYPE_FP32; args.output_data_type = fpga::DATA_TYPE_FP32;
args.image.address = (void *)(input_ptr); args.image.address = reinterpret_cast<void *>(input_ptr);
args.image.height = (uint32_t)input->dims()[0]; args.image.height = (uint32_t)input->dims()[0];
args.image.width = (uint32_t)input->dims()[1]; args.image.width = (uint32_t)input->dims()[1];
args.image.channels = 1; args.image.channels = 1;
args.output.address = output_ptr; args.output.address =
param->SetFpgaArgs(args); reinterpret_cast<void *> floatInput->mutable_data<float>();
}
param->SetFloatInput(floatInput);
param->SetFpgaArgs(args);
return true; return true;
} }
...@@ -48,8 +50,13 @@ void SoftmaxKernel<FPGA, float>::Compute( ...@@ -48,8 +50,13 @@ void SoftmaxKernel<FPGA, float>::Compute(
const SoftmaxParam<FPGA> &param) const { const SoftmaxParam<FPGA> &param) const {
DLOG << "======================================= FPGA SoftMAX " DLOG << "======================================= FPGA SoftMAX "
"==============================================="; "===============================================";
const Tensor *in_x = param.InputX(); const Tensor *in_x = param.FloatInput();
Tensor *out = param.Out(); Tensor *out = param.Out();
fpga::fpga_flush(reinterpret_cast<void *> in_x->data<float>(),
in_x->memory_size());
fpga::PerformBypass(param.FpgaArgs());
fpga::fpga_invalidate(out->data<float>(), out->memory_size());
auto x_dims = in_x->dims(); auto x_dims = in_x->dims();
out->Resize(x_dims); out->Resize(x_dims);
math::SoftmaxFuntor<CPU, float>()(in_x, out); math::SoftmaxFuntor<CPU, float>()(in_x, out);
......
...@@ -785,7 +785,7 @@ class SoftmaxParam : public OpParam { ...@@ -785,7 +785,7 @@ class SoftmaxParam : public OpParam {
fpga::BypassArgs fpga_bypass_args; fpga::BypassArgs fpga_bypass_args;
public: public:
RType *FloatInput() { RType *FloatInput() const {
return float_input_x_ == nullptr ? input_x_ : float_input_x_.get(); return float_input_x_ == nullptr ? input_x_ : float_input_x_.get();
} }
void SetFloatInput(Tensor *input) { float_input_x_.reset(input); } void SetFloatInput(Tensor *input) { float_input_x_.reset(input); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册