diff --git a/src/operators/kernel/arm/dropout_kernel.cpp b/src/operators/kernel/arm/dropout_kernel.cpp index 348b74cd26c79f5d0f2e65012a5760ab7eb58707..4578ac6607d87c316853f6201f02f8204bc41de1 100644 --- a/src/operators/kernel/arm/dropout_kernel.cpp +++ b/src/operators/kernel/arm/dropout_kernel.cpp @@ -27,7 +27,11 @@ bool DropoutKernel::Init(DropoutParam *para) { template struct DropoutFunctor { - inline T operator()(T in) const { return in; } + DropoutFunctor(T drop_pro) : dropout_pro_(drop_pro) {} + inline T operator()(T in) const { return (1 - dropout_pro_) * in; } + + private: + T dropout_pro_; }; template <> @@ -36,8 +40,8 @@ void DropoutKernel::Compute(const DropoutParam ¶m) const { auto *input_x_ptr = input_x->data(); auto *out = param.Out(); auto *out_ptr = out->mutable_data(); - - DropoutFunctor func_; + const float dropoutProb = param.DropoutProb(); + DropoutFunctor func_(dropoutProb); math::Transform trans; trans(input_x_ptr, input_x_ptr + input_x->numel(), out_ptr, func_); } diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 10fe5c2494bb3e5ddbb6876525db8017fe0c910c..6142d9b584d868db6cf9f55c9f6f50f1317a98da 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -2136,15 +2136,20 @@ class DropoutParam : public OpParam { const AttributeMap &attrs, const Scope &scope) { input_x_ = InputXFrom(inputs, scope); out_ = OutFrom(outputs, scope); + + dropout_prob_ = GetAttr("dropout_prob", attrs); } const RType *InputX() const { return input_x_; } RType *Out() const { return out_; } + float DropoutProb() const { return dropout_prob_; } + private: RType *input_x_; RType *out_; + float dropout_prob_; }; #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d96c4fbabbfac9589209fb3319c337f3b6d49897..cf61a754ad6be3994307512004b8158e9c729da2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -207,10 +207,20 @@ else () ADD_EXECUTABLE(test-gru-op operators/test_gru_op.cpp test_helper.h test_include.h) target_link_libraries(test-gru-op paddle-mobile) + # gen test + ADD_EXECUTABLE(test-inceptionv4 net/test_inceptionv4.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-inceptionv4 paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-alexnet net/test_alexnet.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-alexnet paddle-mobile) + #add_library(test-lib-size SHARED common/test_lib_size.h common/test_lib_size.cpp) + + endif() # if(FPGA) diff --git a/test/net/test_alexnet.cpp b/test/net/test_alexnet.cpp new file mode 100644 index 0000000000000000000000000000000000000000..50053fe82f95177fd786c1c8f8f5c9b7a521b888 --- /dev/null +++ b/test/net/test_alexnet.cpp @@ -0,0 +1,59 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_helper.h" +#include "../test_include.h" + +int main() { + paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile.SetThreadNum(4); + auto time1 = time(); + // auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model", + // std::string(g_mobilenet_detect) + "/params", true); + + auto isok = paddle_mobile.Load(g_alexnet, true); + if (isok) { + auto time2 = time(); + std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; + + std::vector input; + std::vector dims{1, 3, 224, 224}; + GetInput(g_test_image_1x3x224x224_banana, &input, dims); + + auto vec_result = paddle_mobile.Predict(input, dims); + std::vector::iterator biggest = + std::max_element(std::begin(vec_result), std::end(vec_result)); + std::cout << " Max element is " << *biggest << " at position " + << std::distance(std::begin(vec_result), biggest) << std::endl; + + // 预热十次 + for (int i = 0; i < 10; ++i) { + auto vec_result = paddle_mobile.Predict(input, dims); + } + auto time3 = time(); + for (int i = 0; i < 10; ++i) { + auto vec_result = paddle_mobile.Predict(input, dims); + } + DLOG << vec_result; + auto time4 = time(); + std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" + << std::endl; + } + + std::cout << "如果结果Nan请查看: test/images/g_test_image_1x3x224x224_banana " + "是否存在?" + << std::endl; + return 0; +} diff --git a/test/net/test_inceptionv4.cpp b/test/net/test_inceptionv4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fbbc9dd39e64f7a8ea745cf7489e46f00ffe1413 --- /dev/null +++ b/test/net/test_inceptionv4.cpp @@ -0,0 +1,59 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_helper.h" +#include "../test_include.h" + +int main() { + paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile.SetThreadNum(4); + auto time1 = time(); + // auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model", + // std::string(g_mobilenet_detect) + "/params", true); + + auto isok = paddle_mobile.Load(g_inceptionv4, true); + if (isok) { + auto time2 = time(); + std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; + + std::vector input; + std::vector dims{1, 3, 224, 224}; + GetInput(g_test_image_1x3x224x224_banana, &input, dims); + + auto vec_result = paddle_mobile.Predict(input, dims); + std::vector::iterator biggest = + std::max_element(std::begin(vec_result), std::end(vec_result)); + std::cout << " Max element is " << *biggest << " at position " + << std::distance(std::begin(vec_result), biggest) << std::endl; + + // 预热十次 + for (int i = 0; i < 10; ++i) { + auto vec_result = paddle_mobile.Predict(input, dims); + } + auto time3 = time(); + for (int i = 0; i < 10; ++i) { + auto vec_result = paddle_mobile.Predict(input, dims); + } + // DLOG << vec_result; + auto time4 = time(); + std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" + << std::endl; + } + + std::cout << "如果结果Nan请查看: test/images/g_test_image_1x3x224x224_banana " + "是否存在?" + << std::endl; + return 0; +} diff --git a/test/test_helper.h b/test/test_helper.h index 9c592cf1032a8f4e5c08ef7f4c7e738b6cf0b122..e93025a702ef2ece2c589c42ff05cef720d7e8dc 100644 --- a/test/test_helper.h +++ b/test/test_helper.h @@ -33,6 +33,8 @@ static const char *g_mobilenet_detect = "../models/mobilenet-detect"; static const char *g_squeezenet = "../models/squeezenet"; static const char *g_googlenet = "../models/googlenet"; static const char *g_mobilenet = "../models/mobilenet"; +static const char *g_alexnet = "../models/alexnet"; +static const char *g_inceptionv4 = "../models/inceptionv4"; static const char *g_nlp = "../models/nlp"; static const char *g_resnet_50 = "../models/resnet_50"; static const char *g_resnet = "../models/resnet";