提交 52c6945b 编写于 作者: E eclipsycn 提交者: GitHub

Merge pull request #221 from Eclipsess/develop

fix #220 add prior_box_op and testfile
......@@ -25,7 +25,7 @@ SOFTWARE.
#include "framework/program_desc.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#define PADDLE_MOBILE_DEBUG 1
namespace paddle_mobile {
void ReadBinaryFile(const std::string &filename, std::string *contents) {
......@@ -165,9 +165,9 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
for (const auto &block : program_desc_proto.blocks()) {
LOG(kLOG_DEBUG) << "block: " << block.idx();
for (int j = 0; j < block.ops().size(); ++j) {
if (j == 2) {
break;
}
// if (j == 2) {
// break;
// }
framework::proto::OpDesc op = block.ops()[j];
LOG(kLOG_DEBUG1) << "op: " << op.type();
for (int m = 0; m < op.inputs_size(); ++m) {
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include "operators/kernel/prior_box_kernel.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
struct ClipFunctor {
inline T operator()(T in) const {
return std::min<T>(std::max<T>(in, 0.), 1.);
}
};
template <>
void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam &param) const {
const auto *input_ = param.Input();
const auto &input_dims = input_->dims();
const auto *input_image = param.InputImage();
const auto &input_image_dims = input_image->dims();
const auto &min_sizes = param.MinSizes();
const auto &max_sizes = param.MaxSizes();
const auto &variances = param.Variances();
const auto &input_aspect_ratio = param.AspectRatios();
const bool &flip = param.Flip();
const bool &clip = param.Clip();
const float &step_w = param.StepW();
const float &step_h = param.StepH();
const float &offset = param.Offset();
Tensor *output_boxes = param.OutputBoxes();
auto output_boxes_dataptr = output_boxes->mutable_data<float>();
Tensor *output_variances = param.OutputVariances();
auto output_variances_dataptr = output_variances->mutable_data<float>();
std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
auto img_width = input_image_dims[3];
auto img_height = input_image_dims[2];
auto feature_width = input_dims[3];
auto feature_height = input_dims[2];
auto stride0 = output_boxes->dims()[1] * output_boxes->dims()[2] *
output_boxes->dims()[3];
auto stride1 = output_boxes->dims()[2] * output_boxes->dims()[3];
auto stride2 = output_boxes->dims()[3];
float step_width, step_height;
/// 300 / 19
if (step_w == 0 || step_h == 0) {
step_width = static_cast<float>(img_width) / feature_width;
step_height = static_cast<float>(img_height) / feature_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = aspect_ratios.size() * min_sizes.size();
if (!max_sizes.empty()) {
num_priors += max_sizes.size();
}
for (int h = 0; h < feature_height; ++h) {
for (int w = 0; w < feature_width; ++w) {
/// map origin image
float center_x = (w + offset) * step_width;
float center_y = (h + offset) * step_height;
float box_width, box_height;
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s];
// priors with different aspect ratios
for (float ar : aspect_ratios) {
box_width = min_size * sqrt(ar) / 2.;
box_height = min_size / sqrt(ar) / 2.;
/// box_width/2 , / img_width 为了得到feature map 相对于
/// 原图的归一化位置的比例。
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] =
(center_x - box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] =
(center_y - box_height) / img_height;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] =
(center_x + box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] =
(center_y + box_height) / img_height;
idx++;
}
if (!max_sizes.empty()) {
auto max_size = max_sizes[s];
// square prior with size sqrt(minSize * maxSize)
box_width = box_height = sqrt(min_size * max_size) / 2.;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] =
(center_x - box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] =
(center_y - box_height) / img_height;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] =
(center_x + box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] =
(center_y + box_height) / img_height;
idx++;
}
}
}
}
if (clip) {
math::Transform trans;
ClipFunctor<float> clip_func;
trans(output_boxes_dataptr, output_boxes_dataptr + output_boxes->numel(),
output_boxes_dataptr, clip_func);
}
Tensor var_t;
var_t.mutable_data<float>(make_ddim({1, static_cast<int>(variances.size())}));
int box_num = feature_height * feature_width * num_priors;
// auto var_dim = output_variances->dims();
// output_variances->Resize({box_num, static_cast<int>(variances.size())});
for (int i = 0; i < box_num; i++) {
output_variances_dataptr[4 * i] = variances[0];
output_variances_dataptr[4 * i + 1] = variances[1];
output_variances_dataptr[4 * i + 2] = variances[2];
output_variances_dataptr[4 * i + 3] = variances[3];
}
// output_variances->Resize(var_dim);
}
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#include <vector>
#include "framework/operator.h"
#include "operators/math/transform.h"
#include "operators/op_param.h"
#pragma once;
namespace paddle_mobile {
namespace operators {
inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
bool flip,
std::vector<float>* output_aspect_ratior) {
constexpr float epsilon = 1e-6;
output_aspect_ratior->clear();
output_aspect_ratior->push_back(1.0f);
for (size_t i = 0; i < input_aspect_ratior.size(); ++i) {
float ar = input_aspect_ratior[i];
bool already_exist = false;
for (size_t j = 0; j < output_aspect_ratior->size(); ++j) {
if (fabs(ar - output_aspect_ratior->at(j)) < epsilon) {
already_exist = true;
break;
}
}
if (!already_exist) {
output_aspect_ratior->push_back(ar);
if (flip) {
output_aspect_ratior->push_back(1.0f / ar);
}
}
}
}
template <typename DeviceType, typename T>
class PriorBoxKernel
: public framework::OpKernelBase<DeviceType, PriorBoxParam> {
public:
void Compute(const PriorBoxParam& param) const;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -65,6 +65,10 @@ class OpParam : PaddleMobileObject {
static T *InputScaleFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Scale", inputs, scope);
}
template <typename T>
static T *InputImageFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Image", inputs, scope);
}
template <typename T>
static std::vector<T *> InputMultiFrom(const VariableNameMap &inputs,
......@@ -87,6 +91,18 @@ class OpParam : PaddleMobileObject {
return GetVarValue<T>("Y", outputs, scope);
}
template <typename T>
static T *OutputBoxesFrom(const VariableNameMap &outputs,
const Scope &scope) {
return GetVarValue<T>("Boxes", outputs, scope);
}
template <typename T>
static T *OutputVariancesFrom(const VariableNameMap &outputs,
const Scope &scope) {
return GetVarValue<T>("Variances", outputs, scope);
}
template <typename T>
static T *MidOutFrom(const VariableNameMap &outputs, const Scope &scope) {
return GetVarValue<T>("MidOut", outputs, scope);
......@@ -382,5 +398,65 @@ class PoolParam : public OpParam {
bool gloabal_pooling_ = false;
};
class PriorBoxParam : public OpParam {
public:
PriorBoxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_ = InputFrom<framework::Tensor>(inputs, scope);
input_image_ = InputImageFrom<framework::Tensor>(inputs, scope);
output_boxes_ = OutputBoxesFrom<framework::Tensor>(outputs, scope);
output_variances_ = OutputVariancesFrom<framework::Tensor>(outputs, scope);
min_sizes_ = GetAttr<std::vector<float>>("min_sizes", attrs);
max_sizes_ = GetAttr<std::vector<float>>("max_sizes", attrs);
aspect_ratios_ = GetAttr<std::vector<float>>("aspect_ratios", attrs);
variances_ = GetAttr<std::vector<float>>("variances", attrs);
flip_ = GetAttr<bool>("flip", attrs);
clip_ = GetAttr<bool>("clip", attrs);
step_w_ = GetAttr<float>("step_w", attrs);
step_h_ = GetAttr<float>("step_h", attrs);
offset_ = GetAttr<float>("offset", attrs);
}
const Tensor *Input() const { return input_; }
const Tensor *InputImage() const { return input_image_; }
Tensor *OutputBoxes() const { return output_boxes_; }
Tensor *OutputVariances() const { return output_variances_; }
const std::vector<float> &MinSizes() const { return min_sizes_; }
const std::vector<float> &MaxSizes() const { return max_sizes_; }
const std::vector<float> &AspectRatios() const { return aspect_ratios_; }
const std::vector<float> &Variances() const { return variances_; }
const bool &Flip() const { return flip_; }
const bool &Clip() const { return clip_; }
const float &StepW() const { return step_w_; }
const float &StepH() const { return step_h_; }
const float &Offset() const { return offset_; }
private:
Tensor *input_;
Tensor *input_image_;
Tensor *output_boxes_;
Tensor *output_variances_;
std::vector<float> min_sizes_;
std::vector<float> max_sizes_;
std::vector<float> aspect_ratios_;
std::vector<float> variances_;
bool flip_;
bool clip_;
float step_w_;
float step_h_;
float offset_;
};
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#include "operators/prior_box_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void PriorBoxOp<Dtype, T>::InferShape() const {
auto input_dims = param_.Input()->dims();
auto input_image_dims = param_.InputImage()->dims();
auto min_sizes = param_.MinSizes();
auto max_sizes = param_.MaxSizes();
auto variances = param_.Variances();
auto aspect_ratios = param_.AspectRatios();
bool flip = param_.Flip();
std::vector<float> aspect_ratios_vec;
ExpandAspectRatios(aspect_ratios, flip, &aspect_ratios_vec);
size_t num_priors = aspect_ratios_vec.size() * min_sizes.size();
if (!max_sizes.empty()) {
num_priors += max_sizes.size();
}
std::vector<int64_t> dim_vec(4);
dim_vec[0] = input_dims[2];
dim_vec[1] = input_dims[3];
dim_vec[2] = num_priors;
dim_vec[3] = 4;
param_.OutputBoxes()->Resize(framework::make_ddim(dim_vec));
param_.OutputVariances()->Resize(framework::make_ddim(dim_vec));
}
template class PriorBoxOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/prior_box_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class PriorBoxOp : public framework::OperatorWithKernel<DeviceType> {
public:
PriorBoxOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
operators::PriorBoxKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override;
protected:
PriorBoxParam param_;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -23,6 +23,10 @@ target_link_libraries(test-lrn-op paddle-mobile)
ADD_EXECUTABLE(test-batchnorm-op operators/test_batchnorm_op.cpp test_helper.h test_include.h)
target_link_libraries(test-batchnorm-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-priorbox-op operators/test_prior_box_op.cpp test_helper.h test_include.h)
target_link_libraries(test-priorbox-op paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-log common/test_log.cpp)
target_link_libraries(test-log paddle-mobile)
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include "../test_include.h"
#include "operators/prior_box_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestPriorBoxOp {
public:
explicit TestPriorBoxOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (auto block_desc : blocks) {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (auto op : ops) {
if (op->Type() == "prior_box" &&
op->Input("Input")[0] == "batch_norm_26.tmp_3") {
DLOG << " mul attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
DLOG << " Input is : " << op->Input("Input")[0];
DLOG << " Image is : " << op->Input("Image")[0];
DLOG << " Output Boxes is : " << op->Output("Boxes")[0];
DLOG << " Output Variances is : " << op->Output("Variances")[0];
DLOG << " offset : " << op->GetAttrMap().at("offset").Get<float>();
DLOG << " step_h : " << op->GetAttrMap().at("step_h").Get<float>();
DLOG << " step_w : " << op->GetAttrMap().at("step_w").Get<float>();
DLOG << " flip : " << op->GetAttrMap().at("flip").Get<bool>();
DLOG << " clip : " << op->GetAttrMap().at("clip").Get<bool>();
// DLOG << " variances : " <<
// op->GetAttrMap().at("variances").Get<std::vector<float>>();
// DLOG << " aspect_ratios : " <<
// op->GetAttrMap().at("aspect_ratios").Get<std::vector<float>>();
// DLOG << " min_sizes : " <<
// op->GetAttrMap().at("min_sizes").Get<std::vector<float>>();
// DLOG << " max_sizes : " <<
// op->GetAttrMap().at("max_sizes").Get<std::vector<float>>();
std::shared_ptr<operators::PriorBoxOp<Dtype, float>> priorbox =
std::make_shared<operators::PriorBoxOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(priorbox);
}
}
}
}
std::shared_ptr<Tensor> predict_priorbox(const Tensor &t1, const Tensor &t2) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("image");
auto tensor_x1 = x1_feed_value->GetMutable<Tensor>();
tensor_x1->ShareDataWith(t1);
Variable *x2_feed_value = scope->Var("batch_norm_26.tmp_3");
auto tensor_x2 = x2_feed_value->GetMutable<Tensor>();
tensor_x2->ShareDataWith(t2);
Variable *boxes_output = scope->Var("prior_box_1.tmp_0");
auto *boxes_output_tensor = boxes_output->GetMutable<Tensor>();
boxes_output_tensor->mutable_data<float>({10, 10, 6, 4});
Variable *variances_output = scope->Var("prior_box_1.tmp_1");
auto *variances_output_tesnor = variances_output->GetMutable<Tensor>();
variances_output_tesnor->mutable_data<float>({10, 10, 6, 4});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> outboxes_tensor = std::make_shared<LoDTensor>();
outboxes_tensor.reset(boxes_output_tensor);
std::shared_ptr<Tensor> outvars_tensor = std::make_shared<LoDTensor>();
outvars_tensor.reset(variances_output_tesnor);
predict_priorbox(t1, t2, 0);
return outboxes_tensor;
// return outvars_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict_priorbox(const Tensor &t1, const Tensor &t2, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestPriorBoxOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run PriorBoxOp Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string("../../test/models/mobilenet+ssd"));
/// input x (1,3,300,300)
paddle_mobile::framework::Tensor input_image;
SetupTensor<float>(&input_image, {1, 3, 300, 300}, static_cast<float>(0),
static_cast<float>(1));
auto *input_image_ptr = input_image.data<float>();
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {1, 1024, 10, 10}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
paddle_mobile::framework::TestPriorBoxOp<paddle_mobile::CPU> testPriorBoxOp(
program);
auto output_priorbox = testPriorBoxOp.predict_priorbox(input_image, inputx1);
auto *output_priorbox_ptr = output_priorbox->data<float>();
for (int i = 0; i < output_priorbox->numel(); i++) {
DLOG << output_priorbox_ptr[i];
}
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册