未验证 提交 d1004d3e 编写于 作者: K kk12333 提交者: GitHub

Merge pull request #7 from PaddlePaddle/develop

merge paddle mobile develop
...@@ -20,14 +20,12 @@ limitations under the License. */ ...@@ -20,14 +20,12 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "framework/tensor.h" #include "framework/tensor.h"
#include "operators/math/poly_util.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
constexpr int kOutputDim = 6;
constexpr int kBBoxSize = 4;
template <class T> template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1, bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) { const std::pair<float, T>& pair2) {
...@@ -90,6 +88,21 @@ static inline T JaccardOverlap(const T* box1, const T* box2, ...@@ -90,6 +88,21 @@ static inline T JaccardOverlap(const T* box1, const T* box2,
} }
} }
template <class T>
static inline T PolyIoU(const T* box1, const T* box2, const size_t box_size,
const bool normalized) {
T bbox1_area = math::PolyArea<T>(box1, box_size, normalized);
T bbox2_area = math::PolyArea<T>(box2, box_size, normalized);
T inter_area = math::PolyOverlapArea<T>(box1, box2, box_size, normalized);
if (bbox1_area == 0 || bbox2_area == 0 || inter_area == 0) {
// If coordinate values are is invalid
// if area size <= 0, return 0.
return static_cast<T>(0.);
} else {
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <typename T> template <typename T>
static inline void NMSFast(const framework::Tensor& bbox, static inline void NMSFast(const framework::Tensor& bbox,
const framework::Tensor& scores, const framework::Tensor& scores,
...@@ -116,8 +129,14 @@ static inline void NMSFast(const framework::Tensor& bbox, ...@@ -116,8 +129,14 @@ static inline void NMSFast(const framework::Tensor& bbox,
for (size_t k = 0; k < selected_indices->size(); ++k) { for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) { if (keep) {
const int kept_idx = (*selected_indices)[k]; const int kept_idx = (*selected_indices)[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size, T overlap = T(0.);
if (box_size == 4) {
overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, true); bbox_data + kept_idx * box_size, true);
} else {
overlap = PolyIoU<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, box_size, true);
}
keep = overlap <= adaptive_threshold; keep = overlap <= adaptive_threshold;
} else { } else {
break; break;
...@@ -190,6 +209,8 @@ void MultiClassOutput(const framework::Tensor& scores, ...@@ -190,6 +209,8 @@ void MultiClassOutput(const framework::Tensor& scores,
const std::map<int, std::vector<int>>& selected_indices, const std::map<int, std::vector<int>>& selected_indices,
framework::Tensor* outs) { framework::Tensor* outs) {
int predict_dim = scores.dims()[1]; int predict_dim = scores.dims()[1];
int box_size = bboxes.dims()[1];
int out_dim = bboxes.dims()[1] + 2;
auto* scores_data = scores.data<T>(); auto* scores_data = scores.data<T>();
auto* bboxes_data = bboxes.data<T>(); auto* bboxes_data = bboxes.data<T>();
auto* odata = outs->data<T>(); auto* odata = outs->data<T>();
...@@ -202,11 +223,11 @@ void MultiClassOutput(const framework::Tensor& scores, ...@@ -202,11 +223,11 @@ void MultiClassOutput(const framework::Tensor& scores,
const std::vector<int>& indices = it.second; const std::vector<int>& indices = it.second;
for (size_t j = 0; j < indices.size(); ++j) { for (size_t j = 0; j < indices.size(); ++j) {
int idx = indices[j]; int idx = indices[j];
const T* bdata = bboxes_data + idx * kBBoxSize; const T* bdata = bboxes_data + idx * box_size;
odata[count * kOutputDim] = label; // label odata[count * out_dim] = label; // label
odata[count * kOutputDim + 1] = sdata[idx]; // score odata[count * out_dim + 1] = sdata[idx]; // score
// xmin, ymin, xmax, ymax // xmin, ymin, xmax, ymax
std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T)); std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T));
count++; count++;
} }
} }
...@@ -256,7 +277,8 @@ void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) { ...@@ -256,7 +277,8 @@ void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) {
float* od = outs->mutable_data<float>({1}); float* od = outs->mutable_data<float>({1});
od[0] = -1; od[0] = -1;
} else { } else {
outs->mutable_data<float>({num_kept, kOutputDim}); int64_t out_dim = box_dim + 2;
outs->mutable_data<float>({num_kept, out_dim});
for (int64_t i = 0; i < batch_size; ++i) { for (int64_t i = 0; i < batch_size; ++i) {
framework::Tensor ins_score = input_scores->Slice(i, i + 1); framework::Tensor ins_score = input_scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim}); ins_score.Resize({class_num, predict_dim});
......
...@@ -1667,7 +1667,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1667,7 +1667,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
const int w_times = (out_w - 2) / 3; const int w_times = (out_w - 2) / 3;
float32x4_t zero = vdupq_n_f32(0.0); float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) { for (int b = batch_size; b > 0; --b) {
#pragma omp parallel for #pragma omp parallel for
for (int j = 0; j < c; j++) { for (int j = 0; j < c; j++) {
const float *input_row_ptr; const float *input_row_ptr;
float *output_row_ptr; float *output_row_ptr;
...@@ -1912,9 +1912,7 @@ void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter, ...@@ -1912,9 +1912,7 @@ void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
float w20 = filter_data[6]; float w20 = filter_data[6];
float w21 = filter_data[7]; float w21 = filter_data[7];
float w22 = filter_data[8]; float w22 = filter_data[8];
float32x4_t biasv = vld1q_dup_f32(bias_data); float32x4_t biasv = vld1q_dup_f32(bias_data);
for (int i = 0; i < output_height; i += 1) { for (int i = 0; i < output_height; i += 1) {
for (int m = 0; m < output_width - 2; m += 3) { for (int m = 0; m < output_width - 2; m += 3) {
float *output_ptr = output_data + i * output_width + m; float *output_ptr = output_data + i * output_width + m;
...@@ -1949,8 +1947,9 @@ void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter, ...@@ -1949,8 +1947,9 @@ void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
out0 = vmlaq_n_f32(out0, in4, w20); out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21); out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22); out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, biasv); if (if_bias) {
out0 = vaddq_f32(out0, biasv);
}
vst1q_lane_f32(output_ptr, out0, 0); vst1q_lane_f32(output_ptr, out0, 0);
vst1q_lane_f32(output_ptr + 1, out0, 1); vst1q_lane_f32(output_ptr + 1, out0, 1);
vst1q_lane_f32(output_ptr + 2, out0, 2); vst1q_lane_f32(output_ptr + 2, out0, 2);
...@@ -1960,16 +1959,18 @@ void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter, ...@@ -1960,16 +1959,18 @@ void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
} }
for (int j = m; j < output_width; j++) { for (int j = m; j < output_width; j++) {
output_data[i * output_width + j] = output_data[i * output_width + j] =
input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 + input_data[(2 * i) * input_width + 2 * j] * w00 +
input_data[(2 * i - 1) * input_width + 2 * j] * w01 + input_data[(2 * i) * input_width + 2 * j + 1] * w01 +
input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 + input_data[(2 * i) * input_width + 2 * j + 2] * w02 +
input_data[(2 * i) * input_width + 2 * j - 1] * w10 + input_data[(2 * i + 1) * input_width + 2 * j] * w10 +
input_data[(2 * i) * input_width + 2 * j] * w11 + input_data[(2 * i + 1) * input_width + 2 * j + 1] * w11 +
input_data[(2 * i) * input_width + 2 * j + 1] * w12 + input_data[(2 * i + 1) * input_width + 2 * j + 2] * w12 +
input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 + input_data[(2 * i + 2) * input_width + 2 * j] * w20 +
input_data[(2 * i + 1) * input_width + 2 * j] * w21 + input_data[(2 * i + 2) * input_width + 2 * j + 1] * w21 +
input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22; input_data[(2 * i + 2) * input_width + 2 * j + 2] * w22;
output_data[i * output_width + j] += *bias_data; if (if_bias) {
output_data[i * output_width + j] += *bias_data;
}
} }
} }
} }
......
此差异已折叠。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MULTICLASSNMS_OP
#pragma once
#include <float.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
namespace gpc {
typedef enum { // Set operation type
GPC_DIFF, // Difference
GPC_INT, // Intersection
GPC_XOR, // Exclusive or
GPC_UNION // Union
} gpc_op;
typedef struct { // Polygon vertex structure
double x; // Vertex x component
double y; // vertex y component
} gpc_vertex;
typedef struct { // Vertex list structure
int num_vertices; // Number of vertices in list
gpc_vertex *vertex; // Vertex array pointer
} gpc_vertex_list;
typedef struct { // Polygon set structure
int num_contours; // Number of contours in polygon
int *hole; // Hole external contour flags
gpc_vertex_list *contour; // Contour array pointer
} gpc_polygon;
typedef struct { // Tristrip set structure
int num_strips; // Number of tristrips
gpc_vertex_list *strip; // Tristrip array pointer
} gpc_tristrip;
typedef enum { LEFT, RIGHT } gpc_left_right;
typedef enum { ABOVE, BELOW } gpc_above_below;
typedef enum { CLIP, SUBJ } gpc_clip_subj;
typedef enum { /* Edge intersection classes */
NUL, /* Empty non-intersection */
EMX, /* External maximum */
ELI, /* External left intermediate */
TED, /* Top edge */
ERI, /* External right intermediate */
RED, /* Right edge */
IMM, /* Internal maximum and minimum */
IMN, /* Internal minimum */
EMN, /* External minimum */
EMM, /* External maximum and minimum */
LED, /* Left edge */
ILI, /* Internal left intermediate */
BED, /* Bottom edge */
IRI, /* Internal right intermediate */
IMX, /* Internal maximum */
FUL /* Full non-intersection */
} vertex_type;
typedef enum { /* Horizontal edge states */
NH, /* No horizontal edge */
BH, /* Bottom horizontal edge */
TH /* Top horizontal edge */
} h_state;
typedef enum { /* Edge bundle state */
UNBUNDLED, /* Isolated edge not within a bundle */
BUNDLE_HEAD, /* Bundle head node */
BUNDLE_TAIL /* Passive bundle tail node */
} bundle_state;
typedef struct v_shape { /* Internal vertex list datatype */
double x; /* X coordinate component */
double y; /* Y coordinate component */
struct v_shape *next; /* Pointer to next vertex in list */
} vertex_node;
typedef struct p_shape { /* Internal contour / tristrip type */
int active; /* Active flag / vertex count */
int hole; /* Hole / external contour flag */
vertex_node *v[2]; /* Left and right vertex list ptrs */
struct p_shape *next; /* Pointer to next polygon contour */
struct p_shape *proxy; /* Pointer to actual structure used */
} polygon_node;
typedef struct edge_shape {
gpc_vertex vertex; /* Piggy-backed contour vertex data */
gpc_vertex bot; /* Edge lower (x, y) coordinate */
gpc_vertex top; /* Edge upper (x, y) coordinate */
double xb; /* Scanbeam bottom x coordinate */
double xt; /* Scanbeam top x coordinate */
double dx; /* Change in x for a unit y increase */
int type; /* Clip / subject edge flag */
int bundle[2][2]; /* Bundle edge flags */
int bside[2]; /* Bundle left / right indicators */
bundle_state bstate[2]; /* Edge bundle state */
polygon_node *outp[2]; /* Output polygon / tristrip pointer */
struct edge_shape *prev; /* Previous edge in the AET */
struct edge_shape *next; /* Next edge in the AET */
struct edge_shape *pred; /* Edge connected at the lower end */
struct edge_shape *succ; /* Edge connected at the upper end */
struct edge_shape *next_bound; /* Pointer to next bound in LMT */
} edge_node;
inline bool gpc_eq(float a, float b) { return (fabs(a - b) <= 1e-6); }
inline bool gpc_prev_index(float a, float b) { return (fabs(a - b) <= 1e-6); }
inline int gpc_prev_index(int i, int n) { return ((i - 1 + n) % n); }
inline int gpc_next_index(int i, int n) { return ((i + 1) % n); }
inline int gpc_optimal(gpc_vertex *v, int i, int n) {
return (v[(i + 1) % n].y != v[i].y || v[(i - 1 + n) % n].y != v[i].y);
}
inline int gpc_fwd_min(edge_node *v, int i, int n) {
return (v[(i + 1) % n].vertex.y > v[i].vertex.y &&
v[(i - 1 + n) % n].vertex.y >= v[i].vertex.y);
}
inline int gpc_not_fmax(edge_node *v, int i, int n) {
return (v[(i + 1) % n].vertex.y > v[i].vertex.y);
}
inline int gpc_rev_min(edge_node *v, int i, int n) {
return (v[(i + 1) % n].vertex.y >= v[i].vertex.y &&
v[(i - 1 + n) % n].vertex.y > v[i].vertex.y);
}
inline int gpc_not_rmax(edge_node *v, int i, int n) {
return (v[(i - 1 + n) % n].vertex.y > v[i].vertex.y);
}
// inline void gpc_p_edge(edge_node *d, edge_node *e, int p, double i, double j)
// {
inline void gpc_p_edge(edge_node *d, edge_node *e, int p) {
d = e;
do {
d = d->prev;
} while (!d->outp[p]);
// i = d->bot.x + d->dx * (j - d->bot.y);
}
// inline void gpc_n_edge(edge_node *d, edge_node *e, int p, double i, double j)
// {
inline void gpc_n_edge(edge_node *d, edge_node *e, int p) {
d = e;
do {
d = d->next;
} while (!d->outp[p]);
// i = d->bot.x + d->dx * (j - d->bot.y);
}
template <typename T>
void gpc_malloc(T *&p, int b, char *s) { // NOLINT
if (b > 0) {
p = reinterpret_cast<T *>(malloc(b));
if (!p) {
fprintf(stderr, "gpc malloc failure: %s\n", s);
exit(0);
}
} else {
p = NULL;
}
}
template <typename T>
void gpc_free(T *&p) { // NOLINT
if (p) {
free(p);
p = NULL;
}
}
/*
===========================================================================
Public Function Prototypes
===========================================================================
*/
void add_vertex(vertex_node **t, double x, double y);
void gpc_vertex_create(edge_node *e, int p, int s, double x, double y);
void gpc_add_contour(gpc_polygon *polygon, gpc_vertex_list *contour, int hole);
void gpc_polygon_clip(gpc_op set_operation, gpc_polygon *subject_polygon,
gpc_polygon *clip_polygon, gpc_polygon *result_polygon);
void gpc_tristrip_clip(gpc_op set_operation, gpc_polygon *subject_polygon,
gpc_polygon *clip_polygon,
gpc_tristrip *result_tristrip);
void gpc_polygon_to_tristrip(gpc_polygon *polygon, gpc_tristrip *tristrip);
void gpc_free_polygon(gpc_polygon *polygon);
void gpc_free_tristrip(gpc_tristrip *tristrip);
} // namespace gpc
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MULTICLASSNMS_OP
#include "operators/math/poly_util.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <class T>
void Array2PointVec(const T* box, const size_t box_size,
std::vector<Point_<T>>* vec) {
size_t pts_num = box_size / 2;
vec->resize(pts_num);
for (size_t i = 0; i < pts_num; i++) {
vec->at(i).x = box[2 * i];
vec->at(i).y = box[2 * i + 1];
}
}
template <class T>
void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly) {
size_t pts_num = box_size / 2;
poly->num_contours = 1;
poly->hole = reinterpret_cast<int*>(malloc(sizeof(int)));
poly->hole[0] = 0;
poly->contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
poly->contour->num_vertices = pts_num;
poly->contour->vertex =
(gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num);
for (size_t i = 0; i < pts_num; ++i) {
poly->contour->vertex[i].x = box[2 * i];
poly->contour->vertex[i].y = box[2 * i + 1];
}
}
template void Array2Poly(const float* box, const size_t box_size,
gpc::gpc_polygon* poly);
template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>* vec) {
int pts_num = contour.num_vertices;
vec->resize(pts_num);
for (size_t i = 0; i < pts_num; i++) {
vec->at(i).x = contour.vertex[i].x;
vec->at(i).y = contour.vertex[i].y;
}
}
template <class T>
T GetContourArea(const std::vector<Point_<T>>& vec) {
int pts_num = vec.size();
if (pts_num < 3) return T(0.);
T area = T(0.);
for (size_t i = 0; i < pts_num; ++i) {
area += vec[i].x * vec[(i + 1) % pts_num].y -
vec[i].y * vec[(i + 1) % pts_num].x;
}
return fabs(area / 2.0);
}
template <class T>
T PolyArea(const T* box, const size_t box_size, const bool normalized) {
// If coordinate values are is invalid
// if area size <= 0, return 0.
std::vector<Point_<T>> vec;
Array2PointVec<T>(box, box_size, &vec);
return GetContourArea<T>(vec);
}
template float PolyArea(const float* box, const size_t box_size,
const bool normalized);
template <class T>
T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size,
const bool normalized) {
gpc::gpc_polygon poly1;
gpc::gpc_polygon poly2;
Array2Poly<T>(box1, box_size, &poly1);
Array2Poly<T>(box2, box_size, &poly2);
gpc::gpc_polygon respoly;
gpc::gpc_op op = gpc::GPC_INT;
gpc::gpc_polygon_clip(op, &poly2, &poly1, &respoly);
T inter_area = T(0.);
int contour_num = respoly.num_contours;
for (int i = 0; i < contour_num; ++i) {
std::vector<Point_<T>> resvec;
Poly2PointVec<T>(respoly.contour[i], &resvec);
inter_area += GetContourArea<T>(resvec);
}
gpc::gpc_free_polygon(&poly1);
gpc::gpc_free_polygon(&poly2);
gpc::gpc_free_polygon(&respoly);
return inter_area;
}
template float PolyOverlapArea(const float* box1, const float* box2,
const size_t box_size, const bool normalized);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MULTICLASSNMS_OP
#pragma once
#include <vector>
#include "operators/math/gpc.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <class T>
class Point_ {
public:
// default constructor
Point_() {}
Point_(T _x, T _y) {}
Point_(const Point_& pt) {}
Point_& operator=(const Point_& pt);
// conversion to another data type
// template<typename _T> operator Point_<_T>() const;
// conversion to the old-style C structures
// operator Vec<T, 2>() const;
// checks whether the point is inside the specified rectangle
// bool inside(const Rect_<T>& r) const;
T x; //!< x coordinate of the point
T y; //!< y coordinate of the point
};
template <class T>
void Array2PointVec(const T* box, const size_t box_size,
std::vector<Point_<T>>* vec);
template <class T>
void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly);
template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>* vec);
template <class T>
T GetContourArea(const std::vector<Point_<T>>& vec);
template <class T>
T PolyArea(const T* box, const size_t box_size, const bool normalized);
template <class T>
T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size,
const bool normalized);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -25,8 +25,8 @@ void MultiClassNMSOp<Dtype, T>::InferShape() const { ...@@ -25,8 +25,8 @@ void MultiClassNMSOp<Dtype, T>::InferShape() const {
if (input_scores_dims.size() != 3) { if (input_scores_dims.size() != 3) {
LOG(kLOG_ERROR) << "Input Scores size must be 3"; LOG(kLOG_ERROR) << "Input Scores size must be 3";
} }
if (input_bboxes_dims[2] != 4) { if (input_bboxes_dims[2] % 4 != 0 || input_bboxes_dims[2] < 4) {
LOG(kLOG_ERROR) << "Input BBoxes 2nd dimension must be 4"; LOG(kLOG_ERROR) << "Input BBoxes 2nd dimension must be multiples of 4";
} }
if (input_bboxes_dims[1] != input_scores_dims[2]) { if (input_bboxes_dims[1] != input_scores_dims[2]) {
LOG(kLOG_ERROR) << "Predict bboxes must be equal"; LOG(kLOG_ERROR) << "Predict bboxes must be equal";
......
...@@ -127,18 +127,25 @@ int main() { ...@@ -127,18 +127,25 @@ int main() {
DLOG << "----------**********----------"; DLOG << "----------**********----------";
DLOG << "begin to run MulticlassNMS Test"; DLOG << "begin to run MulticlassNMS Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader; paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string("../../test/models/mobilenet+ssd")); auto program = loader.Load(std::string(g_mobilenet_ssd));
/// input x (1,3,300,300)
paddle_mobile::framework::Tensor inputx1; paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {10, 1917, 4}, static_cast<float>(0), SetupTensor<float>(&inputx1, {1, 2, 4}, static_cast<float>(0),
static_cast<float>(1)); static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>(); auto *inputx1_ptr = inputx1.data<float>();
const float x1[] = {0, 0, 100, 100, 50, 50, 150, 150};
for (int i = 0; i < 8; ++i) {
*(inputx1_ptr + i) = x1[i];
}
paddle_mobile::framework::Tensor inputx2; paddle_mobile::framework::Tensor inputx2;
SetupTensor<float>(&inputx2, {10, 21, 1917}, static_cast<float>(0), SetupTensor<float>(&inputx2, {1, 2, 2}, static_cast<float>(0),
static_cast<float>(1)); static_cast<float>(1));
auto *inputx2_ptr = inputx2.data<float>(); auto *inputx2_ptr = inputx2.data<float>();
const float x2[] = {0.4, 0.3, 0.6, 0.7};
for (int i = 0; i < 4; ++i) {
*(inputx2_ptr + i) = x2[i];
}
paddle_mobile::framework::TestMultiClassNMSOp<paddle_mobile::CPU> paddle_mobile::framework::TestMultiClassNMSOp<paddle_mobile::CPU>
testMultiClassNMSOp(program); testMultiClassNMSOp(program);
...@@ -146,8 +153,26 @@ int main() { ...@@ -146,8 +153,26 @@ int main() {
auto output = testMultiClassNMSOp.predict(inputx1, inputx2); auto output = testMultiClassNMSOp.predict(inputx1, inputx2);
auto *output_ptr = output->data<float>(); auto *output_ptr = output->data<float>();
for (int i = 0; i < output->numel(); i++) { for (int i = 0; i < output->numel(); ++i) {
DLOG << output_ptr[i]; DLOG << output_ptr[i];
} }
// test multi point
paddle_mobile::framework::Tensor inputx3;
SetupTensor<float>(&inputx3, {1, 2, 8}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx3_ptr = inputx3.data<float>();
const float x3[] = {0, 0, 100, 0, 100, 100, 0, 100,
50, 50, 150, 50, 150, 150, 50, 150};
for (int i = 0; i < 16; ++i) {
*(inputx3_ptr + i) = x3[i];
}
auto output2 = testMultiClassNMSOp.predict(inputx3, inputx2);
auto *output_ptr2 = output2->data<float>();
for (int i = 0; i < output2->numel(); ++i) {
DLOG << output_ptr2[i];
}
return 0; return 0;
} }
...@@ -33,6 +33,7 @@ if (CON GREATER -1) ...@@ -33,6 +33,7 @@ if (CON GREATER -1)
set(POOL_OP ON) set(POOL_OP ON)
set(RESHAPE_OP ON) set(RESHAPE_OP ON)
set(FUSION_CONVADDBNRELU_OP ON) set(FUSION_CONVADDBNRELU_OP ON)
set(FUSION_CONVADDRELU_OP ON)
set(FUSION_CONVADD_OP ON) set(FUSION_CONVADD_OP ON)
set(FOUND_MATCH ON) set(FOUND_MATCH ON)
......
...@@ -45,13 +45,13 @@ def combine_bgrs_nchw(bgrs, means_b_g_r, scale, channel_type=ChannelType.BGR): ...@@ -45,13 +45,13 @@ def combine_bgrs_nchw(bgrs, means_b_g_r, scale, channel_type=ChannelType.BGR):
print '------------------' print '------------------'
print bgrs_float_array[0] print bgrs_float_array[0]
print bgrs_float_array[416 * 416 * 2 + 416 * 2 + 2] print bgrs_float_array[224 * 224 * 2 + 224 * 2 + 2]
# for i in range(0, 9): # for i in range(0, 9):
# print'bs %d' % i # print'bs %d' % i
# print bs[i] / 255. # print bs[i] / 255.
print bs[416 * 2 + 2] / 255. print bs[224 * 2 + 2] / 255.
print '--------------combine_bgrs_nchw-----------------end' print '--------------combine_bgrs_nchw-----------------end'
return bgrs_float_array return bgrs_float_array
...@@ -64,6 +64,6 @@ def combine_bgrs_nchw(bgrs, means_b_g_r, scale, channel_type=ChannelType.BGR): ...@@ -64,6 +64,6 @@ def combine_bgrs_nchw(bgrs, means_b_g_r, scale, channel_type=ChannelType.BGR):
# cv2.waitKey(0) # cv2.waitKey(0)
bgrs = tools.resize_take_rgbs('datas/newyolo.jpg', (416, 416, 3)) bgrs = tools.resize_take_rgbs('datas/jpgs/0000_0.9834-148196_82452-0ad4b83ec6bc0f9c5f28101539267054.jpg_p0_0.126571263346.jpg', (224, 224, 3))
array = combine_bgrs_nchw(bgrs, (0, 0, 0), 1. / 255, ChannelType.RGB) array = combine_bgrs_nchw(bgrs, (0, 0, 0), 1. / 255, ChannelType.RGB)
tools.save_to_file('datas/desktop_1_3_416_416_nchw_float', array) tools.save_to_file('datas/desktop_1_3_224_224_nchw_float', array)
...@@ -15,11 +15,11 @@ from array import array ...@@ -15,11 +15,11 @@ from array import array
# image.resize(shape_h_w) # image.resize(shape_h_w)
data = np.fromfile('datas/img.res') data = np.fromfile('/Users/xiebaiyuan/PaddleProject/paddle-mobile/tools/python/imagetools/datas/jpgs2/0000_0.9834-148196_82452-0ad4b83ec6bc0f9c5f28101539267054.jpg_p0_0.126571263346.jpg.input.npfile','f')
print data.size print data.size
print data[0] print data
data.reshape(1, 3, 416, 416) data.reshape(1, 3, 224, 224)
out_array = array('f') out_array = array('f')
print'--------------------' print'--------------------'
print data.size print data.size
...@@ -27,12 +27,12 @@ print data[0] ...@@ -27,12 +27,12 @@ print data[0]
print '如果是nhwc --------' print '如果是nhwc --------'
# rgb rgb rgb rgb rgb # rgb rgb rgb rgb rgb
print data[416 * 3 * 2 + 3 * 2 + 2] print data[224 * 3 * 2 + 3 * 2 + 2]
# print data[2] # print data[2]
print '如果是nchw --------' print '如果是nchw --------'
# rgb rgb rgb rgb rgb # rgb rgb rgb rgb rgb
print data[416 * 416 * 2 + 416 * 2 + 2] print data[224 * 224 * 2 + 224 * 2 + 2]
# print data[2] # print data[2]
# 明明是nchw # 明明是nchw
...@@ -42,6 +42,8 @@ for i in range(0, data.size): ...@@ -42,6 +42,8 @@ for i in range(0, data.size):
print len(out_array) print len(out_array)
print out_array[416 * 416 * 2 + 416 * 2 + 2] print out_array[224 * 224 * 2 + 224 * 2 + 2]
# print out_array
tools.save_to_file('datas/in_put_1_3_416_416_2', out_array) tools.save_to_file('datas/in_put_1_3_224_224_nchw', out_array)
...@@ -77,6 +77,14 @@ fusion_conv_add_attrs_dict = { ...@@ -77,6 +77,14 @@ fusion_conv_add_attrs_dict = {
'strides': 'stride', 'strides': 'stride',
'groups': 'group' 'groups': 'group'
} }
# fluid attr key --- mdl params key
pool2d_attrs_dict = {
'global_pooling': 'global_pooling',
'pooling_type': 'type'
}
# fluid attr key --- mdl params key # fluid attr key --- mdl params key
fluid_attrs_type_dict = { fluid_attrs_type_dict = {
'paddings': 0, 'paddings': 0,
......
# coding=utf-8
import json import json
import os import os
...@@ -12,13 +13,25 @@ def load_mdl(mdl_json_path): ...@@ -12,13 +13,25 @@ def load_mdl(mdl_json_path):
return json.load(f) return json.load(f)
def create_if_not_exit(target_dir):
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir, 0777)
class Converter: class Converter:
'convert mdlmodel to fluidmodel' 'convert mdlmodel to fluidmodel'
def __init__(self, base_dir, mdl_json_path): def __init__(self, base_dir, mdl_json_path):
print 'base_dir: ' + base_dir
self.mdl_json_path = base_dir + mdl_json_path self.mdl_json_path = base_dir + mdl_json_path
self.base_dir = base_dir self.base_dir = base_dir
print mdl_json_path print mdl_json_path
self.source_weights_dir = self.base_dir + 'datas/sourcemodels/source_weights/'
self.target_weight_dir = self.base_dir + 'datas/target/target_weights/'
create_if_not_exit(self.target_weight_dir)
self.mdl_json = load_mdl(self.mdl_json_path) self.mdl_json = load_mdl(self.mdl_json_path)
self.program_desc = framework_pb2.ProgramDesc() self.program_desc = framework_pb2.ProgramDesc()
self.weight_list_ = [] self.weight_list_ = []
...@@ -41,16 +54,18 @@ class Converter: ...@@ -41,16 +54,18 @@ class Converter:
print 'convert end.....' print 'convert end.....'
desc_serialize_to_string = self.program_desc.SerializeToString() desc_serialize_to_string = self.program_desc.SerializeToString()
outputmodel_ = self.base_dir + 'datas/target/outputmodel/' outputmodel_dir = self.base_dir + 'datas/target/mobilenet_classfication/'
if os.path.exists(outputmodel_): if os.path.exists(outputmodel_dir):
shutil.rmtree(outputmodel_) shutil.rmtree(outputmodel_dir)
os.makedirs(outputmodel_, 0777) os.makedirs(outputmodel_dir, 0777)
# todo copy weight files
# if os.path.exists(outputmodel_):
# shutil.rmtree(outputmodel_)
# shutil.copytree('yolo/datas/multiobjects/float32s_nchw_with_head/', 'mobilenet/datas/target/outputmodel/')
f = open(outputmodel_ + "__model__", "wb") if os.path.exists(outputmodel_dir):
shutil.rmtree(outputmodel_dir)
# create_if_not_exit(outputmodel_dir)
shutil.copytree(self.target_weight_dir, outputmodel_dir)
f = open(outputmodel_dir + "__model__", "wb")
f.write(desc_serialize_to_string) f.write(desc_serialize_to_string)
f.close() f.close()
...@@ -63,26 +78,30 @@ class Converter: ...@@ -63,26 +78,30 @@ class Converter:
layers_ = self.mdl_json['layer'] layers_ = self.mdl_json['layer']
for layer in layers_: for layer in layers_:
desc_ops_add = block_desc.ops.add()
# print layer if layer['type'] == 'SoftmaxLayer':
# for i in layer: pass
# print i else:
if 'name' in layer: desc_ops_add = block_desc.ops.add()
l_name = layer['name']
if 'type' in layer: # print layer
self.package_ops_type(desc_ops_add, layer) # for i in layer:
# print i
if 'name' in layer:
l_name = layer['name']
if 'type' in layer:
self.package_ops_type(desc_ops_add, layer)
if 'weight' in layer: if 'weight' in layer:
self.package_ops_weight2inputs(desc_ops_add, layer) self.package_ops_weight2inputs(desc_ops_add, layer)
if 'output' in layer: if 'output' in layer:
self.package_ops_outputs(desc_ops_add, layer) self.package_ops_outputs(desc_ops_add, layer)
if 'input' in layer: if 'input' in layer:
self.package_ops_inputs(desc_ops_add, layer) self.package_ops_inputs(desc_ops_add, layer)
self.package_ops_attrs(desc_ops_add, layer) self.package_ops_attrs(desc_ops_add, layer)
self.add_op_fetch(block_desc) self.add_op_fetch(block_desc)
...@@ -105,7 +124,8 @@ class Converter: ...@@ -105,7 +124,8 @@ class Converter:
desc_ops_add = block_desc.ops.add() desc_ops_add = block_desc.ops.add()
inputs_add = desc_ops_add.inputs.add() inputs_add = desc_ops_add.inputs.add()
inputs_add.parameter = 'X' inputs_add.parameter = 'X'
inputs_add.arguments.append('conv_pred_87') # todo pick last layer --> op output
inputs_add.arguments.append('fc7')
desc_ops_add.type = 'fetch' desc_ops_add.type = 'fetch'
outputs_add = desc_ops_add.outputs.add() outputs_add = desc_ops_add.outputs.add()
outputs_add.parameter = 'Out' outputs_add.parameter = 'Out'
...@@ -129,6 +149,128 @@ class Converter: ...@@ -129,6 +149,128 @@ class Converter:
# boolean # boolean
attrs_add.type = 6 attrs_add.type = 6
attrs_add.b = 0 attrs_add.b = 0
elif desc_ops_add.type == types.op_fluid_pooling:
Converter.pack_pooling_attr(desc_ops_add, layer)
pass
elif desc_ops_add.type == types.op_fluid_softmax:
pass
@staticmethod
def pack_pooling_attr(desc_ops_add, layer):
print layer
l_params = layer['param']
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'use_mkldnn'
# boolean
attrs_add.type = 6
attrs_add.b = 0
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'use_cudnn'
# boolean
attrs_add.type = 6
attrs_add.b = 1
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'paddings'
# ints
attrs_add.type = 3
attrs_add.ints.append(0)
attrs_add.ints.append(0)
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'strides'
# ints
attrs_add.type = 3
attrs_add.ints.append(1)
attrs_add.ints.append(1)
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'global_pooling'
# boolean
attrs_add.type = 6
attrs_add.b = (l_params[types.pool2d_attrs_dict.get('global_pooling')])
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'pooling_type'
# 2-->STRING
attrs_add.type = 2
# 注意这里 avg but mdl is ave
attrs_add.s = l_params[types.pool2d_attrs_dict.get('pooling_type')]
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'ceil_mode'
# boolean
attrs_add.type = 6
attrs_add.b = 1
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'ksize'
# ints
attrs_add.type = 3
attrs_add.ints.append(7)
attrs_add.ints.append(7)
# type: "pool2d"
# attrs
# {
# name: "use_mkldnn"
# type: BOOLEAN
# b: false
# }
# attrs
# {
# name: "ceil_mode"
# type: BOOLEAN
# b: true
# }
# attrs
# {
# name: "use_cudnn"
# type: BOOLEAN
# b: true
# }
# attrs
# {
# name: "paddings"
# type: INTS
# ints: 0
# ints: 0
# }
# attrs
# {
# name: "strides"
# type: INTS
# ints: 1
# ints: 1
# }
# attrs
# {
# name: "global_pooling"
# type: BOOLEAN
# b: false
# }
# attrs
# {
# name: "data_format"
# type: STRING
# s: "AnyLayout"
# }
# attrs
# {
# name: "ksize"
# type: INTS
# ints: 7
# ints: 7
# }
# attrs
# {
# name: "pooling_type"
# type: STRING
# s: "avg"
# }
# is_target: false
@staticmethod @staticmethod
def pack_fusion_conv_add_attr(desc_ops_add, layer): def pack_fusion_conv_add_attr(desc_ops_add, layer):
...@@ -181,6 +323,13 @@ class Converter: ...@@ -181,6 +323,13 @@ class Converter:
attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('paddings')]) attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('paddings')])
attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('paddings')]) attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('paddings')])
# attrs_add = desc_ops_add.attrs.add()
# attrs_add.name = 'paddings'
# # ints
# attrs_add.type = 3
# attrs_add.ints.append(0)
# attrs_add.ints.append(0)
attrs_add = desc_ops_add.attrs.add() attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'strides' attrs_add.name = 'strides'
# ints # ints
...@@ -188,6 +337,13 @@ class Converter: ...@@ -188,6 +337,13 @@ class Converter:
attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('strides')]) attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('strides')])
attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('strides')]) attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('strides')])
# attrs_add = desc_ops_add.attrs.add()
# attrs_add.name = 'strides'
# # ints
# attrs_add.type = 3
# attrs_add.ints.append(6)
# attrs_add.ints.append(6)
attrs_add = desc_ops_add.attrs.add() attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'groups' attrs_add.name = 'groups'
# int # int
...@@ -232,8 +388,8 @@ class Converter: ...@@ -232,8 +388,8 @@ class Converter:
# print o # print o
outputs_add = desc_ops_add.outputs.add() outputs_add = desc_ops_add.outputs.add()
dict = types.op_io_dict.get(desc_ops_add.type) dict = types.op_io_dict.get(desc_ops_add.type)
print 'desc_ops_add.type: ' + desc_ops_add.type # print 'desc_ops_add.type: ' + desc_ops_add.type
print dict # print dict
outputs_add.parameter = dict.get(types.mdl_outputs_key) outputs_add.parameter = dict.get(types.mdl_outputs_key)
outputs_add.arguments.append(o) outputs_add.arguments.append(o)
...@@ -305,7 +461,7 @@ class Converter: ...@@ -305,7 +461,7 @@ class Converter:
# issues in mdl model filter swich n and c # issues in mdl model filter swich n and c
if j in self.deepwise_weight_list_ and len(dims_of_matrix) == 4: if j in self.deepwise_weight_list_ and len(dims_of_matrix) == 4:
print j print "deep wise issue fit: " + j
tensor.dims.append(dims_of_matrix[1]) tensor.dims.append(dims_of_matrix[1])
tensor.dims.append(dims_of_matrix[0]) tensor.dims.append(dims_of_matrix[0])
tensor.dims.append(dims_of_matrix[2]) tensor.dims.append(dims_of_matrix[2])
...@@ -320,6 +476,12 @@ class Converter: ...@@ -320,6 +476,12 @@ class Converter:
vars_add.persistable = 1 vars_add.persistable = 1
dims_size = len(dims_of_matrix) dims_size = len(dims_of_matrix)
# print dims_size # print dims_size
# print 'weight name : ' + j
Swichter().copy_add_head(
self.source_weights_dir + j + '.bin',
self.target_weight_dir + j
)
# if dims_size == 4: # if dims_size == 4:
# # convert weight from nhwc to nchw # # convert weight from nhwc to nchw
# Swichter().nhwc2nchw_one_slice_add_head( # Swichter().nhwc2nchw_one_slice_add_head(
...@@ -341,7 +503,7 @@ class Converter: ...@@ -341,7 +503,7 @@ class Converter:
vars_add.persistable = 0 vars_add.persistable = 0
mdl_path = "datas/sourcemodels/cls231_0802/mobileNetModel.json" mdl_path = "datas/sourcemodels/source_profile/mobileNetModel.json"
base_dir = "/Users/xiebaiyuan/PaddleProject/paddle-mobile/tools/python/modeltools/mobilenet/" base_dir = "/Users/xiebaiyuan/PaddleProject/paddle-mobile/tools/python/modeltools/mobilenet/"
converter = Converter(base_dir, mdl_path) converter = Converter(base_dir, mdl_path)
converter.convert() converter.convert()
import os
import shutil
from array import array from array import array
...@@ -58,7 +60,7 @@ class Swichter: ...@@ -58,7 +60,7 @@ class Swichter:
to_file = open(to_file_name, "wb") to_file = open(to_file_name, "wb")
tmp = tmp_file.read() tmp = tmp_file.read()
head = self.read_head('yolo/datas/yolo/conv1_biases') head = self.read_head('yolo/datas/yolo/head')
to_file.write(head) to_file.write(head)
to_file.write(tmp) to_file.write(tmp)
tmp_file.close() tmp_file.close()
...@@ -72,12 +74,14 @@ class Swichter: ...@@ -72,12 +74,14 @@ class Swichter:
# print read # print read
return read return read
def copy_add_head(self, from_file_name, to_file_name, tmp_file_name): def copy_add_head(self, from_file_name, to_file_name):
from_file = open(from_file_name, "rb") from_file = open(from_file_name, "rb")
to_file = open(to_file_name, "wb") to_file = open(to_file_name, "wb")
# tmp_file = open(tmp_file_name, "wb") # tmp_file = open(tmp_file_name, "wb")
head = self.read_head('yolo/datas/yolo/conv1_biases') head = self.read_head(
'/Users/xiebaiyuan/PaddleProject/paddle-mobile/tools/python/modeltools/mobilenet/datas/sourcemodels/head/head')
to_file.write(head) to_file.write(head)
to_file.write(from_file.read()) to_file.write(from_file.read())
from_file.close() from_file.close()
...@@ -96,7 +100,7 @@ class Swichter: ...@@ -96,7 +100,7 @@ class Swichter:
to_file = open(to_file_name, "wb") to_file = open(to_file_name, "wb")
# tmp_file = open(tmp_file_name, "wb") # tmp_file = open(tmp_file_name, "wb")
head = self.read_head('yolo/datas/yolo/conv1_biases') head = self.read_head('yolo/datas/yolo/head')
to_file.write(head) to_file.write(head)
to_file.write(read) to_file.write(read)
from_file.close() from_file.close()
...@@ -110,6 +114,6 @@ class Swichter: ...@@ -110,6 +114,6 @@ class Swichter:
# 32, # 32,
# 3, 3, 3) # 3, 3, 3)
# Swichter().read_head('/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/yolo/conv1_biases') # Swichter().read_head('/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/yolo/head')
# Swichter().copy_add_head('datas/model.0.0.weight', 'datas/conv1_0', '') # Swichter().copy_add_head('datas/model.0.0.weight', 'datas/conv1_0', '')
...@@ -58,7 +58,7 @@ class Swichter: ...@@ -58,7 +58,7 @@ class Swichter:
to_file = open(to_file_name, "wb") to_file = open(to_file_name, "wb")
tmp = tmp_file.read() tmp = tmp_file.read()
head = self.read_head('yolo/datas/yolo/conv1_biases') head = self.read_head('yolo/datas/yolo/head')
to_file.write(head) to_file.write(head)
to_file.write(tmp) to_file.write(tmp)
tmp_file.close() tmp_file.close()
...@@ -77,7 +77,7 @@ class Swichter: ...@@ -77,7 +77,7 @@ class Swichter:
to_file = open(to_file_name, "wb") to_file = open(to_file_name, "wb")
# tmp_file = open(tmp_file_name, "wb") # tmp_file = open(tmp_file_name, "wb")
head = self.read_head('yolo/datas/yolo/conv1_biases') head = self.read_head('yolo/datas/yolo/head')
to_file.write(head) to_file.write(head)
to_file.write(from_file.read()) to_file.write(from_file.read())
from_file.close() from_file.close()
...@@ -96,7 +96,7 @@ class Swichter: ...@@ -96,7 +96,7 @@ class Swichter:
to_file = open(to_file_name, "wb") to_file = open(to_file_name, "wb")
# tmp_file = open(tmp_file_name, "wb") # tmp_file = open(tmp_file_name, "wb")
head = self.read_head('yolo/datas/yolo/conv1_biases') head = self.read_head('yolo/datas/yolo/head')
to_file.write(head) to_file.write(head)
to_file.write(read) to_file.write(read)
from_file.close() from_file.close()
...@@ -110,6 +110,6 @@ class Swichter: ...@@ -110,6 +110,6 @@ class Swichter:
# 32, # 32,
# 3, 3, 3) # 3, 3, 3)
# Swichter().read_head('/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/yolo/conv1_biases') # Swichter().read_head('/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/yolo/head')
# Swichter().copy_add_head('datas/model.0.0.weight', 'datas/conv1_0', '') # Swichter().copy_add_head('datas/model.0.0.weight', 'datas/conv1_0', '')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册