提交 488453ad 编写于 作者: Z zhaojiaying01

fix norm_op

上级 d9ad3056
...@@ -37,6 +37,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add"; ...@@ -37,6 +37,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add";
const char *G_OP_TYPE_LRN = "lrn"; const char *G_OP_TYPE_LRN = "lrn";
const char *G_OP_TYPE_MUL = "mul"; const char *G_OP_TYPE_MUL = "mul";
const char *G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms"; const char *G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms";
const char *G_OP_TYPE_NORM = "norm";
const char *G_OP_TYPE_POLYGON_BOX_TRANSFORM = "polygon_box_transform"; const char *G_OP_TYPE_POLYGON_BOX_TRANSFORM = "polygon_box_transform";
const char *G_OP_TYPE_POOL2D = "pool2d"; const char *G_OP_TYPE_POOL2D = "pool2d";
const char *G_OP_TYPE_PRIOR_BOX = "prior_box"; const char *G_OP_TYPE_PRIOR_BOX = "prior_box";
...@@ -169,5 +170,6 @@ std::unordered_map< ...@@ -169,5 +170,6 @@ std::unordered_map<
{G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}}, {G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}}}; {G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}},
{G_OP_TYPE_NORM, {{"X"}, {"Out", "Norm"}}}};
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -52,10 +52,7 @@ void NormCompute(const NormParam<CPU> &param) { ...@@ -52,10 +52,7 @@ void NormCompute(const NormParam<CPU> &param) {
int pre, n, post; int pre, n, post;
GetDims(x_dims, axis, &pre, &n, &post); GetDims(x_dims, axis, &pre, &n, &post);
square.Resize(input->dims());
framework::DDim shape = {pre, n, post};
framework::DDim norm_shape = {pre, post};
square.Resize(shape);
const float *input_ptr = input->data<float>(); const float *input_ptr = input->data<float>();
float *square_ptr = square.mutable_data<float>(); float *square_ptr = square.mutable_data<float>();
...@@ -106,7 +103,7 @@ void NormCompute(const NormParam<CPU> &param) { ...@@ -106,7 +103,7 @@ void NormCompute(const NormParam<CPU> &param) {
norm_tmp++; norm_tmp++;
out_tmp++; out_tmp++;
} }
out_tmp = out_ptr + i * post; norm_tmp = norm_ptr + i * post;
} }
} }
} }
......
...@@ -42,13 +42,11 @@ namespace ops = paddle_mobile::operators; ...@@ -42,13 +42,11 @@ namespace ops = paddle_mobile::operators;
REGISTER_OPERATOR_CPU(norm, ops::NormOp); REGISTER_OPERATOR_CPU(norm, ops::NormOp);
#endif #endif
#ifdef PADDLE_MOBILE_MALI_GPU #ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(norm, ops::NormOp);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
#endif #endif
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(norm, ops::NormOp);
#endif #endif
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册