diff --git a/src/operators/kernel/arm/dropout_kernel.cpp b/src/operators/kernel/arm/dropout_kernel.cpp index 348b74cd26c79f5d0f2e65012a5760ab7eb58707..4578ac6607d87c316853f6201f02f8204bc41de1 100644 --- a/src/operators/kernel/arm/dropout_kernel.cpp +++ b/src/operators/kernel/arm/dropout_kernel.cpp @@ -27,7 +27,11 @@ bool DropoutKernel::Init(DropoutParam *para) { template struct DropoutFunctor { - inline T operator()(T in) const { return in; } + DropoutFunctor(T drop_pro) : dropout_pro_(drop_pro) {} + inline T operator()(T in) const { return (1 - dropout_pro_) * in; } + + private: + T dropout_pro_; }; template <> @@ -36,8 +40,8 @@ void DropoutKernel::Compute(const DropoutParam ¶m) const { auto *input_x_ptr = input_x->data(); auto *out = param.Out(); auto *out_ptr = out->mutable_data(); - - DropoutFunctor func_; + const float dropoutProb = param.DropoutProb(); + DropoutFunctor func_(dropoutProb); math::Transform trans; trans(input_x_ptr, input_x_ptr + input_x->numel(), out_ptr, func_); } diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 69b4eb34e2e072344a7f624aad17e9d516e90565..4f1fc897252de0578fca587da0d1529e604bfae8 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -2141,15 +2141,20 @@ class DropoutParam : public OpParam { const AttributeMap &attrs, const Scope &scope) { input_x_ = InputXFrom(inputs, scope); out_ = OutFrom(outputs, scope); + + dropout_prob_ = GetAttr("dropout_prob", attrs); } const RType *InputX() const { return input_x_; } RType *Out() const { return out_; } + float DropoutProb() const { return dropout_prob_; } + private: RType *input_x_; RType *out_; + float dropout_prob_; }; #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4c6cd6f55bbc3db749e4e78de200bbe9b779968a..d9dd2634770fbcfce22f1c35790b0b81ac4fa346 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -208,6 +208,14 @@ else () target_link_libraries(test-gru-op paddle-mobile) # gen test + + ADD_EXECUTABLE(test-inceptionv4 net/test_inceptionv4.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-inceptionv4 paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-alexnet net/test_alexnet.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-alexnet paddle-mobile) + ADD_EXECUTABLE(test-googlenetv1 net/test_googlenetv1_combine.cpp test_helper.h test_include.h) target_link_libraries(test-googlenetv1 paddle-mobile) @@ -215,10 +223,13 @@ else () ADD_EXECUTABLE(test-fssd net/test_mobilenet_025_fssd.cpp test_helper.h test_include.h) target_link_libraries(test-fssd paddle-mobile) + #add_library(test-lib-size SHARED common/test_lib_size.h common/test_lib_size.cpp) + + endif() # if(FPGA) diff --git a/test/net/test_alexnet.cpp b/test/net/test_alexnet.cpp new file mode 100644 index 0000000000000000000000000000000000000000..50053fe82f95177fd786c1c8f8f5c9b7a521b888 --- /dev/null +++ b/test/net/test_alexnet.cpp @@ -0,0 +1,59 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_helper.h" +#include "../test_include.h" + +int main() { + paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile.SetThreadNum(4); + auto time1 = time(); + // auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model", + // std::string(g_mobilenet_detect) + "/params", true); + + auto isok = paddle_mobile.Load(g_alexnet, true); + if (isok) { + auto time2 = time(); + std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; + + std::vector input; + std::vector dims{1, 3, 224, 224}; + GetInput(g_test_image_1x3x224x224_banana, &input, dims); + + auto vec_result = paddle_mobile.Predict(input, dims); + std::vector::iterator biggest = + std::max_element(std::begin(vec_result), std::end(vec_result)); + std::cout << " Max element is " << *biggest << " at position " + << std::distance(std::begin(vec_result), biggest) << std::endl; + + // 预热十次 + for (int i = 0; i < 10; ++i) { + auto vec_result = paddle_mobile.Predict(input, dims); + } + auto time3 = time(); + for (int i = 0; i < 10; ++i) { + auto vec_result = paddle_mobile.Predict(input, dims); + } + DLOG << vec_result; + auto time4 = time(); + std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" + << std::endl; + } + + std::cout << "如果结果Nan请查看: test/images/g_test_image_1x3x224x224_banana " + "是否存在?" + << std::endl; + return 0; +} diff --git a/test/net/test_inceptionv4.cpp b/test/net/test_inceptionv4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fbbc9dd39e64f7a8ea745cf7489e46f00ffe1413 --- /dev/null +++ b/test/net/test_inceptionv4.cpp @@ -0,0 +1,59 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_helper.h" +#include "../test_include.h" + +int main() { + paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile.SetThreadNum(4); + auto time1 = time(); + // auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model", + // std::string(g_mobilenet_detect) + "/params", true); + + auto isok = paddle_mobile.Load(g_inceptionv4, true); + if (isok) { + auto time2 = time(); + std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; + + std::vector input; + std::vector dims{1, 3, 224, 224}; + GetInput(g_test_image_1x3x224x224_banana, &input, dims); + + auto vec_result = paddle_mobile.Predict(input, dims); + std::vector::iterator biggest = + std::max_element(std::begin(vec_result), std::end(vec_result)); + std::cout << " Max element is " << *biggest << " at position " + << std::distance(std::begin(vec_result), biggest) << std::endl; + + // 预热十次 + for (int i = 0; i < 10; ++i) { + auto vec_result = paddle_mobile.Predict(input, dims); + } + auto time3 = time(); + for (int i = 0; i < 10; ++i) { + auto vec_result = paddle_mobile.Predict(input, dims); + } + // DLOG << vec_result; + auto time4 = time(); + std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" + << std::endl; + } + + std::cout << "如果结果Nan请查看: test/images/g_test_image_1x3x224x224_banana " + "是否存在?" + << std::endl; + return 0; +} diff --git a/test/test_helper.h b/test/test_helper.h index b6ebdbc04f1212c16a905732084a02f01e540d3c..7581405c3d9f14e7e997e73be91cb624ad6d9798 100644 --- a/test/test_helper.h +++ b/test/test_helper.h @@ -34,6 +34,8 @@ static const char *g_mobilenet_detect = "../models/mobilenet-detect"; static const char *g_squeezenet = "../models/squeezenet"; static const char *g_googlenet = "../models/googlenet"; static const char *g_mobilenet = "../models/mobilenet"; +static const char *g_alexnet = "../models/alexnet"; +static const char *g_inceptionv4 = "../models/inceptionv4"; static const char *g_nlp = "../models/nlp"; static const char *g_resnet_50 = "../models/resnet_50"; static const char *g_resnet = "../models/resnet";