未验证 提交 f4fd0b72 编写于 作者: H Houjiang Chen 提交者: GitHub

Merge branch 'develop' into fix1386

...@@ -37,6 +37,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add"; ...@@ -37,6 +37,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add";
const char *G_OP_TYPE_LRN = "lrn"; const char *G_OP_TYPE_LRN = "lrn";
const char *G_OP_TYPE_MUL = "mul"; const char *G_OP_TYPE_MUL = "mul";
const char *G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms"; const char *G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms";
const char *G_OP_TYPE_NORM = "norm";
const char *G_OP_TYPE_POLYGON_BOX_TRANSFORM = "polygon_box_transform"; const char *G_OP_TYPE_POLYGON_BOX_TRANSFORM = "polygon_box_transform";
const char *G_OP_TYPE_POOL2D = "pool2d"; const char *G_OP_TYPE_POOL2D = "pool2d";
const char *G_OP_TYPE_PRIOR_BOX = "prior_box"; const char *G_OP_TYPE_PRIOR_BOX = "prior_box";
...@@ -169,5 +170,6 @@ std::unordered_map< ...@@ -169,5 +170,6 @@ std::unordered_map<
{G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}}, {G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}}}; {G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}},
{G_OP_TYPE_NORM, {{"X"}, {"Out", "Norm"}}}};
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -332,8 +332,8 @@ void expand_conv_arg(ConvArgs *arg) { ...@@ -332,8 +332,8 @@ void expand_conv_arg(ConvArgs *arg) {
auto image_win_cnt = block_len; auto image_win_cnt = block_len;
auto image_win_cnt_last = block_last; auto image_win_cnt_last = block_last;
auto res_row_data_align4_pad = res_amount_per_row_pad / 8; auto res_row_data_align4_pad = res_amount_per_row_pad / 8;
auto prog_full_cnt = 2048 / (filter_amount_all / 16 * 2) - 1; auto prog_full_cnt = 1024 / (filter_amount_all / 16 * 2) - 1;
if (prog_full_cnt == 1023) { if (prog_full_cnt == 511) {
prog_full_cnt--; prog_full_cnt--;
} }
auto post_prog_full_cnt = auto post_prog_full_cnt =
......
...@@ -22,26 +22,97 @@ namespace paddle_mobile { ...@@ -22,26 +22,97 @@ namespace paddle_mobile {
namespace fpga { namespace fpga {
int16_t fp32_2_fp16(float fp32_num) { int16_t fp32_2_fp16(float fp32_num) {
unsigned long tmp = *(unsigned long *)(&fp32_num); // NOLINT int32_t tmp = *(reinterpret_cast<int32_t *>(&fp32_num));
auto t = (int16_t)(((tmp & 0x007fffff) >> 13) | ((tmp & 0x80000000) >> 16) | int16_t se_fp32 = (tmp >> 23) & 0x1ff;
(((tmp & 0x7f800000) >> 13) - (112 << 10))); int32_t m_fp32 = tmp & 0x007fffff;
if (tmp & 0x1000) { int16_t se_fp16 = 0;
t++; // roundoff int16_t m_fp16 = 0;
if (se_fp32 < 103) {
se_fp16 = 0x0000;
m_fp16 = m_fp32 >> 24;
} else if (se_fp32 < 113) {
se_fp16 = (0x0400 >> (113 - se_fp32));
m_fp16 = m_fp32 >> (126 - se_fp32);
} else if (se_fp32 <= 142) {
se_fp16 = (se_fp32 - 112) << 10;
m_fp16 = m_fp32 >> 13;
} else if (se_fp32 < 255) {
se_fp16 = 0x7C00;
m_fp16 = m_fp32 >> 24;
} else if (se_fp32 == 255) {
se_fp16 = 0x7C00;
m_fp16 = m_fp32 >> 13;
} else if (se_fp32 < 359) {
se_fp16 = 0x8000;
m_fp16 = m_fp32 >> 24;
} else if (se_fp32 < 369) {
se_fp16 = (0x0400 >> (369 - se_fp32)) | 0x8000;
m_fp16 = m_fp32 >> (382 - se_fp32);
} else if (se_fp32 <= 398) {
se_fp16 = ((se_fp32 - 368) << 10) | 0x8000;
m_fp16 = m_fp32 >> 13;
} else if (se_fp32 < 511) {
se_fp16 = 0x7C00;
m_fp16 = m_fp32 >> 24;
} else {
se_fp16 = 0x7C00;
m_fp16 = m_fp32 >> 13;
}
int16_t result = se_fp16 + m_fp16;
return result;
}
int32_t convertmantissa(int32_t i) {
int32_t m = i << 13;
int32_t e = 0;
while (!(m & 0x00800000)) {
e -= 0x00800000;
m <<= 1;
} }
return t; m &= ~0x00800000;
e += 0x38800000;
return m | e;
} }
float fp16_2_fp32(int16_t fp16_num) { float fp16_2_fp32(int16_t fp16_num) {
if (0 == fp16_num) { int16_t se_fp16 = fp16_num >> 10;
return 0; int16_t m_fp16 = fp16_num & 0x3ff;
int32_t e_fp32 = 0;
int16_t offset = 0;
int32_t m_fp32 = 0;
if (se_fp16 == 0) {
e_fp32 = 0;
offset = 0;
} else if (se_fp16 < 31) {
e_fp32 = se_fp16 << 23;
offset = 1024;
} else if (se_fp16 == 31) {
e_fp32 = 0x47800000;
offset = 1024;
} else if (se_fp16 == 32) {
e_fp32 = 0x80000000;
offset = 0;
} else if (se_fp16 < 63) {
e_fp32 = 0x80000000 + (se_fp16 - 32) << 23;
offset = 1024;
} else { // se_fp16 == 63
e_fp32 = 0xC7800000;
offset = 1024;
} }
int frac = (fp16_num & 0x3ff); int16_t a = offset + m_fp16;
int exp = ((fp16_num & 0x7c00) >> 10) + 112; if (a == 0) {
int s = fp16_num & 0x8000; m_fp32 = 0;
int tmp = 0; } else if (a < 1024) {
float fp32_num; int32_t tmp = a;
tmp = s << 16 | exp << 23 | frac << 13; m_fp32 = convertmantissa(tmp);
fp32_num = *(float *)&tmp; // NOLINT } else {
int32_t tmp = a - 1024;
m_fp32 = 0x38000000 + (tmp << 13);
}
int32_t tmp = e_fp32 + m_fp32;
float fp32_num = *(reinterpret_cast<float *>(&tmp));
return fp32_num; return fp32_num;
} }
...@@ -126,6 +197,5 @@ uint64_t vaddr_to_paddr(void *address) { ...@@ -126,6 +197,5 @@ uint64_t vaddr_to_paddr(void *address) {
return 0; return 0;
#endif #endif
} }
} // namespace fpga } // namespace fpga
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -256,6 +256,6 @@ int fpga_invalidate(void* address, size_t size); ...@@ -256,6 +256,6 @@ int fpga_invalidate(void* address, size_t size);
uint64_t vaddr_to_paddr(void* address); uint64_t vaddr_to_paddr(void* address);
void expand_conv_arg(ConvArgs* arg); void expand_conv_arg(ConvArgs* arg);
void expand_EW_arg(EWAddArgs* arg); void expand_EW_arg(EWAddArgs* arg);
inline int32_t convertmantissa(int32_t i);
} // namespace fpga } // namespace fpga
} // namespace paddle_mobile } // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef NORM_OP
#include "operators/kernel/norm_kernel.h"
#include "operators/kernel/central-arm-func/norm_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool NormKernel<CPU, float>::Init(NormParam<CPU> *param) {
return true;
}
template <>
void NormKernel<CPU, float>::Compute(const NormParam<CPU> &param) {
NormCompute<float>(param);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef NORM_OP
#pragma once
#include <cmath>
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
inline void GetDims(const framework::DDim &dim, int axis, int *pre, int *n,
int *post) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}
template <typename P>
void NormCompute(const NormParam<CPU> &param) {
const float epsilon = param.Epsilon();
int axis = param.Axis();
const framework::Tensor *input = param.InputX();
framework::Tensor *norm = param.OutputNorm();
framework::Tensor *out = param.Out();
auto x_dims = input->dims();
if (axis < 0) {
axis += x_dims.size();
}
int pre, n, post;
GetDims(x_dims, axis, &pre, &n, &post);
const float *input_ptr = input->data<float>();
float *norm_ptr = norm->mutable_data<float>();
float *out_ptr = out->mutable_data<float>();
for (int p = 0; p < pre; ++p) {
const float *in_tmp = input_ptr + p * n * post;
float *norm_tmp = norm_ptr + p * post;
// in_ch = 0; norm = epsilon + x * x
for (int i = 0; i < post; ++i) {
*norm_tmp = epsilon;
*norm_tmp += (*in_tmp) * (*in_tmp);
norm_tmp++;
in_tmp++;
}
// in_ch >= 1; norm += x * x
for (int c = 1; c < n; ++c) {
norm_tmp = norm_ptr + p * post;
for (int i = 0; i < post; ++i) {
*norm_tmp += (*in_tmp) * (*in_tmp);
norm_tmp++;
in_tmp++;
}
}
// norm = sqart(norm)
norm_tmp = norm_ptr + p * post;
for (int i = 0; i < post; ++i) {
*norm_tmp = sqrtf(*norm_tmp);
norm_tmp++;
}
// out = input / norm
in_tmp = input_ptr + p * n * post;
float *out_tmp = out_ptr + p * n * post;
for (int c = 0; c < n; ++c) {
norm_tmp = norm_ptr + p * post;
for (int j = 0; j < post; ++j) {
*out_tmp = *in_tmp / *norm_tmp;
in_tmp++;
norm_tmp++;
out_tmp++;
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef NORM_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class NormKernel
: public framework::OpKernelBase<DeviceType, NormParam<DeviceType>> {
public:
void Compute(const NormParam<DeviceType> &param);
bool Init(NormParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef NORM_OP
#include "operators/norm_op.h"
#include "framework/op_proto_maker.h"
#include "framework/op_registry.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void NormOp<Dtype, T>::InferShape() const {
auto x_dims = this->param_.InputX()->dims();
this->param_.Out()->Resize(x_dims);
int axis = this->param_.Axis();
if (axis < 0) {
axis += x_dims.size();
}
x_dims[axis] = 1;
this->param_.OutputNorm()->Resize(x_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(norm, ops::NormOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef NORM_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/norm_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class NormOp
: public framework::OperatorWithKernel<DeviceType, NormParam<DeviceType>,
NormKernel<DeviceType, T>> {
public:
NormOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, NormParam<DeviceType>,
NormKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -280,6 +280,11 @@ class OpParam { ...@@ -280,6 +280,11 @@ class OpParam {
return GetVarValue<T>("OutputBox", outputs, scope); return GetVarValue<T>("OutputBox", outputs, scope);
} }
template <typename T>
static T *OutputNormFrom(const VariableNameMap &outputs, const Scope &scope) {
return GetVarValue<T>("Norm", outputs, scope);
}
template <typename T> template <typename T>
static T *OutputVariancesFrom(const VariableNameMap &outputs, static T *OutputVariancesFrom(const VariableNameMap &outputs,
const Scope &scope) { const Scope &scope) {
...@@ -733,6 +738,41 @@ class LrnParam : public OpParam { ...@@ -733,6 +738,41 @@ class LrnParam : public OpParam {
}; };
#endif #endif
#ifdef NORM_OP
template <typename Dtype>
class NormParam : OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
NormParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
output_norm_ = OutputNormFrom<GType>(outputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
axis_ = GetAttr<int>("axis", attrs);
}
const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; }
RType *OutputNorm() const { return output_norm_; }
const float &Epsilon() const { return epsilon_; }
const int &Axis() const { return axis_; }
private:
RType *input_x_;
RType *out_;
RType *output_norm_;
float epsilon_;
int axis_;
};
#endif
#ifdef BATCHNORM_OP #ifdef BATCHNORM_OP
template <typename Dtype> template <typename Dtype>
class BatchNormParam : OpParam { class BatchNormParam : OpParam {
......
...@@ -235,6 +235,15 @@ if (NOT FOUND_MATCH) ...@@ -235,6 +235,15 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h)
target_link_libraries(test-relu-op paddle-mobile) target_link_libraries(test-relu-op paddle-mobile)
ADD_EXECUTABLE(test-relu6-op operators/test_relu6_op.cpp test_helper.h test_include.h)
target_link_libraries(test-relu6-op paddle-mobile)
ADD_EXECUTABLE(test-topk-op operators/test_topk_op.cpp test_helper.h test_include.h)
target_link_libraries(test-topk-op paddle-mobile)
ADD_EXECUTABLE(test-cast-op operators/test_cast_op.cpp test_helper.h test_include.h)
target_link_libraries(test-cast-op paddle-mobile)
# gen test # gen test
ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h)
target_link_libraries(test-fc-op paddle-mobile) target_link_libraries(test-fc-op paddle-mobile)
...@@ -394,4 +403,9 @@ if (NOT FOUND_MATCH) ...@@ -394,4 +403,9 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-sequence-softmax operators/test_sequence_softmax_op.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-sequence-softmax operators/test_sequence_softmax_op.cpp test_helper.h test_include.h)
target_link_libraries(test-sequence-softmax paddle-mobile) target_link_libraries(test-sequence-softmax paddle-mobile)
# gen test
ADD_EXECUTABLE(test-vgg16ssd net/test_vgg16ssd.cpp test_helper.h test_include.h)
target_link_libraries(test-vgg16ssd paddle-mobile)
endif () endif ()
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(1);
auto time1 = paddle_mobile::time();
auto isok =
paddle_mobile.Load(std::string(g_vgg16_ssd_combined) + "/model",
std::string(g_vgg16_ssd_combined) + "/params", false);
if (isok) {
auto time2 = paddle_mobile::time();
std::cout << "load cost :" << paddle_mobile::time_diff(time1, time1) << "ms"
<< std::endl;
std::vector<int64_t> dims{1, 3, 300, 300};
Tensor input_tensor;
SetupTensor<float>(&input_tensor, {1, 3, 300, 300}, static_cast<float>(0),
static_cast<float>(1));
std::vector<float> input(input_tensor.data<float>(),
input_tensor.data<float>() + input_tensor.numel());
auto vec_result = paddle_mobile.Predict(input, dims);
DLOG << vec_result;
}
return 0;
}
...@@ -17,157 +17,106 @@ limitations under the License. */ ...@@ -17,157 +17,106 @@ limitations under the License. */
#include "operators/batchnorm_op.h" #include "operators/batchnorm_op.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestBatchNormOp {
public:
explicit TestBatchNormOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks = void BatchNorm(const framework::Tensor *X, const framework::Tensor *Mean,
to_predict_program_->Blocks(); const framework::Tensor *Var, const framework::Tensor *Scale,
// DLOG << " **block size " << blocks.size(); const framework::Tensor *Bias, const float eps,
for (int i = 0; i < blocks.size(); ++i) { framework::Tensor *Y) {
std::shared_ptr<BlockDesc> block_desc = blocks[i]; const float *x = X->data<float>();
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops(); const float *m = Mean->data<float>();
// DLOG << " ops " << ops.size(); const float *v = Var->data<float>();
for (int j = 0; j < ops.size(); ++j) { const float *s = Scale->data<float>();
std::shared_ptr<OpDesc> op = ops[j]; const float *b = Bias->data<float>();
if (op->Type() == "batch_norm" && float *y = Y->mutable_data<float>();
op->Input("X")[0] == "conv2d_5.tmp_0") {
DLOG << " mul attr size: " << op->GetAttrMap().size(); int batch_size = X->dims()[0];
DLOG << " inputs size: " << op->GetInputs().size(); int channel = X->dims()[1];
DLOG << " outputs size: " << op->GetOutputs().size(); int hw = X->dims()[2] * X->dims()[3];
DLOG << " Input X is : " << op->Input("X")[0];
DLOG << " Input Mean is : " << op->Input("Mean")[0]; for (int batch = 0; batch < batch_size; ++batch) {
DLOG << " Input Variance is : " << op->Input("Variance")[0]; for (int c = 0; c < channel; ++c) {
DLOG << " Input Scale is : " << op->Input("Scale")[0]; float mean = m[c];
DLOG << " Input Bias is : " << op->Input("Bias")[0]; float inv_var = 1.f / std::sqrt(v[c] + eps);
DLOG << " Output Y is : " << op->Output("Y")[0]; float scale = s[c];
DLOG << " epsilon : " << op->GetAttrMap().at("epsilon").Get<float>(); float bias = b[c];
std::shared_ptr<operators::BatchNormOp<Dtype, float>> lrn = const float *input = x + (batch * channel + c) * hw;
std::make_shared<operators::BatchNormOp<Dtype, float>>( float *output = y + (batch * channel + c) * hw;
op->Type(), op->GetInputs(), op->GetOutputs(), for (int j = 0; j < hw; ++j) {
op->GetAttrMap(), program_.scope); output[j] = scale * ((input[j] - mean) * inv_var) + bias;
ops_of_block_[*block_desc.get()].push_back(lrn);
}
} }
} }
} }
}
std::shared_ptr<Tensor> predict_bn(const Tensor &t1, const Tensor &t2, int TestBatchNormOp(const std::vector<int> input_shape) {
const Tensor &t3, const Tensor &t4, framework::DDim dims = framework::make_ddim(input_shape);
const Tensor &t5) { VariableNameMap inputs;
// feed VariableNameMap outputs;
auto scope = program_.scope; auto scope = std::make_shared<framework::Scope>();
Variable *x1_feed_value = scope->Var("conv2d_5.tmp_0"); inputs["X"] = std::vector<std::string>({"input"});
auto tensor_x1 = x1_feed_value->GetMutable<LoDTensor>(); inputs["Mean"] = std::vector<std::string>({"mean"});
tensor_x1->ShareDataWith(t1); inputs["Variance"] = std::vector<std::string>({"variance"});
inputs["Scale"] = std::vector<std::string>({"scale"});
Variable *mean_feed_value = scope->Var("batch_norm_10.w_1"); inputs["Bias"] = std::vector<std::string>({"bias"});
auto tensor_mean = mean_feed_value->GetMutable<LoDTensor>(); outputs["Y"] = std::vector<std::string>({"output"});
tensor_mean->ShareDataWith(t2);
auto input_var = scope.get()->Var("input");
Variable *scale_feed_value = scope->Var("batch_norm_10.w_0"); auto input = input_var->template GetMutable<framework::LoDTensor>();
auto tensor_scale = scale_feed_value->GetMutable<LoDTensor>(); SetupTensor<float>(input, dims, -100.0, 100.0);
tensor_scale->ShareDataWith(t3);
auto mean_var = scope.get()->Var("mean");
Variable *variance_feed_value = scope->Var("batch_norm_10.w_2"); auto mean = mean_var->template GetMutable<framework::LoDTensor>();
auto tensor_variance = variance_feed_value->GetMutable<LoDTensor>(); SetupTensor<float>(mean, framework::make_ddim({input_shape[1]}), -10.0, 10.0);
tensor_variance->ShareDataWith(t4);
auto vari_var = scope.get()->Var("variance");
Variable *bias_feed_value = scope->Var("batch_norm_10.b_0"); auto vari = vari_var->template GetMutable<framework::LoDTensor>();
auto tensor_bias = bias_feed_value->GetMutable<LoDTensor>(); SetupTensor<float>(vari, framework::make_ddim({input_shape[1]}), -10.0, 10.0);
tensor_bias->ShareDataWith(t5);
auto scale_var = scope.get()->Var("scale");
Variable *output = scope->Var("batch_norm_10.tmp_2"); auto scale = scale_var->template GetMutable<framework::LoDTensor>();
auto *output_tensor = output->GetMutable<LoDTensor>(); SetupTensor<float>(scale, framework::make_ddim({input_shape[1]}), -10.0,
output_tensor->mutable_data<float>({1, 256, 38, 38}); 10.0);
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims(); auto bias_var = scope.get()->Var("bias");
auto bias = bias_var->template GetMutable<framework::LoDTensor>();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>(); SetupTensor<float>(bias, framework::make_ddim({input_shape[1]}), -10.0, 10.0);
out_tensor.reset(output_tensor);
auto output_var = scope.get()->Var("output");
predict_bn(t1, t2, t3, t4, t5, 0);
return out_tensor; float eps = 1e-6;
} framework::AttributeMap attrs;
attrs["epsilon"].Set<float>(eps);
private: attrs["momentum"].Set<float>(0.f);
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_; auto *op = new operators::BatchNormOp<CPU, float>("batch_norm", inputs,
std::map<framework::BlockDesc, outputs, attrs, scope);
std::vector<std::shared_ptr<OperatorBase<Dtype>>>> op->InferShape();
ops_of_block_; op->Init();
bool use_optimize_ = false; op->Run();
void predict_bn(const Tensor &t1, const Tensor &t2, const Tensor &t3, auto output = output_var->template Get<framework::LoDTensor>();
const Tensor &t4, const Tensor &t5, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block = framework::Tensor output_cmp;
to_predict_program_->Block(block_id); float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { BatchNorm(input, mean, vari, scale, bias, eps, &output_cmp);
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()"; const float *output_data = output->data<float>();
op->Run(); for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
} }
} }
}; }
template class TestBatchNormOp<CPU>;
} // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
int main() { int main() {
DLOG << "----------**********----------"; TestBatchNormOp({1, 1, 10, 10});
DLOG << "begin to run BatchNormOp Test"; TestBatchNormOp({1, 32, 100, 100});
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_mobilenet_ssd));
/// input x (4,10,2,2)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {1, 256, 38, 38}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
paddle_mobile::framework::Tensor mean;
SetupTensor<float>(&mean, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *mean_ptr = mean.data<float>();
paddle_mobile::framework::Tensor scale;
SetupTensor<float>(&scale, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *scale_ptr = scale.data<float>();
paddle_mobile::framework::Tensor variance;
SetupTensor<float>(&variance, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *variance_ptr = variance.data<float>();
paddle_mobile::framework::Tensor bias;
SetupTensor<float>(&bias, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *bias_ptr = bias.data<float>();
paddle_mobile::framework::TestBatchNormOp<paddle_mobile::CPU> testBatchNormOp(
program);
auto output_bn =
testBatchNormOp.predict_bn(inputx1, mean, scale, variance, bias);
auto *output_bn_ptr = output_bn->data<float>();
DLOG << " (" << inputx1_ptr[0] << " - " << mean_ptr[0] << ")/(("
<< variance_ptr[0] << " + 0.00001"
<< ")^0.5)* " << scale_ptr[0] << " + " << bias_ptr[0] << " = ";
DLOG << output_bn_ptr[0];
DLOG << "input_ptr 0 : " << inputx1_ptr[0];
DLOG << "output_ptr 0 : " << output_bn_ptr[0];
return 0; return 0;
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/cast_op.h"
namespace paddle_mobile {
template <typename Itype, typename Otype>
void Cast(const framework::Tensor *X, framework::Tensor *Y) {
const Itype *x = X->data<Itype>();
Otype *y = Y->mutable_data<Otype>();
for (int i = 0; i < X->numel(); ++i) {
y[i] = static_cast<Otype>(x[i]);
}
}
template <typename T>
int TypeInt() {}
template <>
int TypeInt<bool>() {
return 0;
}
template <>
int TypeInt<int>() {
return 2;
}
template <>
int TypeInt<int64_t>() {
return 3;
}
template <>
int TypeInt<float>() {
return 5;
}
template <>
int TypeInt<double>() {
return 6;
}
template <>
int TypeInt<size_t>() {
return 19;
}
template <>
int TypeInt<uint8_t>() {
return 20;
}
template <>
int TypeInt<int8_t>() {
return 21;
}
template <typename Itype, typename Otype>
int TestCastOp(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(input, dims, static_cast<Itype>(-100),
static_cast<Itype>(100));
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["in_dtype"].Set<int>(TypeInt<Itype>());
attrs["out_dtype"].Set<int>(TypeInt<Otype>());
auto *op =
new operators::CastOp<CPU, float>("cast", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
Otype *output_cmp_data = output_cmp.mutable_data<Otype>(output->dims());
Cast<Itype, Otype>(input, &output_cmp);
const Otype *output_data = output->data<Otype>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
TestCastOp<float, int>({1, 100});
TestCastOp<float, int>({128, 100});
TestCastOp<float, int64_t>({1, 100});
TestCastOp<float, int64_t>({128, 100});
TestCastOp<int, float>({1, 100});
TestCastOp<int, float>({128, 100});
TestCastOp<int64_t, float>({1, 100});
TestCastOp<int64_t, float>({128, 100});
return 0;
}
...@@ -103,20 +103,7 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { ...@@ -103,20 +103,7 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
} }
} // namespace paddle_mobile } // namespace paddle_mobile
int main(int argc, char *argv[]) { int Test(const int in_channels, const int in_height, const int in_width) {
if (argc < 4) {
LOG(paddle_mobile::kLOG_INFO)
<< "Usage:\n"
<< " ./test-pool-op in_channels in_height in_width \n"
<< " params:\n"
<< " -in_channels: int, input image's channels\n"
<< " -in_height: int, input image's height\n"
<< " -in_width: int, input image's width\n";
return 1;
}
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=0, stride=1"; << "float, pooling_type=max, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width);
...@@ -169,55 +156,75 @@ int main(int argc, char *argv[]) { ...@@ -169,55 +156,75 @@ int main(int argc, char *argv[]) {
<< "float, pooling_type=avg, kernel=3, pad=5, stride=2"; << "float, pooling_type=avg, kernel=3, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=0, stride=1"; << "float, pooling_type=max, kernel=2, pad=0, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=1, stride=1"; << "float, pooling_type=max, kernel=2, pad=1, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=2, stride=1"; << "float, pooling_type=max, kernel=2, pad=2, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=5, stride=1"; << "float, pooling_type=max, kernel=2, pad=5, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=0, stride=1"; << "float, pooling_type=avg, kernel=2, pad=0, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=1, stride=1"; << "float, pooling_type=avg, kernel=2, pad=1, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=2, stride=1"; << "float, pooling_type=avg, kernel=2, pad=2, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=5, stride=1"; << "float, pooling_type=avg, kernel=2, pad=5, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=0, stride=2"; << "float, pooling_type=max, kernel=2, pad=0, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=1, stride=2"; << "float, pooling_type=max, kernel=2, pad=1, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=2, stride=2"; << "float, pooling_type=max, kernel=2, pad=2, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=5, stride=2"; << "float, pooling_type=max, kernel=2, pad=5, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=0, stride=2"; << "float, pooling_type=avg, kernel=2, pad=0, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=1, stride=2"; << "float, pooling_type=avg, kernel=2, pad=1, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=2, stride=2"; << "float, pooling_type=avg, kernel=2, pad=2, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=5, stride=2"; << "float, pooling_type=avg, kernel=2, pad=5, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width);
}
int main(int argc, char *argv[]) {
// if (argc < 4) {
// LOG(paddle_mobile::kLOG_INFO)
// << "Usage:\n"
// << " ./test-pool-op in_channels in_height in_width \n"
// << " params:\n"
// << " -in_channels: int, input image's channels\n"
// << " -in_height: int, input image's height\n"
// << " -in_width: int, input image's width\n";
// return 1;
// }
// int in_channels = atoi(argv[1]);
// int in_height = atoi(argv[2]);
// int in_width = atoi(argv[3]);
Test(1, 10, 10);
Test(1, 50, 50);
Test(32, 10, 10);
Test(32, 50, 50);
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h"
#include "operators/relu_op.h"
namespace paddle_mobile {
void Relu6(const framework::Tensor *X, framework::Tensor *Y) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
for (int i = 0; i < X->numel(); ++i) {
float q = x[i];
y[i] = std::min(std::max(0.f, q), 6.f);
}
}
int TestRelu6Op(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
auto *op = new operators::Relu6Op<CPU, float>("relu6", inputs, outputs, attrs,
scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Relu6(input, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestRelu6Op({1, 1, 2, 3});
paddle_mobile::TestRelu6Op({1, 3, 11, 22});
paddle_mobile::TestRelu6Op({1, 32, 112, 112});
std::cout << "test relu6 op pass." << std::endl;
return 0;
}
...@@ -12,46 +12,71 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,46 +12,71 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h" #include "../test_include.h"
#include "operators/relu_op.h" #include "operators/relu_op.h"
int main() { namespace paddle_mobile {
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(g_resnet);
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail");
Executor4Test<paddle_mobile::CPU, void Relu(const framework::Tensor *X, framework::Tensor *Y) {
paddle_mobile::operators::ReluOp<paddle_mobile::CPU, float>> const float *x = X->data<float>();
executor(program, "relu"); float *y = Y->mutable_data<float>();
// 1. input_tensors; for (int i = 0; i < X->numel(); ++i) {
vector<Tensor> input_tensors; float q = x[i];
y[i] = std::max(0.f, q);
}
}
Tensor input1; int TestReluOp(const std::vector<int> input_shape) {
auto input1_data = CreateInput<float>(&input1, {1, 2, 3, 4}, -1, 1); framework::DDim dims = framework::make_ddim(input_shape);
input_tensors.push_back(input1); VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
// 2. input_names auto input_var = scope.get()->Var("input");
vector<string> input_names({ auto input = input_var->template GetMutable<framework::LoDTensor>();
"batch_norm_0.tmp_2", SetupTensor<float>(input, dims, -100.0, 100.0);
});
// 3. output_names auto output_var = scope.get()->Var("output");
vector<string> output_names({"batch_norm_0.tmp_3"});
// 4. out_dims; framework::AttributeMap attrs;
vector<DDim> out_ddims; auto *op =
auto out_ddim = paddle_mobile::framework::make_ddim({1, 2, 3, 4}); new operators::ReluOp<CPU, float>("relu", inputs, outputs, attrs, scope);
out_ddims.push_back(out_ddim); op->InferShape();
op->Init();
op->Run();
auto output = executor.Predict<LoDTensor>(input_tensors, input_names, auto output = output_var->template Get<framework::LoDTensor>();
output_names, out_ddims);
auto output0_data = output[0]->data<float>(); framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Relu(input, &output_cmp);
for (int j = 0; j < output[0]->numel(); ++j) { const float *output_data = output->data<float>();
DLOG << " value of output: " << output0_data[j]; for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
} }
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestReluOp({1, 1, 2, 3});
paddle_mobile::TestReluOp({1, 3, 11, 22});
paddle_mobile::TestReluOp({1, 32, 112, 112});
std::cout << "test relu op pass." << std::endl;
return 0; return 0;
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include <limits>
#include "../test_include.h"
#include "operators/top_k_op.h"
namespace paddle_mobile {
void TopK(const framework::Tensor *X, framework::Tensor *Y,
framework::Tensor *Indices, const int K) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
int64_t *indices = Indices->mutable_data<int64_t>();
int dim_size = X->dims().size();
int row = 1;
int col = X->dims()[dim_size - 1];
for (int i = 0; i < dim_size - 1; ++i) {
row *= X->dims()[i];
}
std::vector<float> vec(col);
for (int i = 0; i < row; ++i) {
for (int j = 0; j < col; ++j) {
vec[j] = x[i * col + j];
}
for (int k = 0; k < K; ++k) {
float max = vec[0];
int index = 0;
for (int j = 1; j < col; ++j) {
if (vec[j] > max) {
max = vec[j];
index = j;
}
}
y[i * K + k] = max;
indices[i * K + k] = index;
vec[index] = -std::numeric_limits<float>::max();
}
}
}
int TestTopKOp(const std::vector<int> input_shape, const int K) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
outputs["Indices"] = std::vector<std::string>({"indices"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
auto indices_var = scope.get()->Var("indices");
framework::AttributeMap attrs;
attrs["k"].Set<int>(K);
auto *op =
new operators::TopKOp<CPU, float>("top_k", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
auto indices = indices_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp, indices_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
int64_t *indices_cmp_data =
indices_cmp.mutable_data<int64_t>(indices->dims());
TopK(input, &output_cmp, &indices_cmp, K);
// sort output
float *output_data = const_cast<float *>(output->data<float>());
int64_t *indices_data = const_cast<int64_t *>(indices->data<int64_t>());
// std::vector<std::pair<float, size_t>> vec(K);
// for (int i = 0; i < output->numel() / K; ++i) {
// for (int j = 0; j < K; ++j) {
// vec[j] = std::move(std::make_pair(output_data[i * K + j],
// indices_data[i * K + j]));
// }
// std::sort(vec.begin(), vec.end(),
// [](const std::pair<float, size_t> &l,
// const std::pair<float, size_t> &r) {
// return l.first > r.first; });
// for (int j = 0; j < K; ++j) {
// output_data[i * K + j] = vec[j].first;
// indices_data[i * K + j] = vec[j].second;
// }
// }
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
for (int i = 0; i < indices->numel(); ++i) {
if (indices_data[i] != indices_cmp_data[i]) {
LOG(kLOG_INFO) << "indices_data[" << i << "] = " << indices_data[i]
<< ", indices_cmp_data[" << i
<< "] = " << indices_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
TestTopKOp({1, 100}, 1);
TestTopKOp({128, 100}, 10);
TestTopKOp({128, 2, 100}, 10);
return 0;
}
...@@ -50,6 +50,7 @@ static const char *g_yolo = "../models/yolo"; ...@@ -50,6 +50,7 @@ static const char *g_yolo = "../models/yolo";
static const char *g_yolo_combined = "../models/yolo_combined"; static const char *g_yolo_combined = "../models/yolo_combined";
static const char *g_yolo_mul = "../models/d"; static const char *g_yolo_mul = "../models/d";
static const char *g_fluid_fssd_new = "../models/fluid_fssd_new"; static const char *g_fluid_fssd_new = "../models/fluid_fssd_new";
static const char *g_vgg16_ssd_combined = "../models/vgg16_ssd_combined";
static const char *g_test_image_1x3x224x224 = static const char *g_test_image_1x3x224x224 =
"../images/test_image_1x3x224x224_float"; "../images/test_image_1x3x224x224_float";
static const char *g_test_image_1x3x224x224_banana = static const char *g_test_image_1x3x224x224_banana =
......
...@@ -146,6 +146,7 @@ if (NOT DEFINED CMAKE_IOS_DEVELOPER_ROOT) ...@@ -146,6 +146,7 @@ if (NOT DEFINED CMAKE_IOS_DEVELOPER_ROOT)
endif (NOT DEFINED CMAKE_IOS_DEVELOPER_ROOT) endif (NOT DEFINED CMAKE_IOS_DEVELOPER_ROOT)
set (CMAKE_IOS_DEVELOPER_ROOT ${CMAKE_IOS_DEVELOPER_ROOT} CACHE PATH "Location of iOS Platform") set (CMAKE_IOS_DEVELOPER_ROOT ${CMAKE_IOS_DEVELOPER_ROOT} CACHE PATH "Location of iOS Platform")
set(CMAKE_IOS_SDK_ROOT "/Applications/Xcode.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk")
# Find and use the most recent iOS sdk unless specified manually with CMAKE_IOS_SDK_ROOT # Find and use the most recent iOS sdk unless specified manually with CMAKE_IOS_SDK_ROOT
if (NOT DEFINED CMAKE_IOS_SDK_ROOT) if (NOT DEFINED CMAKE_IOS_SDK_ROOT)
file (GLOB _CMAKE_IOS_SDKS "${CMAKE_IOS_DEVELOPER_ROOT}/SDKs/*") file (GLOB _CMAKE_IOS_SDKS "${CMAKE_IOS_DEVELOPER_ROOT}/SDKs/*")
......
...@@ -215,6 +215,7 @@ endif() ...@@ -215,6 +215,7 @@ endif()
if(NOT FOUND_MATCH) if(NOT FOUND_MATCH)
message("--default--") message("--default--")
set(NORM_OP ON)
set(BATCHNORM_OP ON) set(BATCHNORM_OP ON)
set(CONV_TRANSPOSE_OP ON) set(CONV_TRANSPOSE_OP ON)
set(BOXCODER_OP ON) set(BOXCODER_OP ON)
...@@ -302,6 +303,9 @@ endif() ...@@ -302,6 +303,9 @@ endif()
# option(TRANSPOSE2_OP "" ON) # option(TRANSPOSE2_OP "" ON)
# endif () # endif ()
if (NORM_OP)
add_definitions(-DNORM_OP)
endif()
if (BATCHNORM_OP) if (BATCHNORM_OP)
add_definitions(-DBATCHNORM_OP) add_definitions(-DBATCHNORM_OP)
endif() endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册