提交 d84e14bd 编写于 作者: qnqinan's avatar qnqinan 提交者: GitHub

Merge pull request #778 from zhangyang0701/develop

Correct feed_op for FPGA track close #777
......@@ -36,12 +36,12 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
out_dims[0] = param_.BatchSize();
param_.Out()->Resize(out_dims);
}
void Init() {}
#ifdef PADDLE_MOBILE_FPGA
void RunImpl() const { fpga::PerformBypass(param_.FpgaArgs()); }
void Init() {
void RunImpl() const {
const Tensor *input = param_.InputX();
auto input_ptr = (const_cast<Tensor *>(input))->mutable_data<float>();
auto input_ptr = input->data<float>();
Tensor *output = param_.Out();
auto output_ptr = output->mutable_data<half>();
fpga::BypassArgs args;
......@@ -52,12 +52,11 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
args.image.height = input->dims()[2];
args.image.width = input->dims()[3];
args.output.address = output_ptr;
param_.SetFpgaArgs(args);
fpga::PerformBypass(args);
}
#else
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() {}
#endif
protected:
......
......@@ -665,16 +665,6 @@ class FeedParam : public OpParam {
Tensor *input_x_;
Tensor *out_;
int batch_size;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::BypassArgs fpga_bypass_args;
public:
const fpga::BypassArgs &FpgaArgs() const { return fpga_bypass_args; }
void SetFpgaArgs(const fpga::BypassArgs &args) { fpga_bypass_args = args; }
#endif
};
class FetchParam : public OpParam {
......@@ -1133,7 +1123,6 @@ class FusionConvBNParam : public OpParam {
FusionConvBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<LoDTensor>(inputs, scope);
input_ = InputFrom<LoDTensor>(inputs, scope);
output_y_ = OutputYFrom<LoDTensor>(outputs, scope);
......@@ -1150,8 +1139,6 @@ class FusionConvBNParam : public OpParam {
// is_test_ = GetAttr<bool>("is_test", attrs);
}
const int &Axis() const { return axis_; }
const Tensor *Input() const { return input_; }
#ifdef PADDLE_MOBILE_FPGA
......@@ -1192,7 +1179,6 @@ class FusionConvBNParam : public OpParam {
const Tensor *NewBias() const { return new_bias_; }
protected:
int axis_;
Tensor *input_;
Tensor *output_y_;
Tensor *filter_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册