diff --git a/src/operators/feed_op.h b/src/operators/feed_op.h index 8056e78aed3b67daf7dead7987e80529b85f0d60..b34c7cf78b0b808e512e68e5429671bf8e9d8c4a 100644 --- a/src/operators/feed_op.h +++ b/src/operators/feed_op.h @@ -31,12 +31,14 @@ class FeedOp : public framework::OperatorBase { scope), param_(inputs, outputs, attrs, scope.get()) {} - protected: - FeedParam param_; + void InferShape() const { + auto out_dims = param_.Out()->dims(); + out_dims[0] = param_.BatchSize(); + param_.Out()->Resize(out_dims); + } #ifdef PADDLE_MOBILE_FPGA void RunImpl() const { fpga::PerformBypass(param_.FpgaArgs()); } - void Init() { const Tensor *input = param_.InputX(); auto input_ptr = input->data(); @@ -53,22 +55,13 @@ class FeedOp : public framework::OperatorBase { param_.SetFpgaArgs(args); } - void InferShape() const { - auto out_dims = param_.Out()->dims(); - out_dims[0] = param_.BatchSize(); - param_.Out()->Resize(out_dims); - param_.Out()->ShareDataWith(*param_.InputX()); // TODO How to handle fp16 - } #else void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } - void Init() {} - - void InferShape() const { - auto out_dims = param_.Out()->dims(); - out_dims[0] = param_.BatchSize(); - param_.Out()->Resize(out_dims); #endif + + protected: + FeedParam param_; }; } // namespace operators