diff --git a/src/framework/op_info.h b/src/framework/op_info.h index fb1666e2b102ec9c892293a5f08d10894190a8f0..7475d155232e31cf00dab6273200f5bc4671f2e9 100644 --- a/src/framework/op_info.h +++ b/src/framework/op_info.h @@ -32,21 +32,15 @@ struct OpInfo { } }; -template -class OpInfoMap; - -template -static OpInfoMap *g_op_info_map = nullptr; - template class OpInfoMap { public: - static OpInfoMap &Instance() { - LOG(paddle_mobile::kLOG_DEBUG1) << " TODO: fix bug"; - if (g_op_info_map == nullptr) { - g_op_info_map = new OpInfoMap(); + static OpInfoMap *Instance() { + static OpInfoMap *s_instance = nullptr; + if (s_instance == nullptr) { + s_instance = new OpInfoMap(); } - return *g_op_info_map; + return s_instance; } bool Has(const std::string &op_type) const { diff --git a/src/framework/op_registry.h b/src/framework/op_registry.h index 7f5a1558705662ff821ef1a5f7855215a8eaf303..233de642be76297706b497a35fa871fd45ca5dfa 100644 --- a/src/framework/op_registry.h +++ b/src/framework/op_registry.h @@ -35,7 +35,7 @@ class OperatorRegistrarRecursive; template struct OperatorRegistrar : public Registrar { explicit OperatorRegistrar(const std::string& op_type) { - if (OpInfoMap::Instance().Has(op_type)) { + if (OpInfoMap::Instance()->Has(op_type)) { LOG(paddle_mobile::kLOG_DEBUG1) << op_type << " is registered more than once."; return; @@ -47,7 +47,7 @@ struct OperatorRegistrar : public Registrar { } OpInfo info; OperatorRegistrarRecursive(op_type, &info); - OpInfoMap::Instance().Insert(op_type, info); + OpInfoMap::Instance()->Insert(op_type, info); } }; @@ -95,10 +95,10 @@ class OpRegistry { LOG(paddle_mobile::kLOG_DEBUG1) << " output size: " << outputs.size(); LOG(paddle_mobile::kLOG_DEBUG1) << " attr size: " << attrs.size(); LOG(paddle_mobile::kLOG_DEBUG1) - << " OpInfoMap size: " << OpInfoMap::Instance().map().size(); + << " OpInfoMap size: " << OpInfoMap::Instance()->map().size(); LOG(paddle_mobile::kLOG_DEBUG1) << " has type: " << type << " " - << OpInfoMap::Instance().Has(type); - auto& info = OpInfoMap::Instance().Get(type); + << OpInfoMap::Instance()->Has(type); + auto& info = OpInfoMap::Instance()->Get(type); auto op = info.Creator()(type, inputs, outputs, attrs, scope); return std::shared_ptr>(op); } diff --git a/src/framework/tensor.h b/src/framework/tensor.h index c76d4db03ed95decc6c4a5ae07d38c4c260a925e..203cf24e5f12fc2a9917e246db2364389c8e20b4 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -132,13 +132,6 @@ class Tensor { reinterpret_cast(holder_->ptr()) + offset_); } - inline void *mutable_data() { - // PADDLE_ENFORCE(this->holder_ != nullptr, - // "Cannot invoke mutable data if current hold - // nothing."); - return mutable_data(holder_->type()); - } - /** * @brief Return a pointer to mutable memory block. * diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index 03558141f9e4c45daadd9e8bdd0068ca24eeee62..c8ac141f9ca47ad5dc71aef5308503ccdb75fcb7 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -35,14 +35,9 @@ void ConvKernel::Compute(const ConvParam ¶m) const { LOG(kLOG_DEBUG) << param; const Tensor *input = param.Input(); - - // The filter will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. Tensor filter = *param.Filter(); - Tensor *output = param.Output(); - // output->mutable_data(context.GetPlace()); + output->mutable_data(); int groups = param.Groups(); std::vector strides = param.Strides(); @@ -53,17 +48,9 @@ void ConvKernel::Compute(const ConvParam ¶m) const { const int batch_size = static_cast(input->dims()[0]); - // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, - // k_w} std::vector filter_shape_vec(framework::vectorize(filter.dims())); - // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, - // o_w} std::vector output_shape_vec(framework::vectorize(output->dims())); - // use col_shape in the im2col calculation - // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, - // k_w, o_d, - // o_h, o_w} size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); col_shape_vec[0] = input->dims()[1] / groups; @@ -73,24 +60,19 @@ void ConvKernel::Compute(const ConvParam ¶m) const { } framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - // use col_matrix_shape in the gemm calculation - // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, - // o_d * - // o_h * o_w) framework::DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. Tensor col_matrix; if (is_expand) { col.mutable_data(col_shape); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } + DLOG << " col_shape = " << col_shape; + DLOG << " col_matrix_shape = " << col_matrix_shape; framework::DDim input_shape = framework::slice_ddim( input->dims(), 1, static_cast(input->dims().size())); @@ -98,6 +80,7 @@ void ConvKernel::Compute(const ConvParam ¶m) const { framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); + DLOG << " filter.deims() = " << filter.dims(); framework::DDim output_matrix_shape = { output->dims()[1], @@ -110,8 +93,6 @@ void ConvKernel::Compute(const ConvParam ¶m) const { math::Vol2ColFunctor vol2col; math::Im2ColFunctor im2col; - // auto& dev_ctx = context.template - // device_context(); for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); @@ -137,6 +118,9 @@ void ConvKernel::Compute(const ConvParam ¶m) const { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + DLOG << " out_slice " << out_slice.dims(); + DLOG << " filter_slice " << filter_slice.dims(); + DLOG << " col_matrix " << col_matrix.dims(); math::matmul(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, static_cast(0)); diff --git a/src/operators/kernel/arm/sigmoid_kernel.cpp b/src/operators/kernel/arm/sigmoid_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..74bc29878019dfe52de94f6fef966a416e04cc72 --- /dev/null +++ b/src/operators/kernel/arm/sigmoid_kernel.cpp @@ -0,0 +1,95 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../sigmoid_kernel.h" +#if __ARM_NEON +#include "../../math/math_func_neon.h" +#endif + +namespace paddle_mobile { +namespace operators { + +using framework::DDim; +using framework::Tensor; + +void sigmoid(const Tensor *X, Tensor *Y) { +#if __ARM_NEON + DLOG << "step1"; + const float *input = X->data(); + DLOG << "step11"; + + float *output = Y->mutable_data(); + DLOG << "step2"; + + const DDim &dDim = X->dims(); + DLOG << "step3"; + + int axis_index = 1; + if (dDim.size() < 4) { + axis_index = 0; + } + DLOG << "step4"; + + DDim outer_ddim = + paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1); + DDim inner_ddim = + paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size()); + DLOG << "step5"; + + int out_size = paddle_mobile::framework::product(outer_ddim); + int inner_size = paddle_mobile::framework::product(inner_ddim); + DLOG << "step6"; + +#pragma omp parallel for + DLOG << "outsize=" << out_size; + DLOG << "innersize=" << inner_size; + for (int i = 0; i < out_size; ++i) { + const float *input_outer_ptr = input + i * inner_size; + float *output_outer_ptr = output + i * inner_size; + int nn = inner_size >> 2; + int remain = inner_size - (nn << 2); + float32x4_t _one = vdupq_n_f32(1.f); + for (; nn > 0; nn--) { + float32x4_t data = vld1q_f32(input_outer_ptr); + data = vnegq_f32(data); + data = exp_ps(data); + data = vaddq_f32(data, _one); + float32x4_t out_data = vrecpeq_f32(data); + out_data = vmulq_f32(vrecpsq_f32(data, out_data), out_data); + vst1q_f32(output_outer_ptr, out_data); + + input_outer_ptr += 4; + output_outer_ptr += 4; + } + for (; remain > 0; remain--) { + *output_outer_ptr = 1.f / (1.f + exp(-*input_outer_ptr)); + output_outer_ptr++; + input_outer_ptr++; + } + } +#endif +} + +template <> +void SigmoidKernel::Compute(const SigmoidParam ¶m) const { + const Tensor *in_x = param.InputX(); + Tensor *out = param.Out(); + auto x_dims = in_x->dims(); + out->Resize(x_dims); + sigmoid(in_x, out); +} + +template class SigmoidKernel; +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/sigmoid_kernel.h b/src/operators/kernel/sigmoid_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8f5c787f3ff009ed1e334e61657d00454d6e4c0b --- /dev/null +++ b/src/operators/kernel/sigmoid_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" +namespace paddle_mobile { +namespace operators { +using framework::OpKernelBase; +void sigmoid(const Tensor* X, Tensor* Y); +template +class SigmoidKernel : public OpKernelBase { + public: + void Compute(const SigmoidParam& param) const override; +}; +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/softmax_kernel.h b/src/operators/kernel/softmax_kernel.h index 8ffdb1a3e74145cde6b5dd0568e8bed26fe256e0..5bdae46d288adef3c07c6b2735bdfe5e6ec0c1c3 100644 --- a/src/operators/kernel/softmax_kernel.h +++ b/src/operators/kernel/softmax_kernel.h @@ -21,6 +21,8 @@ namespace paddle_mobile { namespace operators { using framework::OpKernelBase; +void simoid(Tensor *X, Tensor *Y); + template class SoftmaxKernel : public OpKernelBase { public: diff --git a/src/operators/math/softmax.cpp b/src/operators/math/softmax.cpp index eb442e634c482ce300bb87ef6d9070c9d3ff415d..6eaeb6e256148598b460f1fe4e1f0cdf451f186c 100644 --- a/src/operators/math/softmax.cpp +++ b/src/operators/math/softmax.cpp @@ -11,11 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #include "operators/math/softmax.h" #include "common/types.h" #if __ARM_NEON #include +#include #include "operators/math/math_func_neon.h" #endif @@ -108,7 +108,7 @@ class SoftmaxFuntor { // sum exp sum(exp_sub_max, sumptr, inner_size, out_size); // div - auto *out_ptr = static_cast(Y->mutable_data()); + auto *out_ptr = Y->mutable_data(); for (int l = 0; l < out_size; ++l) { const float *input_outer_ptr = exp_sub_max + l * inner_size; float *output_outer_ptr = out_ptr + l * inner_size; diff --git a/src/operators/op_param.h b/src/operators/op_param.h index c2f3a5e7cff010a752bb44133c02ebfe489c29a1..72b729b572093a3ae58751fc0dd7f4a05e938cf6 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -542,6 +542,22 @@ class SoftmaxParam : public OpParam { Tensor *input_x_; Tensor *out_; }; + +class SigmoidParam : public OpParam { + public: + SigmoidParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + const framework::Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + out_ = OutFrom(outputs, scope); + } + const Tensor *InputX() const { return input_x_; } + Tensor *Out() const { return out_; } + + private: + Tensor *input_x_; + Tensor *out_; +}; class MultiClassNMSParam : public OpParam { public: MultiClassNMSParam(const VariableNameMap &inputs, diff --git a/src/operators/sigmoid_op.cpp b/src/operators/sigmoid_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6bff80a35aa019a7b05f6e9b58c49e13fb8f1bc8 --- /dev/null +++ b/src/operators/sigmoid_op.cpp @@ -0,0 +1,29 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/sigmoid_op.h" + +namespace paddle_mobile { +namespace operators { +template +void SigmoidOp::InferShape() const { + param_.Out()->Resize(param_.InputX()->dims()); +} +template class SigmoidOp; +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +USE_OP(sigmoid); +REGISTER_OPERATOR(sigmoid, ops::SigmoidOp); diff --git a/src/operators/sigmoid_op.h b/src/operators/sigmoid_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ba5d3d0299fe5de3f94284546b9fc7d81ca6d524 --- /dev/null +++ b/src/operators/sigmoid_op.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "operators/kernel/sigmoid_kernel.h" + +namespace paddle_mobile { +namespace operators { +template +class SigmoidOp : public framework::OperatorWithKernel { + public: + SigmoidOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel(type, inputs, outputs, attrs, + scope), + param_(inputs, outputs, attrs, *scope) {} + + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape() const override; + + void Run() const { + operators::SigmoidKernel kernel; + kernel.Compute(param_); + this->ClearVariables({"X"}); + } + + private: + SigmoidParam param_; +}; +} // namespace operators +} // namespace paddle_mobile diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 8451842499b5d4cbf2fe458a3b0c7be6454f00c6..57ffd92c91f3f045745967bdfee52fc70317a328 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -82,3 +82,7 @@ target_link_libraries(test-enforce paddle-mobile) # gen test ADD_EXECUTABLE(test-googlenet net/test_googlenet.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-googlenet paddle-mobile) + +# gen test +ADD_EXECUTABLE(test-sigmoid operators/test_sigmoid_op.cpp test_include.h) +target_link_libraries(test-sigmoid paddle-mobile) diff --git a/test/executor_for_test.h b/test/executor_for_test.h index 33bd62cd2430594d7f7110c8e47d060c73a2af3c..1b373e8806b996d481ac9528060a93e87ef33dd6 100644 --- a/test/executor_for_test.h +++ b/test/executor_for_test.h @@ -17,11 +17,14 @@ limitations under the License. */ #include #include +#include "./io.h" #include "common/log.h" -#include "io.h" +#include "framework/op_registry.h" #include "operators/conv_op.h" #include "operators/pool_op.h" +#include "operators/relu_op.h" #include "operators/reshape_op.h" +#include "operators/sigmoid_op.h" #include "operators/softmax_op.h" #include "operators/transpose_op.h" @@ -57,9 +60,13 @@ class Executor4Test : public Executor { for (std::shared_ptr op : ops) { if (op->Type() == op_type) { - std::shared_ptr op_ptr = std::make_shared( - op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(), - this->program_.scope); + /// test first meeting op in program + std::shared_ptr> + op_ptr = paddle_mobile::framework::OpRegistry< + paddle_mobile::CPU>::CreateOp(op->Type(), op->GetInputs(), + op->GetOutputs(), + op->GetAttrMap(), + this->program_.scope); this->ops_of_block_[*block_desc.get()].push_back(op_ptr); break; } diff --git a/test/operators/test_cov_op.cpp b/test/operators/test_cov_op.cpp index 260cdfa04c496c0d9671708a7e4dacabea135f9d..2fe7f3577bef42d26c349e9a24313518c05b9d2b 100644 --- a/test/operators/test_cov_op.cpp +++ b/test/operators/test_cov_op.cpp @@ -29,6 +29,9 @@ int main() { paddle_mobile::framework::Tensor input; GetInput(g_test_image_1x3x224x224, &input, {1, 3, 224, 224}); + // // use SetupTensor if not has local input image . + // SetupTensor(&input, {1, 3, 224, 224}, static_cast(0), + // static_cast(1)); auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 112, 112}); auto output = executor.predict(input, "data", "conv2d_0.tmp_0", out_ddim); diff --git a/test/operators/test_elementwise_add_op.cpp b/test/operators/test_elementwise_add_op.cpp index 309f86b22d46306158d67260305cbf8c87a2668a..eeb642a3f486c81a93452b8a3a26354793c8eff1 100644 --- a/test/operators/test_elementwise_add_op.cpp +++ b/test/operators/test_elementwise_add_op.cpp @@ -111,7 +111,7 @@ int main() { DLOG << "begin to run ElementAddOp Test"; paddle_mobile::Loader loader; auto program = - loader.Load(std::string("../../test/models/" + loader.Load(std::string("../models/" "image_classification_resnet.inference.model")); /// input x (1,3,224,224) diff --git a/test/operators/test_relu_op.cpp b/test/operators/test_relu_op.cpp index 6c2084f8c89cb7026e717bba522d8223dbea9e95..6fefb0368bef48c5ad699b530deabff961e9c5d0 100644 --- a/test/operators/test_relu_op.cpp +++ b/test/operators/test_relu_op.cpp @@ -12,108 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once +#include "../executor_for_test.h" #include "../test_include.h" -#include "operators/relu_op.h" - -namespace paddle_mobile { -namespace framework { - -template -class TestReluOp { - public: - explicit TestReluOp(const Program p) : program_(p) { - if (use_optimize_) { - to_predict_program_ = program_.optimizeProgram; - } else { - to_predict_program_ = program_.originProgram; - } - - const std::vector> blocks = - to_predict_program_->Blocks(); - // DLOG << " **block size " << blocks.size(); - for (auto block_desc : blocks) { - std::vector> ops = block_desc->Ops(); - // DLOG << " ops " << ops.size(); - for (auto op : ops) { - if (op->Type() == "relu" && - op->Input("X")[0] == "batch_norm_34.tmp_2") { - DLOG << "in"; - std::shared_ptr> test_op = - std::make_shared>( - op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); - ops_of_block_[*block_desc.get()].push_back(test_op); - } - } - } - } - - std::shared_ptr predict(const Tensor &t1) { - // feed - auto scope = program_.scope; - Variable *x1_feed_value = scope->Var("batch_norm_34.tmp_2"); - auto tensor_x1 = x1_feed_value->GetMutable(); - tensor_x1->ShareDataWith(t1); - - Variable *output = scope->Var("batch_norm_34.tmp_3"); - auto *output_tensor = output->GetMutable(); - output_tensor->mutable_data({1, 2, 3, 4}); - - // DLOG << typeid(output_tensor).name(); - // DLOG << "output_tensor dims: " << output_tensor->dims(); - - std::shared_ptr out_tensor = std::make_shared(); - out_tensor.reset(output_tensor); - - predict(t1, 0); - - return out_tensor; - // return outvars_tensor; - } - - private: - const framework::Program program_; - std::shared_ptr to_predict_program_; - std::map>>> - ops_of_block_; - bool use_optimize_ = false; - - void predict(const Tensor &t1, int block_id) { - std::shared_ptr to_predict_block = - to_predict_program_->Block(block_id); - for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { - auto op = ops_of_block_[*to_predict_block.get()][j]; - DLOG << "op -> run()"; - op->Run(); - } - } -}; - -template class TestReluOp; -} // namespace framework -} // namespace paddle_mobile int main() { - DLOG << "----------**********----------"; - DLOG << "begin to run Relu Test"; paddle_mobile::Loader loader; - auto program = loader.Load(std::string("../../test/models/mobilenet+ssd")); + // ../models/image_classification_resnet.inference.model + auto program = loader.Load(g_mobilenet_ssd); - /// input x (1,3,300,300) - paddle_mobile::framework::Tensor inputx1; - SetupTensor(&inputx1, {1, 2, 3, 4}, static_cast(-1), - static_cast(1)); - auto *inputx1_ptr = inputx1.data(); + PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, + "program file read fail"); - paddle_mobile::framework::TestReluOp testReluOp(program); + Executor4Test> + executor(program, "relu"); + + paddle_mobile::framework::Tensor input; + SetupTensor(&input, {1, 2, 3, 4}, static_cast(-1), + static_cast(1)); - auto output = testReluOp.predict(inputx1); - auto *output_ptr = output->data(); + auto out_ddim = paddle_mobile::framework::make_ddim({1, 2, 3, 4}); + auto output = executor.predict(input, "batch_norm_0.tmp_2", + "batch_norm_0.tmp_3", out_ddim); - for (int i = 0; i < output->numel(); i++) { - DLOG << output_ptr[i]; + auto output_ptr = output->data(); + for (int j = 0; j < output->numel(); ++j) { + DLOG << " value of output: " << output_ptr[j]; } return 0; } diff --git a/test/operators/test_sigmoid_op.cpp b/test/operators/test_sigmoid_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e053ca1e904db2fdd9642eeaaaefd590d3c5624a --- /dev/null +++ b/test/operators/test_sigmoid_op.cpp @@ -0,0 +1,38 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../../src/operators/kernel/sigmoid_kernel.h" +#include "../test_helper.h" +#include "./io.h" + +int main() { + paddle_mobile::framework::Tensor input; + paddle_mobile::framework::Tensor output; + DLOG << 1; + SetupTensor(&input, {1, 4, 60, 60}, static_cast(0), + static_cast(1)); + DLOG << 2; + + auto out_ddim = paddle_mobile::framework::make_ddim({1, 4, 60, 60}); + output.Resize(out_ddim); + DLOG << 3; + paddle_mobile::operators::sigmoid(&input, &output); + DLOG << 4; + auto *output_ptr = output.data(); + for (int j = 0; j < output.numel(); ++j) { + DLOG << " value of output: " << output_ptr[j]; + } + DLOG << 5; + return 0; +} diff --git a/test/test_helper.h b/test/test_helper.h index 58a7507f1b5e3fb600c583dee65cd228527f464d..e2d6a183cb7b4caf812a11e5e6b7ada8dbb3e747 100644 --- a/test/test_helper.h +++ b/test/test_helper.h @@ -23,7 +23,7 @@ limitations under the License. */ static const std::string g_googlenet = "../models/googlenet"; static const std::string g_mobilenet = "../models/mobilenet"; -static const std::string g_mobilenet_ssd = "../models/mobilenet"; +static const std::string g_mobilenet_ssd = "../models/mobilenet+ssd"; static const std::string g_squeezenet = "../models/squeezenet"; static const std::string g_resnet = "../models/image_classification_resnet.inference.model";