提交 4741e8fc 编写于 作者: Z zhangyang

correct feed_op for FPGA

上级 1559e7fd
...@@ -36,12 +36,12 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -36,12 +36,12 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
out_dims[0] = param_.BatchSize(); out_dims[0] = param_.BatchSize();
param_.Out()->Resize(out_dims); param_.Out()->Resize(out_dims);
} }
void Init() {}
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
void RunImpl() const { fpga::PerformBypass(param_.FpgaArgs()); } void RunImpl() const {
void Init() {
const Tensor *input = param_.InputX(); const Tensor *input = param_.InputX();
auto input_ptr = (const_cast<Tensor *>(input))->mutable_data<float>(); auto input_ptr = input->data<float>();
Tensor *output = param_.Out(); Tensor *output = param_.Out();
auto output_ptr = output->mutable_data<half>(); auto output_ptr = output->mutable_data<half>();
fpga::BypassArgs args; fpga::BypassArgs args;
...@@ -52,12 +52,11 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -52,12 +52,11 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
args.image.height = input->dims()[2]; args.image.height = input->dims()[2];
args.image.width = input->dims()[3]; args.image.width = input->dims()[3];
args.output.address = output_ptr; args.output.address = output_ptr;
param_.SetFpgaArgs(args); fpga::PerformBypass(args);
} }
#else #else
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() {}
#endif #endif
protected: protected:
......
...@@ -665,16 +665,6 @@ class FeedParam : public OpParam { ...@@ -665,16 +665,6 @@ class FeedParam : public OpParam {
Tensor *input_x_; Tensor *input_x_;
Tensor *out_; Tensor *out_;
int batch_size; int batch_size;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::BypassArgs fpga_bypass_args;
public:
const fpga::BypassArgs &FpgaArgs() const { return fpga_bypass_args; }
void SetFpgaArgs(const fpga::BypassArgs &args) { fpga_bypass_args = args; }
#endif
}; };
class FetchParam : public OpParam { class FetchParam : public OpParam {
...@@ -1133,7 +1123,6 @@ class FusionConvBNParam : public OpParam { ...@@ -1133,7 +1123,6 @@ class FusionConvBNParam : public OpParam {
FusionConvBNParam(const VariableNameMap &inputs, FusionConvBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { const Scope &scope) {
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<LoDTensor>(inputs, scope); filter_ = FilterFrom<LoDTensor>(inputs, scope);
input_ = InputFrom<LoDTensor>(inputs, scope); input_ = InputFrom<LoDTensor>(inputs, scope);
output_y_ = OutputYFrom<LoDTensor>(outputs, scope); output_y_ = OutputYFrom<LoDTensor>(outputs, scope);
...@@ -1150,8 +1139,6 @@ class FusionConvBNParam : public OpParam { ...@@ -1150,8 +1139,6 @@ class FusionConvBNParam : public OpParam {
// is_test_ = GetAttr<bool>("is_test", attrs); // is_test_ = GetAttr<bool>("is_test", attrs);
} }
const int &Axis() const { return axis_; }
const Tensor *Input() const { return input_; } const Tensor *Input() const { return input_; }
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
...@@ -1192,7 +1179,6 @@ class FusionConvBNParam : public OpParam { ...@@ -1192,7 +1179,6 @@ class FusionConvBNParam : public OpParam {
const Tensor *NewBias() const { return new_bias_; } const Tensor *NewBias() const { return new_bias_; }
protected: protected:
int axis_;
Tensor *input_; Tensor *input_;
Tensor *output_y_; Tensor *output_y_;
Tensor *filter_; Tensor *filter_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册