提交 b2d6610e 编写于 作者: H hanbuhe

softmax added cpu cache flush and invalidate

上级 3a1c22bd
......@@ -25,21 +25,23 @@ namespace operators {
template <>
bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) {
const Tensor *input = param->InputX();
if (input->type() == typeid(half)) {
auto input_ptr = input->data<half>();
auto output_ptr = param->Out();
fpga::BypassArgs args;
args.input_layout_type = fpga::LAYOUT_HWC;
args.output_layout_type = fpga::LAYOUT_CHW;
args.input_data_type = fpga::DATA_TYPE_FP16;
args.output_data_type = fpga::DATA_TYPE_FP32;
args.image.address = (void *)(input_ptr);
args.image.height = (uint32_t)input->dims()[0];
args.image.width = (uint32_t)input->dims()[1];
args.image.channels = 1;
args.output.address = output_ptr;
param->SetFpgaArgs(args);
}
auto input_ptr = input->data<float>();
auto output_ptr = param->Out();
Tensor *floatInput = new Tensor(*input);
fpga::BypassArgs args;
args.input_layout_type = fpga::LAYOUT_HWC;
args.output_layout_type = fpga::LAYOUT_CHW;
args.input_data_type = fpga::DATA_TYPE_FP16;
args.output_data_type = fpga::DATA_TYPE_FP32;
args.image.address = reinterpret_cast<void *>(input_ptr);
args.image.height = (uint32_t)input->dims()[0];
args.image.width = (uint32_t)input->dims()[1];
args.image.channels = 1;
args.output.address =
reinterpret_cast<void *> floatInput->mutable_data<float>();
param->SetFloatInput(floatInput);
param->SetFpgaArgs(args);
return true;
}
......@@ -48,8 +50,13 @@ void SoftmaxKernel<FPGA, float>::Compute(
const SoftmaxParam<FPGA> &param) const {
DLOG << "======================================= FPGA SoftMAX "
"===============================================";
const Tensor *in_x = param.InputX();
const Tensor *in_x = param.FloatInput();
Tensor *out = param.Out();
fpga::fpga_flush(reinterpret_cast<void *> in_x->data<float>(),
in_x->memory_size());
fpga::PerformBypass(param.FpgaArgs());
fpga::fpga_invalidate(out->data<float>(), out->memory_size());
auto x_dims = in_x->dims();
out->Resize(x_dims);
math::SoftmaxFuntor<CPU, float>()(in_x, out);
......
......@@ -785,7 +785,7 @@ class SoftmaxParam : public OpParam {
fpga::BypassArgs fpga_bypass_args;
public:
RType *FloatInput() {
RType *FloatInput() const {
return float_input_x_ == nullptr ? input_x_ : float_input_x_.get();
}
void SetFloatInput(Tensor *input) { float_input_x_.reset(input); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册