diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 26b189dbf8b38d812ab1f410dcb10956ce01565c..54527a1c610892eb8cbaae2957e034c3ba294b87 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -350,7 +350,7 @@ PMStatus Executor::Predict() { _tp[ops_list_[i]->Type()] += timeCost; } } - DLOG << "====================[ profile ]======================"; + printf("====================[ profile ]======================\n"); typedef std::pair prof_t; std::vector _tv(_tp.begin(), _tp.end()); uint64_t _ptotal = 0; @@ -367,7 +367,7 @@ PMStatus Executor::Predict() { static_cast(p.second), static_cast(p.second) / _ptotal * 100.0); } - DLOG << "====================[---------]======================"; + printf("====================[---------]======================\n"); #endif return PMSuccess; } diff --git a/src/operators/fill_constant_op.h b/src/operators/fill_constant_op.h index e24cecd363630a845f147e2e429b973dad24f63d..c9af1e6f02ef022975fc4434f38856a7a97a4136 100644 --- a/src/operators/fill_constant_op.h +++ b/src/operators/fill_constant_op.h @@ -25,12 +25,11 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -using std::string; template class FillConstantOp : public framework::OperatorBase { public: - FillConstantOp(const string &type, const VariableNameMap &inputs, + FillConstantOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, std::shared_ptr scope) @@ -58,7 +57,7 @@ class FillConstantOp : public framework::OperatorBase { tensor->Resize(framework::make_ddim(param_.Shape())); tensor->mutable_data(framework::ToTypeIndex(data_type)); - math::set_constant(tensor, value); + math::SetConstant(tensor, value); } void Init() {} diff --git a/src/operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h index a19c67e68366fc57a305e0dbb955229a763737d9..7660ed46155f34bb08e9526e71969eb0f2cf8254 100644 --- a/src/operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef FUSION_CONVADDADDPRELU_OP - #pragma once + +#include #include #include "operators/math/conv_func.h" #include "operators/math/im2col.h" @@ -115,20 +116,7 @@ void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam ¶m) { Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor bias1_slice = bias1_batch.Slice(g * out_step, (g + 1) * out_step); float *biase_data1 = bias1_slice.data(); - // int n = bias1_slice.dims()[0]; - // int m = bias1_slice.dims()[1]; - // for(int i=0;i(filter_slice, false, col_matrix, - // false, - // static_cast(1), - // &out_slice, - // static_cast(1), true, - // biase_data); - math::matmulWithPRelu(filter_slice, false, col_matrix, false, &out_slice, + math::MatMulWithPRelu(filter_slice, false, col_matrix, false, &out_slice, p, mode, biase_data, biase_data1); } } @@ -137,4 +125,4 @@ void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam ¶m) { } // namespace operators } // namespace paddle_mobile -#endif +#endif // FUSION_CONVADDADDPRELU_OP diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index d65f89ede012ea083e4a73e4647079c248e33fe0..ebc014da4f841bb90ccd6a1582a3d3043141c151 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -107,7 +107,7 @@ void ConvAddBasic(const FusionConvAddParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, + math::MatMul(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, static_cast(1), false, biase_data); } diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h index 5374eab51f315ee8baa4f4effe04fc97240aabff..9a5bbbbd6a0ab3f134ec3b7f2354a9e1d139aa7c 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h @@ -25,6 +25,7 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { + void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); @@ -105,12 +106,13 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmulWithBn( - filter_slice, false, col_matrix, false, static_cast(1), - &out_slice, static_cast(0), true, &new_scale, &new_bias, g); + math::MatMulWithBn(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(0), true, &new_scale, &new_bias, g); } } } + template void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { Tensor Bias; @@ -126,9 +128,6 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), - // param.Output(), param.NewScale(), - // param.NewBias(), 1); math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); diff --git a/src/operators/kernel/central-arm-func/conv_add_prelu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_prelu_arm_func.h index df63379d967606e15106937534bb82496ee83b4e..4a97b2fa81e19f62633e355cdaa768c1d8ddeb52 100644 --- a/src/operators/kernel/central-arm-func/conv_add_prelu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_prelu_arm_func.h @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef FUSION_CONVADDPRELU_OP - #pragma once + +#include #include #include "operators/math/conv_func.h" #include "operators/math/im2col.h" @@ -30,8 +31,6 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); Tensor bias = *param.Bias(); - // DLOG<<"yangfei"; - // DLOG<(); @@ -112,13 +111,7 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - // math::matmul(filter_slice, false, col_matrix, - // false, - // static_cast(1), - // &out_slice, - // static_cast(1), true, - // biase_data); - math::matmulWithPRelu(filter_slice, false, col_matrix, false, &out_slice, + math::MatMulWithPRelu(filter_slice, false, col_matrix, false, &out_slice, p, mode, biase_data, nullptr); } } @@ -127,4 +120,4 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam ¶m) { } // namespace operators } // namespace paddle_mobile -#endif +#endif // FUSION_CONVADDPRELU_OP diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h index 860a3746e2918b05aa0f09f1536589b7dc62899c..9f251b3d7189b36b13ae9ccc27c55b136d8ab511 100644 --- a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h @@ -112,7 +112,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, alpha, + math::MatMul(filter_slice, false, col_matrix, false, alpha, &out_slice, beta, true, bias_data); } } diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index d29789293906312d03918ba7ebd03be65f19a722..848f1b113b78d145f730255a33a82006b94e03f0 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -106,7 +106,7 @@ inline void GemmConv(const ConvParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, + math::MatMul(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, static_cast(0), false, static_cast(nullptr)); diff --git a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h index e3fe37e19bd10ec5cbbfb59b556df5af9fecd09e..65aee43c59fcceeb142c858f2883c80a0d6e004e 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h @@ -108,10 +108,10 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam ¶m) { Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor bias_data = bias_batch.Slice(g * out_step, (g + 1) * out_step); - math::matmulWithBn(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(1), true, &new_scale, - &new_bias, g, bias_data.data()); + math::MatMulWithBn(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(1), true, &new_scale, &new_bias, g, + bias_data.data()); } } } diff --git a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h index 4c8cf393345d16e79799bc5ce9ecd1be1fc0a15a..6e8aec99e5f595381efa98e7fb04501c13ddf4de 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h @@ -107,9 +107,9 @@ void ConvBNReluBasic(const FusionConvBNReluParam ¶m) { Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmulWithBn( - filter_slice, false, col_matrix, false, static_cast(1), - &out_slice, static_cast(0), true, &new_scale, &new_bias, g); + math::MatMulWithBn(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(0), true, &new_scale, &new_bias, g); } } } diff --git a/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h b/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h index 300cb8e84b0703951b5305d684eb2f7bb652d669..f7ebe571f9b8f3f032ef8d5902861f1b5b5d3002 100644 --- a/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h @@ -93,7 +93,7 @@ void ConvTransposeCompute(const ConvTransposeParam ¶m) { Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step); Tensor out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, true, in_slice, false, + math::MatMul(filter_slice, true, in_slice, false, static_cast

(1.0), &col_matrix, static_cast

(0.0)); if (data_dim == 2U) { col2im(col, dilations, strides, diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h index a5c08c26237345320fef89e8f0fdd148534dfc8a..cef297daad3c83253105ccf2c44d195e01d074ae 100644 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -106,9 +106,9 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmulWithBn( - filter_slice, false, col_matrix, false, static_cast(1), - &out_slice, static_cast(0), true, &new_scale, &new_bias, g); + math::MatMulWithBn(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(0), true, &new_scale, &new_bias, g); } } } diff --git a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h index 30eb30ca3cc0ef416a70f657c0c2e6bde5e7e9ba..843e2d5119939f2641b4ccaf3542566465761c32 100644 --- a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h +++ b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h @@ -57,7 +57,7 @@ void FusionFcCompute(const FusionFcParam ¶m) { for (int i = 0; i < out_dim[0]; i++) { memory::Copy(out_data + i * classes, input_z_data, sizeof(Otype) * classes); } - math::matmul(x_matrix, false, y_matrix, false, + math::MatMul(x_matrix, false, y_matrix, false, static_cast(1), out, static_cast(1), false); } diff --git a/src/operators/kernel/central-arm-func/gru_arm_func.h b/src/operators/kernel/central-arm-func/gru_arm_func.h index 2e00e839ff10da0d40612c9f63d5d0f7e059a0fe..897538273232b4379b93dbb34651906e3bc9058c 100644 --- a/src/operators/kernel/central-arm-func/gru_arm_func.h +++ b/src/operators/kernel/central-arm-func/gru_arm_func.h @@ -25,18 +25,16 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -using LoDTensor = framework::LoDTensor; -using Tensor = framework::Tensor; - -template +template inline void ReorderInitState(const framework::Tensor& src, std::vector index_lod, framework::Tensor* dst, bool indexed_src) { - math::CopyMatrixRowsFunctor row_shuffle; + math::CopyMatrixRowsFunctor row_shuffle; dst->mutable_data(src.dims()); row_shuffle(src, index_lod, dst, indexed_src); } -template + +template void GruCompute(const GruParam& param) { auto* input = param.InputInput(); auto* h0 = param.InputH0(); @@ -57,8 +55,6 @@ void GruCompute(const GruParam& param) { bool is_reverse = param.IsReverse(); math::LoDTensor2BatchFunctor to_batch; to_batch(*input, batch_gate, true, is_reverse); - // math::ClearTensor clearTensor; - // clearTensor(batch_gate); if (bias) { math::RowwiseAdd add_bias; add_bias(*batch_gate, *bias, batch_gate); @@ -68,7 +64,7 @@ void GruCompute(const GruParam& param) { gru_value.gate_weight = const_cast(weight_data); gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); - Tensor ordered_h0; + framework::Tensor ordered_h0; std::vector order(batch_gate->lod()[2]); if (h0) { // Since the batch computing for GRU reorders the input sequences @@ -87,9 +83,10 @@ void GruCompute(const GruParam& param) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); int cur_batch_size = bend - bstart; - Tensor gate_t = batch_gate->Slice(bstart, bend); // BUG - Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); - Tensor hidden_t = batch_hidden->Slice(bstart, bend); + framework::Tensor gate_t = batch_gate->Slice(bstart, bend); + framework::Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + framework::Tensor hidden_t = batch_hidden->Slice(bstart, bend); gru_value.output_value = hidden_t.data(); gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); @@ -105,7 +102,6 @@ void GruCompute(const GruParam& param) { } } // namespace operators - } // namespace paddle_mobile -#endif +#endif // GRU_OP diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index 316f78a43f27c17eec1b31741b2b6bc678c41af2..1f22ab98989644430e2484ca8d57fe2c4047e2f7 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -19,40 +19,6 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -// 1、如果x,y维度都是2维, -// x = [[1,2], y = [[5,6], -// [3,4]] [7,8]] -// 运算结果为正常矩阵相乘。结果 out = -// [[1*5+2*7,1*6+2*8],[3*5+4*7, 3*6+4*8]] -// -// 2、如果x的维度大于2或者y的维度大于2,x的维度(2,3,4) ,y的维度(4,1,2) -// x = [[[1,2,3,4], -// [2,3,4,5], -// [3,4,5,6]], -// [[1,2,3,4], -// [2,3,4,5], -// [3,4,5,6]]] -// y = [[[1,2]], -// [[3,4]], -// [[5,6]], -// [[7,8]]] -// 需要借助x_num_col_dims和y_num_col_dims将x和y的维度转换为2维 -// 从模型中读到参数,x_num_col_dims = 2,y_num_col_dims = 1,左开右闭 -// (1) 将x = (2,3,4)的index [0,x_num_col_dims)部分2,3相乘,得到6, -// [x_num_col_dims,xdim.size())部分4相乘,得到4, -// 将Tensor x的dims重写成(6,4) -// (2) 将y = (4,1,2)的index [0,y_num_col_dims)部分4相乘,得到4, -// [y_num_col_dims,ydim.size())部分1,2相乘,得到2, -// 将Tensor y的dims重写成(4,2) -// 并不影响x,y在内存中的分布。 -// x = [[1,2,3,4], y = [[1,2], -// [2,3,4,5], [3,4], -// [3,4,5,6], 矩阵乘法 [5,6], -// [1,2,3,4], [7,8]] -// [2,3,4,5], -// [3,4,5,6]] -// 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列) - template void MulCompute(const MulParam ¶m) { const Tensor *input_x = param.InputX(); @@ -73,12 +39,12 @@ void MulCompute(const MulParam ¶m) { } if (param.InputX()->type() == typeid(int8_t)) { out->mutable_data(); - math::matmul(x_matrix, false, y_matrix, false, + math::MatMul(x_matrix, false, y_matrix, false, static_cast(1), out, static_cast(0)); } else { out->mutable_data(); - math::matmul(x_matrix, false, y_matrix, false, + math::MatMul(x_matrix, false, y_matrix, false, static_cast(1), out, static_cast(0)); } diff --git a/src/operators/kernel/cl/fusion_fc_kernel.cpp b/src/operators/kernel/cl/fusion_fc_kernel.cpp index 7d85becea601878de577b59a5c671b3ea04f9370..34f36b56bc156f898555534374a79b50643f4784 100644 --- a/src/operators/kernel/cl/fusion_fc_kernel.cpp +++ b/src/operators/kernel/cl/fusion_fc_kernel.cpp @@ -94,27 +94,19 @@ void FusionFcCompute(const FusionFcParam ¶m, cl_context context, memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes); } - // for (int i = 0; i < out->numel(); i++) { - // DLOG << out_data[i]; - // } - // bias_data的维度和out的维度一致 - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), + math::MatMul(x_matrix, false, y_matrix, false, static_cast(1), out, static_cast(1), false); out_image->InitEmptyImage(context, commandQueue, out->dims()); framework::TensorToCLImage(out, out_image, context, commandQueue, kernel1); - DLOG << *out; - delete (input_x); delete (input_y); delete (input_z); delete (out); PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); - // if (out_dim.size() != 2) { - // out->Resize(out_dim); - // } } + template <> void FusionFcKernel::Compute( const FusionFcParam ¶m) { diff --git a/src/operators/kernel/mali/fushion_fc_kernel.cpp b/src/operators/kernel/mali/fushion_fc_kernel.cpp index 5e59215834ce00e902deb19e54e149b3b4cfb8ac..39b36d756734a69320060d99297f8d6f3acaeef9 100755 --- a/src/operators/kernel/mali/fushion_fc_kernel.cpp +++ b/src/operators/kernel/mali/fushion_fc_kernel.cpp @@ -61,7 +61,7 @@ void FusionFcKernel::Compute( for (int i = 0; i < out->numel(); i++) { DLOG << out_data[i]; } - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), + math::MatMul(x_matrix, false, y_matrix, false, static_cast(1), out, static_cast(1)); PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); // if (out_dim.size() != 2) { diff --git a/src/operators/kernel/mali/mul_kernel.cpp b/src/operators/kernel/mali/mul_kernel.cpp index da69f5e6fe5a4ec95373011d360cd4d9e20a8a61..6148ae702558f2d1dc28e68d733938510db1082b 100644 --- a/src/operators/kernel/mali/mul_kernel.cpp +++ b/src/operators/kernel/mali/mul_kernel.cpp @@ -44,7 +44,7 @@ void MulKernel::Compute(const MulParam ¶m) { if (out_dim.size() != 2) { out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), + math::MatMul(x_matrix, false, y_matrix, false, static_cast(1), out, static_cast(0)); if (out_dim.size() != 2) { out->Resize(out_dim); diff --git a/src/operators/math/math_func_neon.h b/src/operators/math/math_func_neon.h index 5bb3fd0f5ae3f6349ab52535348f6310e4096951..3f9245351d3bce49f852b90a4d14bab7e6a826f5 100644 --- a/src/operators/math/math_func_neon.h +++ b/src/operators/math/math_func_neon.h @@ -38,7 +38,11 @@ limitations under the License. */ * * (this is the zlib license) */ + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + #pragma once + #include #define c_inv_mant_mask ~0x7f800000u @@ -316,11 +320,11 @@ static inline float32x4_t cos_ps(float32x4_t x) { static inline float32x4_t div_ps(float32x4_t a, float32x4_t b) { float32x4_t reciprocal = vrecpeq_f32(b); reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal); - // reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal); return vmulq_f32(a, reciprocal); } static inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) { - // pow(x, m) = exp(m * log(x)) return exp_ps(vmulq_f32(b, log_ps(a))); } + +#endif // __ARM_NEON__ diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index b9ce977e0c84b148b27a02624baa05e6ab150672..d672dbc607e940450f5fe5f1ffa6d2093e3715f0 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "operators/math/math_function.h" -#include #include +#include "common/enforce.h" #include "framework/data_type.h" #include "framework/tensor.h" #include "operators/math/gemm.h" @@ -35,13 +35,13 @@ struct TensorSetConstant { float value_; }; -void set_constant(framework::Tensor *tensor, float value) { +void SetConstant(framework::Tensor *tensor, float value) { framework::VisitDataType(framework::ToDataType(tensor->type()), TensorSetConstant(tensor, value)); } template <> -void matmul(const framework::Tensor &matrix_a, bool trans_a, +void MatMul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, framework::Tensor *matrix_out, float beta, bool relu, float *bias) { @@ -50,7 +50,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, auto dim_out = matrix_out->dims(); PADDLE_MOBILE_ENFORCE( dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - "The input and output of matmul be matrix"); + "The input and output of MatMul be matrix"); int M = dim_out[0]; int N = dim_out[1]; @@ -72,7 +72,6 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, } #ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, bias); #else @@ -92,19 +91,18 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, } } -template <> -void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, - float alpha, framework::Tensor *matrix_out, float beta, - bool relu, framework::Tensor *new_scale, - framework::Tensor *new_bias, int group, float *bias) { +void MatMulWithBn(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, float alpha, + framework::Tensor *matrix_out, float beta, bool relu, + framework::Tensor *new_scale, framework::Tensor *new_bias, + int group, float *bias) { Gemm gemm; auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); PADDLE_MOBILE_ENFORCE( dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - "The input and output of matmul be matrix"); + "The input and output of MatMul be matrix"); int M = dim_out[0]; int N = dim_out[1]; @@ -122,7 +120,7 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, new_bias->data() + group, bias); #endif } -void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, +void MatMulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, framework::Tensor *matrix_out, float *p, std::string mode, float *bias, float *bias1) { @@ -132,7 +130,7 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, auto dim_out = matrix_out->dims(); PADDLE_MOBILE_ENFORCE( dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - "The input and output of matmul be matrix"); + "The input and output of MatMul be matrix"); int M = dim_out[0]; int N = dim_out[1]; @@ -146,7 +144,6 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, gemm.SgemmWithPRelu(M, N, K, matrix_a.data(), K, matrix_b.data(), N, matrix_out->data(), N, p, mode, bias, bias1); - #endif } diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index 3b682eab2acf96ba70de563aba415a19ad4a66b6..ccc1a2b931a0f2133f25adefc2f9466c02c39fb4 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include #include #include "framework/tensor.h" @@ -22,37 +21,37 @@ namespace paddle_mobile { namespace operators { namespace math { -void set_constant(framework::Tensor *tensor, float value); +void SetConstant(framework::Tensor *tensor, float value); template -void matmul(const framework::Tensor &matrix_a, bool trans_a, +void MatMul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, framework::Tensor *matrix_out, float beta, bool relu = false, Otype *bias = nullptr); template -void matmul(const framework::Tensor &matrix_a, bool trans_a, +void MatMul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, framework::Tensor *matrix_out, float beta, bool relu, Otype *bias, bool addOnRow); -template -void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, +void MatMulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, framework::Tensor *matrix_out, float beta, bool relu, framework::Tensor *new_scale, framework::Tensor *new_bias, - int group, T *bias = nullptr); + int group, float *bias = nullptr); -void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, +void MatMulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, framework::Tensor *matrix_out, float *p, std::string mode, float *bias, float *bias1); -template + +template struct ClearTensor { void operator()(framework::Tensor *tensor); }; -template +template struct RowwiseAdd { void operator()(const framework::Tensor &input, const framework::Tensor &vec, framework::Tensor *output); diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp index 6b3dd3f00a33cb015891a801801964eca1c5dcf5..0595a808f0540a0fa5134e72845992e04d125873 100644 --- a/src/operators/math/math_function_int8.cpp +++ b/src/operators/math/math_function_int8.cpp @@ -22,7 +22,7 @@ namespace operators { namespace math { template <> -void matmul(const framework::Tensor &matrix_a, bool trans_a, +void MatMul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, framework::Tensor *matrix_out, float beta, bool relu, int32_t *bias, @@ -32,7 +32,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, auto dim_out = matrix_out->dims(); PADDLE_MOBILE_ENFORCE( dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - "The input and output of matmul be matrix"); + "The input and output of MatMul be matrix"); int32_t M = dim_out[0]; int32_t N = dim_out[1]; @@ -96,11 +96,11 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, } template <> -void matmul(const framework::Tensor &matrix_a, bool trans_a, +void MatMul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, framework::Tensor *matrix_out, float beta, bool relu, int32_t *bias) { - matmul(matrix_a, trans_a, matrix_b, trans_b, alpha, + MatMul(matrix_a, trans_a, matrix_b, trans_b, alpha, matrix_out, beta, relu, bias, false); } diff --git a/src/operators/math/softmax.cpp b/src/operators/math/softmax.cpp index 9c23d99e60f6c7f38f372cbe2d221ae3c1a58592..4cee62696c6cb592c2e51b13a3d5f2afc4618b6e 100644 --- a/src/operators/math/softmax.cpp +++ b/src/operators/math/softmax.cpp @@ -15,154 +15,131 @@ limitations under the License. */ #ifdef SOFTMAX_OP #include "operators/math/softmax.h" -#include "common/types.h" -#ifdef __ARM_NEON #include #include +#include +#include "common/types.h" #include "operators/math/math_func_neon.h" -#endif namespace paddle_mobile { namespace operators { namespace math { -using framework::DDim; -using framework::Tensor; -template -class SoftmaxFuntor { -#ifdef __ARM_NEON - void sum(float *input, float *sumptr, int inner_size, int outter_size) { - float32x4_t acc = vdupq_n_f32(0); - float sum_ = 0; - for (int i = 0; i < outter_size; ++i) { - float *input_outer_ptr = input + i * inner_size; - int nn = inner_size >> 2; - int left = inner_size - (nn << 2); - for (; nn > 0; nn--) { - float32x4_t vec_input = vld1q_f32(input_outer_ptr); - acc = vaddq_f32(acc, vec_input); - input_outer_ptr += 4; - } - float32x2_t vsum_ = vadd_f32(vget_high_f32(acc), vget_low_f32(acc)); - sum_ = vget_lane_f32(vsum_, 0) + vget_lane_f32(vsum_, 1); - for (; left > 0; left--) { - sum_ += *input_outer_ptr; - input_outer_ptr++; - } - } - for (int j = 0; j < inner_size * outter_size; ++j) { - sumptr[j] = sum_; - } + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) +#ifndef __aarch64__ +inline float32_t vmaxvq_f32(const float32x4_t &r) { + float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r)); + return vget_lane_f32(vpmax_f32(v, v), 0); +} + +inline float32_t vaddvq_f32(const float32x4_t &r) { + float32x2_t v = vadd_f32(vget_high_f32(r), vget_low_f32(r)); + return vget_lane_f32(vpadd_f32(v, v), 0); +} +#endif // __aarch64__ +#endif // __ARM_NEON__ + +float find_max(const float *input, const int num_classes) { + int remain = num_classes; + float max = -std::numeric_limits::max(); +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + int loop = num_classes >> 3; + remain = num_classes & 0x7; + float32x4_t __max = vdupq_n_f32(max); + for (int i = 0; i < loop; ++i, input += 8) { + float32x4_t x0 = vld1q_f32(input); + float32x4_t x1 = vld1q_f32(input + 4); + __max = vmaxq_f32(x0, __max); + __max = vmaxq_f32(x1, __max); + } + max = vmaxvq_f32(__max); +#endif + for (int i = 0; i < remain; ++i) { + max = std::max(max, input[i]); } + return max; +} - void SoftmaxCacl(const Tensor *X, Tensor *Y) { - const float *input = X->data(); - const DDim &dDim = X->dims(); - int axis_index = 1; - if (dDim.size() < 4) { - axis_index = 0; - } - DDim outer_ddim = - paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1); - DDim inner_ddim = - paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size()); - int out_size = paddle_mobile::framework::product(outer_ddim); - int inner_size = paddle_mobile::framework::product(inner_ddim); - auto *max_ptr = new float[inner_size * out_size]; - // max - for (int j = 0; j < out_size; ++j) { - const float *input_outer_ptr = input + j * inner_size; - float *max_outer_ptr = max_ptr + j * inner_size; - float max_ = 0; - for (int i = 0; i < inner_size; ++i) { - const float *input_inner_ptr = input_outer_ptr + i; - max_ = std::max(max_, input_inner_ptr[0]); - } - for (int k = 0; k < inner_size; ++k) { - max_outer_ptr[k] = max_; +template <> +void SoftmaxFuntor::operator()(const framework::Tensor *X, + framework::Tensor *Y) { + const framework::DDim &dims = X->dims(); + int batch_size = dims[0]; + int num_classes = dims[dims.size() - 1]; + int channels = X->numel() / batch_size / num_classes; + const float *x = X->data(); + float *y = Y->mutable_data(); + + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < X->dims()[0]; ++batch) { + for (int channel = 0; channel < channels; ++channel) { + size_t offset = (batch * channels + channel) * num_classes; + const float *input = x + offset; + float *output = y + offset; + // find max + float max = find_max(input, num_classes); + + // exp(x - max) + int remain = num_classes; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + int loop = num_classes >> 3; + remain = num_classes & 0x7; + float32x4_t __max = vdupq_n_f32(max); + for (int i = 0; i < loop; ++i, input += 8, output += 8) { + float32x4_t x0 = vld1q_f32(input); + float32x4_t x1 = vld1q_f32(input + 4); + x0 = vsubq_f32(x0, __max); + x1 = vsubq_f32(x1, __max); + x0 = exp_ps(x0); + x1 = exp_ps(x1); + vst1q_f32(output, x0); + vst1q_f32(output + 4, x1); } - } - // exp(value - max) - float *exp_sub_max = new float[inner_size * out_size]; - float *exp_sub_max_ptr = &exp_sub_max[0]; - for (int l = 0; l < out_size; ++l) { - const float *input_outer_ptr = input + l * inner_size; - float *max_outer_ptr = max_ptr + l * inner_size; - int nn = inner_size >> 2; - int left = inner_size - (nn << 2); - for (; nn > 0; nn--) { - float32x4_t vec_input = vld1q_f32(input_outer_ptr); - float32x4_t vec_max = vld1q_f32(max_outer_ptr); - float32x4_t vec_sub = vsubq_f32(vec_input, vec_max); - float32x4_t vec_exp = exp_ps(vec_sub); - vst1q_f32(exp_sub_max_ptr, vec_exp); - input_outer_ptr += 4; - max_outer_ptr += 4; - exp_sub_max_ptr += 4; +#endif // __ARM_NEON__ + for (int i = 0; i < remain; ++i) { + output[i] = std::expf(input[i] - max); } - for (; left > 0; left--) { - *exp_sub_max_ptr = expf(*input_outer_ptr - *max_outer_ptr); - input_outer_ptr++; - max_outer_ptr++; - exp_sub_max_ptr++; + // sum(exp(x - max)) + float sum = 0.f; + output = y + offset; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + float32x4_t __sum = vdupq_n_f32(0.f); + for (int i = 0; i < loop; ++i, output += 8) { + float32x4_t x0 = vld1q_f32(output); + float32x4_t x1 = vld1q_f32(output + 4); + __sum = vaddq_f32(x0, __sum); + __sum = vaddq_f32(x1, __sum); } - } - float *sumptr = new float[inner_size * out_size]; - // sum exp - sum(exp_sub_max, sumptr, inner_size, out_size); - // div - auto *out_ptr = Y->mutable_data(); - for (int l = 0; l < out_size; ++l) { - const float *input_outer_ptr = exp_sub_max + l * inner_size; - float *output_outer_ptr = out_ptr + l * inner_size; - float *sum_outer_ptr = sumptr + l * inner_size; - int nn = inner_size >> 2; - int left = inner_size - (nn << 2); - for (; nn > 0; nn--) { - float32x4_t vec_input = vld1q_f32(input_outer_ptr); - float32x4_t vec_sum = vld1q_f32(sum_outer_ptr); - float32x4_t vec_div = div_ps(vec_input, vec_sum); - vst1q_f32(output_outer_ptr, vec_div); - input_outer_ptr += 4; - output_outer_ptr += 4; - sum_outer_ptr += 4; + sum += vaddvq_f32(__sum); +#endif // __ARM_NEON__ + for (int i = 0; i < remain; ++i) { + sum += output[i]; } - for (; left > 0; left--) { - *output_outer_ptr = (*input_outer_ptr) / (*sum_outer_ptr); - input_outer_ptr++; - output_outer_ptr++; - sum_outer_ptr++; + + // exp(x - max) / sum + float inv_sum = 1.f / sum; + output = y + offset; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + float32x4_t __inv_sum = vdupq_n_f32(inv_sum); + for (int i = 0; i < loop; ++i, output += 8) { + float32x4_t x0 = vld1q_f32(output); + float32x4_t x1 = vld1q_f32(output + 4); + x0 = vmulq_f32(x0, __inv_sum); + x1 = vmulq_f32(x1, __inv_sum); + vst1q_f32(output, x0); + vst1q_f32(output + 4, x0); } - } - } -#else -#endif // ARM_NEON - - public: - void operator()(const framework::Tensor *X, framework::Tensor *Y) { - const DDim dDim = X->dims(); - int dim1 = dDim[dDim.size() - 1]; - int dim0 = X->numel() / dim1 / dDim[0]; - framework::DDim matrix_shape = {dim0, dim1}; - for (int i = 0; i < dDim[0]; ++i) { - framework::Tensor sub_X = X->Slice(i, i + 1); - framework::Tensor sub_Y = Y->Slice(i, i + 1); - sub_X.Resize(matrix_shape); - sub_Y.Resize(matrix_shape); - for (int j = 0; j < dim0; j++) { - framework::Tensor sub_x = sub_X.Slice(j, j + 1); - framework::Tensor sub_y = sub_Y.Slice(j, j + 1); -#ifdef __ARM_NEON - SoftmaxCacl(&sub_x, &sub_y); #endif + for (int i = 0; i < remain; ++i) { + output[i] *= inv_sum; } } } -}; - -template class SoftmaxFuntor; +} } // namespace math } // namespace operators } // namespace paddle_mobile -#endif + +#endif // SOFTMAX_OP diff --git a/src/operators/math/softmax.h b/src/operators/math/softmax.h index e2ca8f30b067e9262a0e87f4ba5807df07949e73..0de30a4ecaa2d58a4180203b7a27b23dc35446b5 100644 --- a/src/operators/math/softmax.h +++ b/src/operators/math/softmax.h @@ -13,17 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef SOFTMAX_OP + #pragma once + #include "framework/tensor.h" + namespace paddle_mobile { namespace operators { namespace math { -template +template class SoftmaxFuntor { public: void operator()(const framework::Tensor *X, framework::Tensor *Y); }; + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7c9be64f4f76a7dd7c6722eb0dd9cb4a93be6c07..1fb5fcf2df3bd26c094e6d79b37e7ba87d0e7475 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -261,20 +261,17 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-inference-api framework/test_inference_api.cpp) target_link_libraries(test-inference-api paddle-mobile) - - # gen test log # gen test ADD_EXECUTABLE(test-optimize framework/test_optimize.cpp) target_link_libraries(test-optimize paddle-mobile) - #gen test ADD_EXECUTABLE(test-pool-op operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-pool-op paddle-mobile) #gen test - ADD_EXECUTABLE(test-softmax operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-softmax paddle-mobile) + ADD_EXECUTABLE(test-softmax-op operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-softmax-op paddle-mobile) # gen test ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp) diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 17f9b77b7039cbe5eb26645ffa8dc97a164bc808..c88a65625dc03d6e0e1e6a2575a2645e64ab1605 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -73,14 +73,14 @@ int main() { // float // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::MatMul( aa, false, bb, false, static_cast(1), &cc, static_cast(0), false, nullptr); } auto time_start0 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::MatMul( aa, false, bb, false, static_cast(1), &cc, static_cast(0), false, nullptr); } @@ -91,14 +91,14 @@ int main() { // int8_t without bias // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::MatMul( aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, static_cast(0)); } auto time_start1 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::MatMul( aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, static_cast(0)); } @@ -109,13 +109,13 @@ int main() { // int8_t with bias, column element wise add // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::MatMul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), false, bias_data_col, false); } auto time_start2 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::MatMul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), false, bias_data_col, false); } @@ -126,13 +126,13 @@ int main() { // int8_t with bias, row element wise add // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::MatMul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), false, bias_data_row, true); } auto time_start3 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::MatMul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), false, bias_data_row, true); } @@ -143,13 +143,13 @@ int main() { // int8_t with bias&relu // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::MatMul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), true, bias_data_col, false); } auto time_start4 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::MatMul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), true, bias_data_col, false); } diff --git a/test/operators/test_softmax_op.cpp b/test/operators/test_softmax_op.cpp index f31bcb4e455a6b9699cf96271310681e51d4c6a7..d65cf4fea27343343d6c2a2a720a0e0ec7d45076 100644 --- a/test/operators/test_softmax_op.cpp +++ b/test/operators/test_softmax_op.cpp @@ -12,29 +12,88 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "../test_include.h" - #include "operators/softmax_op.h" -int main() { - paddle_mobile::framework::Loader loader; - auto program = loader.Load(std::string(g_mobilenet)); - if (program.originProgram == nullptr) { - DLOG << "program read file"; +namespace paddle_mobile { + +void Softmax(const framework::Tensor *X, framework::Tensor *Y) { + const framework::DDim &dims = X->dims(); + int batch_size = dims[0]; + int num_classes = dims[dims.size() - 1]; + int channels = X->numel() / batch_size / num_classes; + const float *x = X->data(); + float *y = Y->mutable_data(); + + for (int batch = 0; batch < batch_size; ++batch) { + for (int c = 0; c < channels; ++c) { + size_t offset = (batch * channels + c) * num_classes; + const float *input = x + offset; + float *output = y + offset; + float max = -std::numeric_limits::max(); + for (int j = 0; j < num_classes; ++j) { + max = (input[j] > max) ? input[j] : max; + } + float sum = 0.f; + for (int j = 0; j < num_classes; ++j) { + float tmp = std::expf(input[j] - max); + sum += tmp; + output[j] = tmp; + } + for (int j = 0; j < num_classes; ++j) { + output[j] /= sum; + } + } } - Executor4Test> - executor(program, "softmax"); - paddle_mobile::framework::Tensor input; - SetupTensor(&input, {1, 1000}, static_cast(0), - static_cast(1)); - auto out_ddim = paddle_mobile::framework::make_ddim({1, 1000}); - auto output = - executor.Predict(input, "reshape_0.tmp_0", "softmax_0.tmp_0", out_ddim); - auto *output_ptr = output->data(); - for (int j = 0; j < output->numel(); ++j) { - DLOG << " value of output: " << output_ptr[j]; +} + +int TestSoftmaxOp(const std::vector input_shape) { + framework::DDim dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input"}); + outputs["Out"] = std::vector({"output"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, dims, -100.0, 100.0); + + auto output_var = scope.get()->Var("output"); + auto output = output_var->template Get(); + + framework::AttributeMap attrs; + auto *op = new operators::SoftmaxOp("softmax", inputs, outputs, + attrs, scope); + op->InferShape(); + op->Init(); + op->Run(); + + framework::Tensor output_cmp; + float *output_cmp_data = output_cmp.mutable_data(output->dims()); + Softmax(input, &output_cmp); + + const float *output_data = output->data(); + for (int i = 0; i < output->numel(); ++i) { + float gap = output_data[i] - output_cmp_data[i]; + if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); + } } + delete op; + return 0; +} + +} // namespace paddle_mobile +int main(int argc, char *argv[]) { + TestSoftmaxOp({128, 1000}); + TestSoftmaxOp({128, 10, 1000}); return 0; } diff --git a/tools/pre-commit.hooks/cpplint.hook b/tools/pre-commit.hooks/cpplint.hook index 3082a8c8595cd5f9aa9e0a3b5ff69418254ff636..78ca3cfcdda52a223be609801e6b12ec58b79323 100644 --- a/tools/pre-commit.hooks/cpplint.hook +++ b/tools/pre-commit.hooks/cpplint.hook @@ -5,7 +5,7 @@ TOTAL_ERRORS=0 # The trick to remove deleted files: https://stackoverflow.com/a/2413151 for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}' | \ grep -v ".pb.cpp" | grep -v ".pb.h" | grep -v ".pb-c.h" | grep -v ".pb-c.c" | \ - grep -v "protobuf-c.h" | grep -v "protobuf-c.c" | grep -v "paddle_mobile_jni.cpp"); do + grep -v "protobuf-c.h" | grep -v "protobuf-c.c"); do cpplint $file; TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); done