提交 5d5e9960 编写于 作者: Z zhangyang

recover InferShape() function

上级 ab7155f8
...@@ -31,12 +31,14 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -31,12 +31,14 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
scope), scope),
param_(inputs, outputs, attrs, scope.get()) {} param_(inputs, outputs, attrs, scope.get()) {}
protected: void InferShape() const {
FeedParam param_; auto out_dims = param_.Out()->dims();
out_dims[0] = param_.BatchSize();
param_.Out()->Resize(out_dims);
}
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
void RunImpl() const { fpga::PerformBypass(param_.FpgaArgs()); } void RunImpl() const { fpga::PerformBypass(param_.FpgaArgs()); }
void Init() { void Init() {
const Tensor *input = param_.InputX(); const Tensor *input = param_.InputX();
auto input_ptr = input->data<float>(); auto input_ptr = input->data<float>();
...@@ -53,22 +55,13 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -53,22 +55,13 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
param_.SetFpgaArgs(args); param_.SetFpgaArgs(args);
} }
void InferShape() const {
auto out_dims = param_.Out()->dims();
out_dims[0] = param_.BatchSize();
param_.Out()->Resize(out_dims);
param_.Out()->ShareDataWith(*param_.InputX()); // TODO How to handle fp16
}
#else #else
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() {} void Init() {}
void InferShape() const {
auto out_dims = param_.Out()->dims();
out_dims[0] = param_.BatchSize();
param_.Out()->Resize(out_dims);
#endif #endif
protected:
FeedParam param_;
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册