diff --git a/src/common/types.cpp b/src/common/types.cpp index c25c5db30c7183b6685db03386ca9a9355ca6958..444789237f573f8da3eaf915abf61493967aabf8 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -37,6 +37,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add"; const char *G_OP_TYPE_LRN = "lrn"; const char *G_OP_TYPE_MUL = "mul"; const char *G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms"; +const char *G_OP_TYPE_NORM = "norm"; const char *G_OP_TYPE_POLYGON_BOX_TRANSFORM = "polygon_box_transform"; const char *G_OP_TYPE_POOL2D = "pool2d"; const char *G_OP_TYPE_PRIOR_BOX = "prior_box"; @@ -169,5 +170,6 @@ std::unordered_map< {G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}}, - {G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}}}; + {G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}}, + {G_OP_TYPE_NORM, {{"X"}, {"Out", "Norm"}}}}; } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/norm_arm_func.h b/src/operators/kernel/central-arm-func/norm_arm_func.h index 217a8cac5a435f317a8b16f49c26068bfbee6b20..e43c03484712cfb1f8baf96a9ca8cccc062672ca 100644 --- a/src/operators/kernel/central-arm-func/norm_arm_func.h +++ b/src/operators/kernel/central-arm-func/norm_arm_func.h @@ -52,10 +52,7 @@ void NormCompute(const NormParam ¶m) { int pre, n, post; GetDims(x_dims, axis, &pre, &n, &post); - - framework::DDim shape = {pre, n, post}; - framework::DDim norm_shape = {pre, post}; - square.Resize(shape); + square.Resize(input->dims()); const float *input_ptr = input->data(); float *square_ptr = square.mutable_data(); @@ -106,7 +103,7 @@ void NormCompute(const NormParam ¶m) { norm_tmp++; out_tmp++; } - out_tmp = out_ptr + i * post; + norm_tmp = norm_ptr + i * post; } } } diff --git a/src/operators/norm_op.cpp b/src/operators/norm_op.cpp index 65630ed0d1d1819dc9688bdf3a285e0290ca42e2..deed9f69d1cf40ee70a211b0c9a84e4afeef6623 100644 --- a/src/operators/norm_op.cpp +++ b/src/operators/norm_op.cpp @@ -42,13 +42,11 @@ namespace ops = paddle_mobile::operators; REGISTER_OPERATOR_CPU(norm, ops::NormOp); #endif #ifdef PADDLE_MOBILE_MALI_GPU -REGISTER_OPERATOR_MALI_GPU(norm, ops::NormOp); #endif #ifdef PADDLE_MOBILE_FPGA #endif #ifdef PADDLE_MOBILE_CL -REGISTER_OPERATOR_CL(norm, ops::NormOp); #endif #endif