提交 5957f2f2 编写于 作者: Z zhangyang

add PerformBypass inside FeedOp

上级 a9537ca4
......@@ -30,18 +30,45 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
: framework::OperatorBase<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, scope.get()) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() {}
protected:
FeedParam param_;
#ifdef PADDLE_MOBILE_FPGA
void RunImpl() const { fpga::PerformBypass(param_.FpgaArgs()); }
void Init() {
const Tensor *input = param_.InputX();
auto input_ptr = input->data<float>();
Tensor *output = param_.Out();
auto output_ptr = output->mutable_data<half>();
fpga::BypassArgs args;
args.convert_type = fpga::DATA_FP32_TO_FP16;
args.layout_type = fpga::LAYOUT_CHW_TO_HWC;
args.image.address = (void *)input_ptr;
args.image.channels = input->dims()[1];
args.image.height = input->dims()[2];
args.image.width = input->dims()[3];
args.output.address = output_ptr;
param_.SetFpgaArgs(args);
}
void InferShape() const {
auto out_dims = param_.Out()->dims();
out_dims[0] = param_.BatchSize();
param_.Out()->Resize(out_dims);
param_.Out()->ShareDataWith(*param_.InputX()); // TODO How to handle fp16
}
#else
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
protected:
FeedParam param_;
void Init() {}
void InferShape() const {
auto out_dims = param_.Out()->dims();
out_dims[0] = param_.BatchSize();
param_.Out()->Resize(out_dims);
#endif
};
} // namespace operators
......
......@@ -24,13 +24,13 @@ template <>
bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) {
bool relu_enabled = false;
const Tensor *input = param->Input();
auto input_ptr = input->data<float>();
auto input_ptr = input->data<half>();
const Tensor *bias = param->Bias();
auto bias_ptr = bias->data<float>();
const Tensor *filter = param->Filter();
auto filter_ptr = filter->data<float>();
Tensor *out = param->Output();
auto out_ptr = out->mutable_data<float>();
auto out_ptr = out->mutable_data<half>();
auto bn_mean_ptr = param->InputMean()->data<float>();
auto bn_var_ptr = param->InputVariance()->data<float>();
auto bn_scale_ptr = param->InputScale()->data<float>();
......
......@@ -24,13 +24,13 @@ template <>
bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) {
bool relu_enabled = true;
const Tensor *input = param->Input();
auto input_ptr = input->data<float>();
auto input_ptr = input->data<half>();
const Tensor *bias = param->Bias();
auto bias_ptr = bias->data<float>();
const Tensor *filter = param->Filter();
auto filter_ptr = filter->data<float>();
Tensor *out = param->Output();
auto out_ptr = out->mutable_data<float>();
auto out_ptr = out->mutable_data<half>();
auto bn_mean_ptr = param->InputMean()->data<float>();
auto bn_var_ptr = param->InputVariance()->data<float>();
auto bn_scale_ptr = param->InputScale()->data<float>();
......
......@@ -24,13 +24,13 @@ template <>
bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam *param) {
bool relu_enabled = true;
const Tensor *input = param->Input();
auto input_ptr = input->data<float>();
auto input_ptr = input->data<half>();
const Tensor *bias = param->Bias();
auto bias_ptr = bias->data<float>();
const Tensor *filter = param->Filter();
auto filter_ptr = filter->data<float>();
Tensor *out = param->Output();
auto out_ptr = out->mutable_data<float>();
auto out_ptr = out->mutable_data<half>();
PADDLE_MOBILE_ENFORCE(input->dims()[1] == bias->dims()[0],
"Image channel should be equal to bias number");
......
......@@ -665,6 +665,16 @@ class FeedParam : public OpParam {
Tensor *input_x_;
Tensor *out_;
int batch_size;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::BypassArgs fpga_bypass_args;
public:
const fpga::BypassArgs &FpgaArgs() const { return fpga_bypass_args; }
void SetFpgaArgs(const fpga::BypassArgs &args) { fpga_bypass_args = args; }
#endif
};
class FetchParam : public OpParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册