diff --git a/src/common/types.cpp b/src/common/types.cpp index c25c5db30c7183b6685db03386ca9a9355ca6958..444789237f573f8da3eaf915abf61493967aabf8 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -37,6 +37,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add"; const char *G_OP_TYPE_LRN = "lrn"; const char *G_OP_TYPE_MUL = "mul"; const char *G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms"; +const char *G_OP_TYPE_NORM = "norm"; const char *G_OP_TYPE_POLYGON_BOX_TRANSFORM = "polygon_box_transform"; const char *G_OP_TYPE_POOL2D = "pool2d"; const char *G_OP_TYPE_PRIOR_BOX = "prior_box"; @@ -169,5 +170,6 @@ std::unordered_map< {G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}}, - {G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}}}; + {G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}}, + {G_OP_TYPE_NORM, {{"X"}, {"Out", "Norm"}}}}; } // namespace paddle_mobile diff --git a/src/fpga/V1/api.cpp b/src/fpga/V1/api.cpp index f17e79ffaacccbd9c24d3e0937274bce890aa357..137ac73512b9d88716ab585ba315f26aa3b14ea8 100644 --- a/src/fpga/V1/api.cpp +++ b/src/fpga/V1/api.cpp @@ -332,8 +332,8 @@ void expand_conv_arg(ConvArgs *arg) { auto image_win_cnt = block_len; auto image_win_cnt_last = block_last; auto res_row_data_align4_pad = res_amount_per_row_pad / 8; - auto prog_full_cnt = 2048 / (filter_amount_all / 16 * 2) - 1; - if (prog_full_cnt == 1023) { + auto prog_full_cnt = 1024 / (filter_amount_all / 16 * 2) - 1; + if (prog_full_cnt == 511) { prog_full_cnt--; } auto post_prog_full_cnt = diff --git a/src/fpga/common/fpga_common.cpp b/src/fpga/common/fpga_common.cpp old mode 100755 new mode 100644 index 1495e6e12ce4568c04db35b8d241cf9c7a40e9e0..0a1787aa3f211a247d95cd7124879ce14af980a9 --- a/src/fpga/common/fpga_common.cpp +++ b/src/fpga/common/fpga_common.cpp @@ -22,26 +22,97 @@ namespace paddle_mobile { namespace fpga { int16_t fp32_2_fp16(float fp32_num) { - unsigned long tmp = *(unsigned long *)(&fp32_num); // NOLINT - auto t = (int16_t)(((tmp & 0x007fffff) >> 13) | ((tmp & 0x80000000) >> 16) | - (((tmp & 0x7f800000) >> 13) - (112 << 10))); - if (tmp & 0x1000) { - t++; // roundoff + int32_t tmp = *(reinterpret_cast(&fp32_num)); + int16_t se_fp32 = (tmp >> 23) & 0x1ff; + int32_t m_fp32 = tmp & 0x007fffff; + int16_t se_fp16 = 0; + int16_t m_fp16 = 0; + + if (se_fp32 < 103) { + se_fp16 = 0x0000; + m_fp16 = m_fp32 >> 24; + } else if (se_fp32 < 113) { + se_fp16 = (0x0400 >> (113 - se_fp32)); + m_fp16 = m_fp32 >> (126 - se_fp32); + } else if (se_fp32 <= 142) { + se_fp16 = (se_fp32 - 112) << 10; + m_fp16 = m_fp32 >> 13; + } else if (se_fp32 < 255) { + se_fp16 = 0x7C00; + m_fp16 = m_fp32 >> 24; + } else if (se_fp32 == 255) { + se_fp16 = 0x7C00; + m_fp16 = m_fp32 >> 13; + } else if (se_fp32 < 359) { + se_fp16 = 0x8000; + m_fp16 = m_fp32 >> 24; + } else if (se_fp32 < 369) { + se_fp16 = (0x0400 >> (369 - se_fp32)) | 0x8000; + m_fp16 = m_fp32 >> (382 - se_fp32); + } else if (se_fp32 <= 398) { + se_fp16 = ((se_fp32 - 368) << 10) | 0x8000; + m_fp16 = m_fp32 >> 13; + } else if (se_fp32 < 511) { + se_fp16 = 0x7C00; + m_fp16 = m_fp32 >> 24; + } else { + se_fp16 = 0x7C00; + m_fp16 = m_fp32 >> 13; + } + int16_t result = se_fp16 + m_fp16; + return result; +} + +int32_t convertmantissa(int32_t i) { + int32_t m = i << 13; + int32_t e = 0; + while (!(m & 0x00800000)) { + e -= 0x00800000; + m <<= 1; } - return t; + m &= ~0x00800000; + e += 0x38800000; + return m | e; } float fp16_2_fp32(int16_t fp16_num) { - if (0 == fp16_num) { - return 0; + int16_t se_fp16 = fp16_num >> 10; + int16_t m_fp16 = fp16_num & 0x3ff; + int32_t e_fp32 = 0; + int16_t offset = 0; + int32_t m_fp32 = 0; + if (se_fp16 == 0) { + e_fp32 = 0; + offset = 0; + } else if (se_fp16 < 31) { + e_fp32 = se_fp16 << 23; + offset = 1024; + } else if (se_fp16 == 31) { + e_fp32 = 0x47800000; + offset = 1024; + } else if (se_fp16 == 32) { + e_fp32 = 0x80000000; + offset = 0; + } else if (se_fp16 < 63) { + e_fp32 = 0x80000000 + (se_fp16 - 32) << 23; + offset = 1024; + } else { // se_fp16 == 63 + e_fp32 = 0xC7800000; + offset = 1024; } - int frac = (fp16_num & 0x3ff); - int exp = ((fp16_num & 0x7c00) >> 10) + 112; - int s = fp16_num & 0x8000; - int tmp = 0; - float fp32_num; - tmp = s << 16 | exp << 23 | frac << 13; - fp32_num = *(float *)&tmp; // NOLINT + int16_t a = offset + m_fp16; + if (a == 0) { + m_fp32 = 0; + } else if (a < 1024) { + int32_t tmp = a; + m_fp32 = convertmantissa(tmp); + } else { + int32_t tmp = a - 1024; + m_fp32 = 0x38000000 + (tmp << 13); + } + + int32_t tmp = e_fp32 + m_fp32; + float fp32_num = *(reinterpret_cast(&tmp)); return fp32_num; } @@ -126,6 +197,5 @@ uint64_t vaddr_to_paddr(void *address) { return 0; #endif } - } // namespace fpga } // namespace paddle_mobile diff --git a/src/fpga/common/fpga_common.h b/src/fpga/common/fpga_common.h index 9bf67ba8292df66c3184b4b76ad0a45daa188537..c9519071fba94ad1e2b526d9e4d5cd96a1bcdbac 100755 --- a/src/fpga/common/fpga_common.h +++ b/src/fpga/common/fpga_common.h @@ -256,6 +256,6 @@ int fpga_invalidate(void* address, size_t size); uint64_t vaddr_to_paddr(void* address); void expand_conv_arg(ConvArgs* arg); void expand_EW_arg(EWAddArgs* arg); - +inline int32_t convertmantissa(int32_t i); } // namespace fpga } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/norm_kernel.cpp b/src/operators/kernel/arm/norm_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32617992cb1a60b44265343092f15316ea087df1 --- /dev/null +++ b/src/operators/kernel/arm/norm_kernel.cpp @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef NORM_OP + +#include "operators/kernel/norm_kernel.h" +#include "operators/kernel/central-arm-func/norm_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool NormKernel::Init(NormParam *param) { + return true; +} + +template <> +void NormKernel::Compute(const NormParam ¶m) { + NormCompute(param); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/norm_arm_func.h b/src/operators/kernel/central-arm-func/norm_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..71b4c5515e9493def7c8d824e61917dfc8d1b985 --- /dev/null +++ b/src/operators/kernel/central-arm-func/norm_arm_func.h @@ -0,0 +1,106 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef NORM_OP + +#pragma once + +#include +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +inline void GetDims(const framework::DDim &dim, int axis, int *pre, int *n, + int *post) { + *pre = 1; + *post = 1; + *n = dim[axis]; + for (int i = 0; i < axis; ++i) { + (*pre) *= dim[i]; + } + for (int i = axis + 1; i < dim.size(); ++i) { + (*post) *= dim[i]; + } +} + +template +void NormCompute(const NormParam ¶m) { + const float epsilon = param.Epsilon(); + int axis = param.Axis(); + + const framework::Tensor *input = param.InputX(); + framework::Tensor *norm = param.OutputNorm(); + framework::Tensor *out = param.Out(); + + auto x_dims = input->dims(); + if (axis < 0) { + axis += x_dims.size(); + } + + int pre, n, post; + GetDims(x_dims, axis, &pre, &n, &post); + + const float *input_ptr = input->data(); + float *norm_ptr = norm->mutable_data(); + float *out_ptr = out->mutable_data(); + + for (int p = 0; p < pre; ++p) { + const float *in_tmp = input_ptr + p * n * post; + float *norm_tmp = norm_ptr + p * post; + + // in_ch = 0; norm = epsilon + x * x + for (int i = 0; i < post; ++i) { + *norm_tmp = epsilon; + *norm_tmp += (*in_tmp) * (*in_tmp); + norm_tmp++; + in_tmp++; + } + + // in_ch >= 1; norm += x * x + for (int c = 1; c < n; ++c) { + norm_tmp = norm_ptr + p * post; + for (int i = 0; i < post; ++i) { + *norm_tmp += (*in_tmp) * (*in_tmp); + norm_tmp++; + in_tmp++; + } + } + + // norm = sqart(norm) + norm_tmp = norm_ptr + p * post; + for (int i = 0; i < post; ++i) { + *norm_tmp = sqrtf(*norm_tmp); + norm_tmp++; + } + + // out = input / norm + in_tmp = input_ptr + p * n * post; + float *out_tmp = out_ptr + p * n * post; + for (int c = 0; c < n; ++c) { + norm_tmp = norm_ptr + p * post; + for (int j = 0; j < post; ++j) { + *out_tmp = *in_tmp / *norm_tmp; + in_tmp++; + norm_tmp++; + out_tmp++; + } + } + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/norm_kernel.h b/src/operators/kernel/norm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..4f945bdb8b03a3952dd362df8b36a1db26f3fd93 --- /dev/null +++ b/src/operators/kernel/norm_kernel.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef NORM_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class NormKernel + : public framework::OpKernelBase> { + public: + void Compute(const NormParam ¶m); + bool Init(NormParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/norm_op.cpp b/src/operators/norm_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..deed9f69d1cf40ee70a211b0c9a84e4afeef6623 --- /dev/null +++ b/src/operators/norm_op.cpp @@ -0,0 +1,52 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef NORM_OP + +#include "operators/norm_op.h" +#include "framework/op_proto_maker.h" +#include "framework/op_registry.h" + +namespace paddle_mobile { +namespace operators { + +template +void NormOp::InferShape() const { + auto x_dims = this->param_.InputX()->dims(); + this->param_.Out()->Resize(x_dims); + + int axis = this->param_.Axis(); + if (axis < 0) { + axis += x_dims.size(); + } + x_dims[axis] = 1; + this->param_.OutputNorm()->Resize(x_dims); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(norm, ops::NormOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#ifdef PADDLE_MOBILE_CL +#endif + +#endif diff --git a/src/operators/norm_op.h b/src/operators/norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5bd6924af1d4aca125795be879cf67c09832965e --- /dev/null +++ b/src/operators/norm_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef NORM_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/norm_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +using std::string; +template +class NormOp + : public framework::OperatorWithKernel, + NormKernel> { + public: + NormOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel, + NormKernel>( + type, inputs, outputs, attrs, scope) {} + + void InferShape() const override; + + protected: +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 8976d8be8e0722fe6915af4786d55809a9f8ca7c..385be3b72f233c9a0951f17b8c61c23adc346988 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -280,6 +280,11 @@ class OpParam { return GetVarValue("OutputBox", outputs, scope); } + template + static T *OutputNormFrom(const VariableNameMap &outputs, const Scope &scope) { + return GetVarValue("Norm", outputs, scope); + } + template static T *OutputVariancesFrom(const VariableNameMap &outputs, const Scope &scope) { @@ -733,6 +738,41 @@ class LrnParam : public OpParam { }; #endif +#ifdef NORM_OP +template +class NormParam : OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + NormParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + out_ = OutFrom(outputs, scope); + output_norm_ = OutputNormFrom(outputs, scope); + epsilon_ = GetAttr("epsilon", attrs); + axis_ = GetAttr("axis", attrs); + } + + const RType *InputX() const { return input_x_; } + + RType *Out() const { return out_; } + + RType *OutputNorm() const { return output_norm_; } + + const float &Epsilon() const { return epsilon_; } + + const int &Axis() const { return axis_; } + + private: + RType *input_x_; + RType *out_; + RType *output_norm_; + float epsilon_; + int axis_; +}; +#endif + #ifdef BATCHNORM_OP template class BatchNormParam : OpParam { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6cab082d98f6247e5254b8bfa4c6a208c50fb42c..e90fe2c904868d0145a889033720e4948a1f3c33 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -235,6 +235,15 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h) target_link_libraries(test-relu-op paddle-mobile) + ADD_EXECUTABLE(test-relu6-op operators/test_relu6_op.cpp test_helper.h test_include.h) + target_link_libraries(test-relu6-op paddle-mobile) + + ADD_EXECUTABLE(test-topk-op operators/test_topk_op.cpp test_helper.h test_include.h) + target_link_libraries(test-topk-op paddle-mobile) + + ADD_EXECUTABLE(test-cast-op operators/test_cast_op.cpp test_helper.h test_include.h) + target_link_libraries(test-cast-op paddle-mobile) + # gen test ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h) target_link_libraries(test-fc-op paddle-mobile) @@ -394,4 +403,9 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-sequence-softmax operators/test_sequence_softmax_op.cpp test_helper.h test_include.h) target_link_libraries(test-sequence-softmax paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-vgg16ssd net/test_vgg16ssd.cpp test_helper.h test_include.h) + target_link_libraries(test-vgg16ssd paddle-mobile) + endif () diff --git a/test/net/test_vgg16ssd.cpp b/test/net/test_vgg16ssd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..387d6f38ea9185d0563b39defbed928bda0186bf --- /dev/null +++ b/test/net/test_vgg16ssd.cpp @@ -0,0 +1,46 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_helper.h" +#include "../test_include.h" + +int main() { + paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile.SetThreadNum(1); + auto time1 = paddle_mobile::time(); + + auto isok = + paddle_mobile.Load(std::string(g_vgg16_ssd_combined) + "/model", + std::string(g_vgg16_ssd_combined) + "/params", false); + if (isok) { + auto time2 = paddle_mobile::time(); + std::cout << "load cost :" << paddle_mobile::time_diff(time1, time1) << "ms" + << std::endl; + + std::vector dims{1, 3, 300, 300}; + Tensor input_tensor; + SetupTensor(&input_tensor, {1, 3, 300, 300}, static_cast(0), + static_cast(1)); + + std::vector input(input_tensor.data(), + input_tensor.data() + input_tensor.numel()); + + auto vec_result = paddle_mobile.Predict(input, dims); + + DLOG << vec_result; + } + + return 0; +} diff --git a/test/operators/test_batchnorm_op.cpp b/test/operators/test_batchnorm_op.cpp index c027d4bd31d5ff41f42e9cd333618f8630aad5d9..f78aa4061205586e9d1540b65a8a3dbc32de6757 100644 --- a/test/operators/test_batchnorm_op.cpp +++ b/test/operators/test_batchnorm_op.cpp @@ -17,157 +17,106 @@ limitations under the License. */ #include "operators/batchnorm_op.h" namespace paddle_mobile { -namespace framework { - -template -class TestBatchNormOp { - public: - explicit TestBatchNormOp(const Program p) : program_(p) { - if (use_optimize_) { - to_predict_program_ = program_.optimizeProgram; - } else { - to_predict_program_ = program_.originProgram; - } - const std::vector> blocks = - to_predict_program_->Blocks(); - // DLOG << " **block size " << blocks.size(); - for (int i = 0; i < blocks.size(); ++i) { - std::shared_ptr block_desc = blocks[i]; - std::vector> ops = block_desc->Ops(); - // DLOG << " ops " << ops.size(); - for (int j = 0; j < ops.size(); ++j) { - std::shared_ptr op = ops[j]; - if (op->Type() == "batch_norm" && - op->Input("X")[0] == "conv2d_5.tmp_0") { - DLOG << " mul attr size: " << op->GetAttrMap().size(); - DLOG << " inputs size: " << op->GetInputs().size(); - DLOG << " outputs size: " << op->GetOutputs().size(); - DLOG << " Input X is : " << op->Input("X")[0]; - DLOG << " Input Mean is : " << op->Input("Mean")[0]; - DLOG << " Input Variance is : " << op->Input("Variance")[0]; - DLOG << " Input Scale is : " << op->Input("Scale")[0]; - DLOG << " Input Bias is : " << op->Input("Bias")[0]; - DLOG << " Output Y is : " << op->Output("Y")[0]; - DLOG << " epsilon : " << op->GetAttrMap().at("epsilon").Get(); - std::shared_ptr> lrn = - std::make_shared>( - op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); - ops_of_block_[*block_desc.get()].push_back(lrn); - } +void BatchNorm(const framework::Tensor *X, const framework::Tensor *Mean, + const framework::Tensor *Var, const framework::Tensor *Scale, + const framework::Tensor *Bias, const float eps, + framework::Tensor *Y) { + const float *x = X->data(); + const float *m = Mean->data(); + const float *v = Var->data(); + const float *s = Scale->data(); + const float *b = Bias->data(); + float *y = Y->mutable_data(); + + int batch_size = X->dims()[0]; + int channel = X->dims()[1]; + int hw = X->dims()[2] * X->dims()[3]; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int c = 0; c < channel; ++c) { + float mean = m[c]; + float inv_var = 1.f / std::sqrt(v[c] + eps); + float scale = s[c]; + float bias = b[c]; + const float *input = x + (batch * channel + c) * hw; + float *output = y + (batch * channel + c) * hw; + for (int j = 0; j < hw; ++j) { + output[j] = scale * ((input[j] - mean) * inv_var) + bias; } } } +} - std::shared_ptr predict_bn(const Tensor &t1, const Tensor &t2, - const Tensor &t3, const Tensor &t4, - const Tensor &t5) { - // feed - auto scope = program_.scope; - Variable *x1_feed_value = scope->Var("conv2d_5.tmp_0"); - auto tensor_x1 = x1_feed_value->GetMutable(); - tensor_x1->ShareDataWith(t1); - - Variable *mean_feed_value = scope->Var("batch_norm_10.w_1"); - auto tensor_mean = mean_feed_value->GetMutable(); - tensor_mean->ShareDataWith(t2); - - Variable *scale_feed_value = scope->Var("batch_norm_10.w_0"); - auto tensor_scale = scale_feed_value->GetMutable(); - tensor_scale->ShareDataWith(t3); - - Variable *variance_feed_value = scope->Var("batch_norm_10.w_2"); - auto tensor_variance = variance_feed_value->GetMutable(); - tensor_variance->ShareDataWith(t4); - - Variable *bias_feed_value = scope->Var("batch_norm_10.b_0"); - auto tensor_bias = bias_feed_value->GetMutable(); - tensor_bias->ShareDataWith(t5); - - Variable *output = scope->Var("batch_norm_10.tmp_2"); - auto *output_tensor = output->GetMutable(); - output_tensor->mutable_data({1, 256, 38, 38}); - // DLOG << typeid(output_tensor).name(); - // DLOG << "output_tensor dims: " << output_tensor->dims(); - - std::shared_ptr out_tensor = std::make_shared(); - out_tensor.reset(output_tensor); - - predict_bn(t1, t2, t3, t4, t5, 0); - return out_tensor; - } - - private: - const framework::Program program_; - std::shared_ptr to_predict_program_; - std::map>>> - ops_of_block_; - bool use_optimize_ = false; - - void predict_bn(const Tensor &t1, const Tensor &t2, const Tensor &t3, - const Tensor &t4, const Tensor &t5, int block_id) { - std::shared_ptr to_predict_block = - to_predict_program_->Block(block_id); - for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { - auto op = ops_of_block_[*to_predict_block.get()][j]; - DLOG << "op -> run()"; - op->Run(); +int TestBatchNormOp(const std::vector input_shape) { + framework::DDim dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input"}); + inputs["Mean"] = std::vector({"mean"}); + inputs["Variance"] = std::vector({"variance"}); + inputs["Scale"] = std::vector({"scale"}); + inputs["Bias"] = std::vector({"bias"}); + outputs["Y"] = std::vector({"output"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, dims, -100.0, 100.0); + + auto mean_var = scope.get()->Var("mean"); + auto mean = mean_var->template GetMutable(); + SetupTensor(mean, framework::make_ddim({input_shape[1]}), -10.0, 10.0); + + auto vari_var = scope.get()->Var("variance"); + auto vari = vari_var->template GetMutable(); + SetupTensor(vari, framework::make_ddim({input_shape[1]}), -10.0, 10.0); + + auto scale_var = scope.get()->Var("scale"); + auto scale = scale_var->template GetMutable(); + SetupTensor(scale, framework::make_ddim({input_shape[1]}), -10.0, + 10.0); + + auto bias_var = scope.get()->Var("bias"); + auto bias = bias_var->template GetMutable(); + SetupTensor(bias, framework::make_ddim({input_shape[1]}), -10.0, 10.0); + + auto output_var = scope.get()->Var("output"); + + float eps = 1e-6; + framework::AttributeMap attrs; + attrs["epsilon"].Set(eps); + attrs["momentum"].Set(0.f); + + auto *op = new operators::BatchNormOp("batch_norm", inputs, + outputs, attrs, scope); + op->InferShape(); + op->Init(); + op->Run(); + + auto output = output_var->template Get(); + + framework::Tensor output_cmp; + float *output_cmp_data = output_cmp.mutable_data(output->dims()); + BatchNorm(input, mean, vari, scale, bias, eps, &output_cmp); + + const float *output_data = output->data(); + for (int i = 0; i < output->numel(); ++i) { + float gap = output_data[i] - output_cmp_data[i]; + if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); } } -}; +} -template class TestBatchNormOp; -} // namespace framework } // namespace paddle_mobile int main() { - DLOG << "----------**********----------"; - DLOG << "begin to run BatchNormOp Test"; - paddle_mobile::framework::Loader loader; - auto program = loader.Load(std::string(g_mobilenet_ssd)); - - /// input x (4,10,2,2) - paddle_mobile::framework::Tensor inputx1; - SetupTensor(&inputx1, {1, 256, 38, 38}, static_cast(0), - static_cast(1)); - auto *inputx1_ptr = inputx1.data(); - - paddle_mobile::framework::Tensor mean; - SetupTensor(&mean, {256}, static_cast(0), - static_cast(1)); - auto *mean_ptr = mean.data(); - - paddle_mobile::framework::Tensor scale; - SetupTensor(&scale, {256}, static_cast(0), - static_cast(1)); - auto *scale_ptr = scale.data(); - - paddle_mobile::framework::Tensor variance; - SetupTensor(&variance, {256}, static_cast(0), - static_cast(1)); - auto *variance_ptr = variance.data(); - - paddle_mobile::framework::Tensor bias; - SetupTensor(&bias, {256}, static_cast(0), - static_cast(1)); - auto *bias_ptr = bias.data(); - - paddle_mobile::framework::TestBatchNormOp testBatchNormOp( - program); - - auto output_bn = - testBatchNormOp.predict_bn(inputx1, mean, scale, variance, bias); - auto *output_bn_ptr = output_bn->data(); - - DLOG << " (" << inputx1_ptr[0] << " - " << mean_ptr[0] << ")/((" - << variance_ptr[0] << " + 0.00001" - << ")^0.5)* " << scale_ptr[0] << " + " << bias_ptr[0] << " = "; - DLOG << output_bn_ptr[0]; - - DLOG << "input_ptr 0 : " << inputx1_ptr[0]; - DLOG << "output_ptr 0 : " << output_bn_ptr[0]; - + TestBatchNormOp({1, 1, 10, 10}); + TestBatchNormOp({1, 32, 100, 100}); return 0; } diff --git a/test/operators/test_cast_op.cpp b/test/operators/test_cast_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..df6fab705bfb639f2b2d0b6d8c30bb86512b84d0 --- /dev/null +++ b/test/operators/test_cast_op.cpp @@ -0,0 +1,126 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../test_include.h" +#include "operators/cast_op.h" + +namespace paddle_mobile { + +template +void Cast(const framework::Tensor *X, framework::Tensor *Y) { + const Itype *x = X->data(); + Otype *y = Y->mutable_data(); + + for (int i = 0; i < X->numel(); ++i) { + y[i] = static_cast(x[i]); + } +} + +template +int TypeInt() {} +template <> +int TypeInt() { + return 0; +} +template <> +int TypeInt() { + return 2; +} +template <> +int TypeInt() { + return 3; +} +template <> +int TypeInt() { + return 5; +} +template <> +int TypeInt() { + return 6; +} +template <> +int TypeInt() { + return 19; +} +template <> +int TypeInt() { + return 20; +} +template <> +int TypeInt() { + return 21; +} + +template +int TestCastOp(const std::vector input_shape) { + framework::DDim dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input"}); + outputs["Out"] = std::vector({"output"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, dims, static_cast(-100), + static_cast(100)); + + auto output_var = scope.get()->Var("output"); + + framework::AttributeMap attrs; + attrs["in_dtype"].Set(TypeInt()); + attrs["out_dtype"].Set(TypeInt()); + auto *op = + new operators::CastOp("cast", inputs, outputs, attrs, scope); + op->InferShape(); + op->Init(); + op->Run(); + + auto output = output_var->template Get(); + + framework::Tensor output_cmp; + Otype *output_cmp_data = output_cmp.mutable_data(output->dims()); + Cast(input, &output_cmp); + + const Otype *output_data = output->data(); + for (int i = 0; i < output->numel(); ++i) { + float gap = output_data[i] - output_cmp_data[i]; + if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); + } + } + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main(int argc, char *argv[]) { + TestCastOp({1, 100}); + TestCastOp({128, 100}); + + TestCastOp({1, 100}); + TestCastOp({128, 100}); + + TestCastOp({1, 100}); + TestCastOp({128, 100}); + + TestCastOp({1, 100}); + TestCastOp({128, 100}); + return 0; +} diff --git a/test/operators/test_pool_op.cpp b/test/operators/test_pool_op.cpp index acbf0eaf34c8cb7b35a94fd4e8a4a3867a7c1dff..c7590512f92e2166ea082986fb97bed771eb2b15 100644 --- a/test/operators/test_pool_op.cpp +++ b/test/operators/test_pool_op.cpp @@ -103,20 +103,7 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { } } // namespace paddle_mobile -int main(int argc, char *argv[]) { - if (argc < 4) { - LOG(paddle_mobile::kLOG_INFO) - << "Usage:\n" - << " ./test-pool-op in_channels in_height in_width \n" - << " params:\n" - << " -in_channels: int, input image's channels\n" - << " -in_height: int, input image's height\n" - << " -in_width: int, input image's width\n"; - return 1; - } - int in_channels = atoi(argv[1]); - int in_height = atoi(argv[2]); - int in_width = atoi(argv[3]); +int Test(const int in_channels, const int in_height, const int in_width) { LOG(paddle_mobile::kLOG_INFO) << "float, pooling_type=max, kernel=3, pad=0, stride=1"; paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width); @@ -169,55 +156,75 @@ int main(int argc, char *argv[]) { << "float, pooling_type=avg, kernel=3, pad=5, stride=2"; paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=2, pad=0, stride=1"; - // paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=2, pad=1, stride=1"; - // paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=2, pad=2, stride=1"; - // paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=2, pad=5, stride=1"; - // paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width); - // - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=2, pad=0, stride=1"; - // paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=2, pad=1, stride=1"; - // paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=2, pad=2, stride=1"; - // paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=2, pad=5, stride=1"; - // paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width); - // - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=2, pad=0, stride=2"; - // paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=2, pad=1, stride=2"; - // paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=2, pad=2, stride=2"; - // paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=2, pad=5, stride=2"; - // paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width); - // - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=2, pad=0, stride=2"; - // paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=2, pad=1, stride=2"; - // paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=2, pad=2, stride=2"; - // paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=2, pad=5, stride=2"; - // paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=2, pad=0, stride=1"; + paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=2, pad=1, stride=1"; + paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=2, pad=2, stride=1"; + paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=2, pad=5, stride=1"; + paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width); + + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=2, pad=0, stride=1"; + paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=2, pad=1, stride=1"; + paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=2, pad=2, stride=1"; + paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=2, pad=5, stride=1"; + paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width); + + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=2, pad=0, stride=2"; + paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=2, pad=1, stride=2"; + paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=2, pad=2, stride=2"; + paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=2, pad=5, stride=2"; + paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width); + + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=2, pad=0, stride=2"; + paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=2, pad=1, stride=2"; + paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=2, pad=2, stride=2"; + paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=2, pad=5, stride=2"; + paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width); +} + +int main(int argc, char *argv[]) { + // if (argc < 4) { + // LOG(paddle_mobile::kLOG_INFO) + // << "Usage:\n" + // << " ./test-pool-op in_channels in_height in_width \n" + // << " params:\n" + // << " -in_channels: int, input image's channels\n" + // << " -in_height: int, input image's height\n" + // << " -in_width: int, input image's width\n"; + // return 1; + // } + // int in_channels = atoi(argv[1]); + // int in_height = atoi(argv[2]); + // int in_width = atoi(argv[3]); + Test(1, 10, 10); + Test(1, 50, 50); + Test(32, 10, 10); + Test(32, 50, 50); } diff --git a/test/operators/test_relu6_op.cpp b/test/operators/test_relu6_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d2885f3ea94f522adbccefe5010662920a367a4 --- /dev/null +++ b/test/operators/test_relu6_op.cpp @@ -0,0 +1,82 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "../test_include.h" +#include "operators/relu_op.h" + +namespace paddle_mobile { + +void Relu6(const framework::Tensor *X, framework::Tensor *Y) { + const float *x = X->data(); + float *y = Y->mutable_data(); + + for (int i = 0; i < X->numel(); ++i) { + float q = x[i]; + y[i] = std::min(std::max(0.f, q), 6.f); + } +} + +int TestRelu6Op(const std::vector input_shape) { + framework::DDim dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input"}); + outputs["Out"] = std::vector({"output"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, dims, -100.0, 100.0); + + auto output_var = scope.get()->Var("output"); + + framework::AttributeMap attrs; + auto *op = new operators::Relu6Op("relu6", inputs, outputs, attrs, + scope); + op->InferShape(); + op->Init(); + op->Run(); + + auto output = output_var->template Get(); + + framework::Tensor output_cmp; + float *output_cmp_data = output_cmp.mutable_data(output->dims()); + Relu6(input, &output_cmp); + + const float *output_data = output->data(); + for (int i = 0; i < output->numel(); ++i) { + float gap = output_data[i] - output_cmp_data[i]; + if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); + } + } + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main() { + paddle_mobile::TestRelu6Op({1, 1, 2, 3}); + paddle_mobile::TestRelu6Op({1, 3, 11, 22}); + paddle_mobile::TestRelu6Op({1, 32, 112, 112}); + std::cout << "test relu6 op pass." << std::endl; + return 0; +} diff --git a/test/operators/test_relu_op.cpp b/test/operators/test_relu_op.cpp index 542d3d18f6a383c1e03962ba845b39c04a51631b..c38a16e684029e22a3f1b489c1a83776e91382c2 100644 --- a/test/operators/test_relu_op.cpp +++ b/test/operators/test_relu_op.cpp @@ -12,46 +12,71 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "../test_include.h" #include "operators/relu_op.h" -int main() { - paddle_mobile::framework::Loader loader; - auto program = loader.Load(g_resnet); - PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, - "program file read fail"); +namespace paddle_mobile { - Executor4Test> - executor(program, "relu"); +void Relu(const framework::Tensor *X, framework::Tensor *Y) { + const float *x = X->data(); + float *y = Y->mutable_data(); - // 1. input_tensors; - vector input_tensors; + for (int i = 0; i < X->numel(); ++i) { + float q = x[i]; + y[i] = std::max(0.f, q); + } +} - Tensor input1; - auto input1_data = CreateInput(&input1, {1, 2, 3, 4}, -1, 1); - input_tensors.push_back(input1); +int TestReluOp(const std::vector input_shape) { + framework::DDim dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input"}); + outputs["Out"] = std::vector({"output"}); - // 2. input_names - vector input_names({ - "batch_norm_0.tmp_2", - }); + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, dims, -100.0, 100.0); - // 3. output_names - vector output_names({"batch_norm_0.tmp_3"}); + auto output_var = scope.get()->Var("output"); - // 4. out_dims; - vector out_ddims; - auto out_ddim = paddle_mobile::framework::make_ddim({1, 2, 3, 4}); - out_ddims.push_back(out_ddim); + framework::AttributeMap attrs; + auto *op = + new operators::ReluOp("relu", inputs, outputs, attrs, scope); + op->InferShape(); + op->Init(); + op->Run(); - auto output = executor.Predict(input_tensors, input_names, - output_names, out_ddims); + auto output = output_var->template Get(); - auto output0_data = output[0]->data(); + framework::Tensor output_cmp; + float *output_cmp_data = output_cmp.mutable_data(output->dims()); + Relu(input, &output_cmp); - for (int j = 0; j < output[0]->numel(); ++j) { - DLOG << " value of output: " << output0_data[j]; + const float *output_data = output->data(); + for (int i = 0; i < output->numel(); ++i) { + float gap = output_data[i] - output_cmp_data[i]; + if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); + } } + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main() { + paddle_mobile::TestReluOp({1, 1, 2, 3}); + paddle_mobile::TestReluOp({1, 3, 11, 22}); + paddle_mobile::TestReluOp({1, 32, 112, 112}); + std::cout << "test relu op pass." << std::endl; return 0; } diff --git a/test/operators/test_topk_op.cpp b/test/operators/test_topk_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7244232d0737fb9fe77448331c0bdf2477b4f8e5 --- /dev/null +++ b/test/operators/test_topk_op.cpp @@ -0,0 +1,139 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "../test_include.h" +#include "operators/top_k_op.h" + +namespace paddle_mobile { + +void TopK(const framework::Tensor *X, framework::Tensor *Y, + framework::Tensor *Indices, const int K) { + const float *x = X->data(); + float *y = Y->mutable_data(); + int64_t *indices = Indices->mutable_data(); + + int dim_size = X->dims().size(); + int row = 1; + int col = X->dims()[dim_size - 1]; + for (int i = 0; i < dim_size - 1; ++i) { + row *= X->dims()[i]; + } + + std::vector vec(col); + for (int i = 0; i < row; ++i) { + for (int j = 0; j < col; ++j) { + vec[j] = x[i * col + j]; + } + for (int k = 0; k < K; ++k) { + float max = vec[0]; + int index = 0; + for (int j = 1; j < col; ++j) { + if (vec[j] > max) { + max = vec[j]; + index = j; + } + } + y[i * K + k] = max; + indices[i * K + k] = index; + vec[index] = -std::numeric_limits::max(); + } + } +} + +int TestTopKOp(const std::vector input_shape, const int K) { + framework::DDim dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input"}); + outputs["Out"] = std::vector({"output"}); + outputs["Indices"] = std::vector({"indices"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, dims, -100.0, 100.0); + + auto output_var = scope.get()->Var("output"); + auto indices_var = scope.get()->Var("indices"); + + framework::AttributeMap attrs; + attrs["k"].Set(K); + auto *op = + new operators::TopKOp("top_k", inputs, outputs, attrs, scope); + op->InferShape(); + op->Init(); + op->Run(); + + auto output = output_var->template Get(); + auto indices = indices_var->template Get(); + + framework::Tensor output_cmp, indices_cmp; + float *output_cmp_data = output_cmp.mutable_data(output->dims()); + int64_t *indices_cmp_data = + indices_cmp.mutable_data(indices->dims()); + TopK(input, &output_cmp, &indices_cmp, K); + + // sort output + float *output_data = const_cast(output->data()); + int64_t *indices_data = const_cast(indices->data()); + // std::vector> vec(K); + // for (int i = 0; i < output->numel() / K; ++i) { + // for (int j = 0; j < K; ++j) { + // vec[j] = std::move(std::make_pair(output_data[i * K + j], + // indices_data[i * K + j])); + // } + // std::sort(vec.begin(), vec.end(), + // [](const std::pair &l, + // const std::pair &r) { + // return l.first > r.first; }); + // for (int j = 0; j < K; ++j) { + // output_data[i * K + j] = vec[j].first; + // indices_data[i * K + j] = vec[j].second; + // } + // } + + for (int i = 0; i < output->numel(); ++i) { + float gap = output_data[i] - output_cmp_data[i]; + if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); + } + } + + for (int i = 0; i < indices->numel(); ++i) { + if (indices_data[i] != indices_cmp_data[i]) { + LOG(kLOG_INFO) << "indices_data[" << i << "] = " << indices_data[i] + << ", indices_cmp_data[" << i + << "] = " << indices_cmp_data[i]; + delete op; + exit(1); + } + } + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main(int argc, char *argv[]) { + TestTopKOp({1, 100}, 1); + TestTopKOp({128, 100}, 10); + TestTopKOp({128, 2, 100}, 10); + return 0; +} diff --git a/test/test_helper.h b/test/test_helper.h index 60e907fa6fb301b2e944a6d938f662225a743d41..7d677217a0fddecc23fb354b2b8b5c9a652746cf 100644 --- a/test/test_helper.h +++ b/test/test_helper.h @@ -50,6 +50,7 @@ static const char *g_yolo = "../models/yolo"; static const char *g_yolo_combined = "../models/yolo_combined"; static const char *g_yolo_mul = "../models/d"; static const char *g_fluid_fssd_new = "../models/fluid_fssd_new"; +static const char *g_vgg16_ssd_combined = "../models/vgg16_ssd_combined"; static const char *g_test_image_1x3x224x224 = "../images/test_image_1x3x224x224_float"; static const char *g_test_image_1x3x224x224_banana = diff --git a/tools/ios-cmake/ios.toolchain.cmake b/tools/ios-cmake/ios.toolchain.cmake index 6000f7a8e5dffcd8693b56539f4519840ddd8be8..12dd1721d488cd8ba776b8f302f137ad2d60fe73 100644 --- a/tools/ios-cmake/ios.toolchain.cmake +++ b/tools/ios-cmake/ios.toolchain.cmake @@ -146,6 +146,7 @@ if (NOT DEFINED CMAKE_IOS_DEVELOPER_ROOT) endif (NOT DEFINED CMAKE_IOS_DEVELOPER_ROOT) set (CMAKE_IOS_DEVELOPER_ROOT ${CMAKE_IOS_DEVELOPER_ROOT} CACHE PATH "Location of iOS Platform") +set(CMAKE_IOS_SDK_ROOT "/Applications/Xcode.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk") # Find and use the most recent iOS sdk unless specified manually with CMAKE_IOS_SDK_ROOT if (NOT DEFINED CMAKE_IOS_SDK_ROOT) file (GLOB _CMAKE_IOS_SDKS "${CMAKE_IOS_DEVELOPER_ROOT}/SDKs/*") diff --git a/tools/op.cmake b/tools/op.cmake index 3563834e77fa5eba363fc64ba337dd30c95a3820..e98e95e4ed29ced5044315dd4fb59a1866964a82 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -215,6 +215,7 @@ endif() if(NOT FOUND_MATCH) message("--default--") + set(NORM_OP ON) set(BATCHNORM_OP ON) set(CONV_TRANSPOSE_OP ON) set(BOXCODER_OP ON) @@ -302,6 +303,9 @@ endif() # option(TRANSPOSE2_OP "" ON) # endif () +if (NORM_OP) + add_definitions(-DNORM_OP) +endif() if (BATCHNORM_OP) add_definitions(-DBATCHNORM_OP) endif()