提交 08b98472 编写于 作者: H Houjiang Chen 提交者: GitHub

Merge pull request #1387 from hjchen2/ocr_attention

Add/Refine unit tests
...@@ -235,6 +235,15 @@ if (NOT FOUND_MATCH) ...@@ -235,6 +235,15 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h)
target_link_libraries(test-relu-op paddle-mobile) target_link_libraries(test-relu-op paddle-mobile)
ADD_EXECUTABLE(test-relu6-op operators/test_relu6_op.cpp test_helper.h test_include.h)
target_link_libraries(test-relu6-op paddle-mobile)
ADD_EXECUTABLE(test-topk-op operators/test_topk_op.cpp test_helper.h test_include.h)
target_link_libraries(test-topk-op paddle-mobile)
ADD_EXECUTABLE(test-cast-op operators/test_cast_op.cpp test_helper.h test_include.h)
target_link_libraries(test-cast-op paddle-mobile)
# gen test # gen test
ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h)
target_link_libraries(test-fc-op paddle-mobile) target_link_libraries(test-fc-op paddle-mobile)
......
...@@ -17,157 +17,106 @@ limitations under the License. */ ...@@ -17,157 +17,106 @@ limitations under the License. */
#include "operators/batchnorm_op.h" #include "operators/batchnorm_op.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestBatchNormOp {
public:
explicit TestBatchNormOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks = void BatchNorm(const framework::Tensor *X, const framework::Tensor *Mean,
to_predict_program_->Blocks(); const framework::Tensor *Var, const framework::Tensor *Scale,
// DLOG << " **block size " << blocks.size(); const framework::Tensor *Bias, const float eps,
for (int i = 0; i < blocks.size(); ++i) { framework::Tensor *Y) {
std::shared_ptr<BlockDesc> block_desc = blocks[i]; const float *x = X->data<float>();
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops(); const float *m = Mean->data<float>();
// DLOG << " ops " << ops.size(); const float *v = Var->data<float>();
for (int j = 0; j < ops.size(); ++j) { const float *s = Scale->data<float>();
std::shared_ptr<OpDesc> op = ops[j]; const float *b = Bias->data<float>();
if (op->Type() == "batch_norm" && float *y = Y->mutable_data<float>();
op->Input("X")[0] == "conv2d_5.tmp_0") {
DLOG << " mul attr size: " << op->GetAttrMap().size(); int batch_size = X->dims()[0];
DLOG << " inputs size: " << op->GetInputs().size(); int channel = X->dims()[1];
DLOG << " outputs size: " << op->GetOutputs().size(); int hw = X->dims()[2] * X->dims()[3];
DLOG << " Input X is : " << op->Input("X")[0];
DLOG << " Input Mean is : " << op->Input("Mean")[0]; for (int batch = 0; batch < batch_size; ++batch) {
DLOG << " Input Variance is : " << op->Input("Variance")[0]; for (int c = 0; c < channel; ++c) {
DLOG << " Input Scale is : " << op->Input("Scale")[0]; float mean = m[c];
DLOG << " Input Bias is : " << op->Input("Bias")[0]; float inv_var = 1.f / std::sqrt(v[c] + eps);
DLOG << " Output Y is : " << op->Output("Y")[0]; float scale = s[c];
DLOG << " epsilon : " << op->GetAttrMap().at("epsilon").Get<float>(); float bias = b[c];
std::shared_ptr<operators::BatchNormOp<Dtype, float>> lrn = const float *input = x + (batch * channel + c) * hw;
std::make_shared<operators::BatchNormOp<Dtype, float>>( float *output = y + (batch * channel + c) * hw;
op->Type(), op->GetInputs(), op->GetOutputs(), for (int j = 0; j < hw; ++j) {
op->GetAttrMap(), program_.scope); output[j] = scale * ((input[j] - mean) * inv_var) + bias;
ops_of_block_[*block_desc.get()].push_back(lrn);
}
} }
} }
} }
}
std::shared_ptr<Tensor> predict_bn(const Tensor &t1, const Tensor &t2, int TestBatchNormOp(const std::vector<int> input_shape) {
const Tensor &t3, const Tensor &t4, framework::DDim dims = framework::make_ddim(input_shape);
const Tensor &t5) { VariableNameMap inputs;
// feed VariableNameMap outputs;
auto scope = program_.scope; auto scope = std::make_shared<framework::Scope>();
Variable *x1_feed_value = scope->Var("conv2d_5.tmp_0"); inputs["X"] = std::vector<std::string>({"input"});
auto tensor_x1 = x1_feed_value->GetMutable<LoDTensor>(); inputs["Mean"] = std::vector<std::string>({"mean"});
tensor_x1->ShareDataWith(t1); inputs["Variance"] = std::vector<std::string>({"variance"});
inputs["Scale"] = std::vector<std::string>({"scale"});
Variable *mean_feed_value = scope->Var("batch_norm_10.w_1"); inputs["Bias"] = std::vector<std::string>({"bias"});
auto tensor_mean = mean_feed_value->GetMutable<LoDTensor>(); outputs["Y"] = std::vector<std::string>({"output"});
tensor_mean->ShareDataWith(t2);
auto input_var = scope.get()->Var("input");
Variable *scale_feed_value = scope->Var("batch_norm_10.w_0"); auto input = input_var->template GetMutable<framework::LoDTensor>();
auto tensor_scale = scale_feed_value->GetMutable<LoDTensor>(); SetupTensor<float>(input, dims, -100.0, 100.0);
tensor_scale->ShareDataWith(t3);
auto mean_var = scope.get()->Var("mean");
Variable *variance_feed_value = scope->Var("batch_norm_10.w_2"); auto mean = mean_var->template GetMutable<framework::LoDTensor>();
auto tensor_variance = variance_feed_value->GetMutable<LoDTensor>(); SetupTensor<float>(mean, framework::make_ddim({input_shape[1]}), -10.0, 10.0);
tensor_variance->ShareDataWith(t4);
auto vari_var = scope.get()->Var("variance");
Variable *bias_feed_value = scope->Var("batch_norm_10.b_0"); auto vari = vari_var->template GetMutable<framework::LoDTensor>();
auto tensor_bias = bias_feed_value->GetMutable<LoDTensor>(); SetupTensor<float>(vari, framework::make_ddim({input_shape[1]}), -10.0, 10.0);
tensor_bias->ShareDataWith(t5);
auto scale_var = scope.get()->Var("scale");
Variable *output = scope->Var("batch_norm_10.tmp_2"); auto scale = scale_var->template GetMutable<framework::LoDTensor>();
auto *output_tensor = output->GetMutable<LoDTensor>(); SetupTensor<float>(scale, framework::make_ddim({input_shape[1]}), -10.0,
output_tensor->mutable_data<float>({1, 256, 38, 38}); 10.0);
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims(); auto bias_var = scope.get()->Var("bias");
auto bias = bias_var->template GetMutable<framework::LoDTensor>();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>(); SetupTensor<float>(bias, framework::make_ddim({input_shape[1]}), -10.0, 10.0);
out_tensor.reset(output_tensor);
auto output_var = scope.get()->Var("output");
predict_bn(t1, t2, t3, t4, t5, 0);
return out_tensor; float eps = 1e-6;
} framework::AttributeMap attrs;
attrs["epsilon"].Set<float>(eps);
private: attrs["momentum"].Set<float>(0.f);
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_; auto *op = new operators::BatchNormOp<CPU, float>("batch_norm", inputs,
std::map<framework::BlockDesc, outputs, attrs, scope);
std::vector<std::shared_ptr<OperatorBase<Dtype>>>> op->InferShape();
ops_of_block_; op->Init();
bool use_optimize_ = false;
void predict_bn(const Tensor &t1, const Tensor &t2, const Tensor &t3,
const Tensor &t4, const Tensor &t5, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run(); op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
BatchNorm(input, mean, vari, scale, bias, eps, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
} }
} }
}; }
template class TestBatchNormOp<CPU>;
} // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
int main() { int main() {
DLOG << "----------**********----------"; TestBatchNormOp({1, 1, 10, 10});
DLOG << "begin to run BatchNormOp Test"; TestBatchNormOp({1, 32, 100, 100});
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_mobilenet_ssd));
/// input x (4,10,2,2)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {1, 256, 38, 38}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
paddle_mobile::framework::Tensor mean;
SetupTensor<float>(&mean, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *mean_ptr = mean.data<float>();
paddle_mobile::framework::Tensor scale;
SetupTensor<float>(&scale, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *scale_ptr = scale.data<float>();
paddle_mobile::framework::Tensor variance;
SetupTensor<float>(&variance, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *variance_ptr = variance.data<float>();
paddle_mobile::framework::Tensor bias;
SetupTensor<float>(&bias, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *bias_ptr = bias.data<float>();
paddle_mobile::framework::TestBatchNormOp<paddle_mobile::CPU> testBatchNormOp(
program);
auto output_bn =
testBatchNormOp.predict_bn(inputx1, mean, scale, variance, bias);
auto *output_bn_ptr = output_bn->data<float>();
DLOG << " (" << inputx1_ptr[0] << " - " << mean_ptr[0] << ")/(("
<< variance_ptr[0] << " + 0.00001"
<< ")^0.5)* " << scale_ptr[0] << " + " << bias_ptr[0] << " = ";
DLOG << output_bn_ptr[0];
DLOG << "input_ptr 0 : " << inputx1_ptr[0];
DLOG << "output_ptr 0 : " << output_bn_ptr[0];
return 0; return 0;
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/cast_op.h"
namespace paddle_mobile {
template <typename Itype, typename Otype>
void Cast(const framework::Tensor *X, framework::Tensor *Y) {
const Itype *x = X->data<Itype>();
Otype *y = Y->mutable_data<Otype>();
for (int i = 0; i < X->numel(); ++i) {
y[i] = static_cast<Otype>(x[i]);
}
}
template <typename T>
int TypeInt() {}
template <>
int TypeInt<bool>() {
return 0;
}
template <>
int TypeInt<int>() {
return 2;
}
template <>
int TypeInt<int64_t>() {
return 3;
}
template <>
int TypeInt<float>() {
return 5;
}
template <>
int TypeInt<double>() {
return 6;
}
template <>
int TypeInt<size_t>() {
return 19;
}
template <>
int TypeInt<uint8_t>() {
return 20;
}
template <>
int TypeInt<int8_t>() {
return 21;
}
template <typename Itype, typename Otype>
int TestCastOp(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(input, dims, static_cast<Itype>(-100),
static_cast<Itype>(100));
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["in_dtype"].Set<int>(TypeInt<Itype>());
attrs["out_dtype"].Set<int>(TypeInt<Otype>());
auto *op =
new operators::CastOp<CPU, float>("cast", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
Otype *output_cmp_data = output_cmp.mutable_data<Otype>(output->dims());
Cast<Itype, Otype>(input, &output_cmp);
const Otype *output_data = output->data<Otype>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
TestCastOp<float, int>({1, 100});
TestCastOp<float, int>({128, 100});
TestCastOp<float, int64_t>({1, 100});
TestCastOp<float, int64_t>({128, 100});
TestCastOp<int, float>({1, 100});
TestCastOp<int, float>({128, 100});
TestCastOp<int64_t, float>({1, 100});
TestCastOp<int64_t, float>({128, 100});
return 0;
}
...@@ -103,20 +103,7 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { ...@@ -103,20 +103,7 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
} }
} // namespace paddle_mobile } // namespace paddle_mobile
int main(int argc, char *argv[]) { int Test(const int in_channels, const int in_height, const int in_width) {
if (argc < 4) {
LOG(paddle_mobile::kLOG_INFO)
<< "Usage:\n"
<< " ./test-pool-op in_channels in_height in_width \n"
<< " params:\n"
<< " -in_channels: int, input image's channels\n"
<< " -in_height: int, input image's height\n"
<< " -in_width: int, input image's width\n";
return 1;
}
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=0, stride=1"; << "float, pooling_type=max, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width);
...@@ -169,55 +156,75 @@ int main(int argc, char *argv[]) { ...@@ -169,55 +156,75 @@ int main(int argc, char *argv[]) {
<< "float, pooling_type=avg, kernel=3, pad=5, stride=2"; << "float, pooling_type=avg, kernel=3, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=0, stride=1"; << "float, pooling_type=max, kernel=2, pad=0, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=1, stride=1"; << "float, pooling_type=max, kernel=2, pad=1, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=2, stride=1"; << "float, pooling_type=max, kernel=2, pad=2, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=5, stride=1"; << "float, pooling_type=max, kernel=2, pad=5, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=0, stride=1"; << "float, pooling_type=avg, kernel=2, pad=0, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=1, stride=1"; << "float, pooling_type=avg, kernel=2, pad=1, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=2, stride=1"; << "float, pooling_type=avg, kernel=2, pad=2, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=5, stride=1"; << "float, pooling_type=avg, kernel=2, pad=5, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=0, stride=2"; << "float, pooling_type=max, kernel=2, pad=0, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=1, stride=2"; << "float, pooling_type=max, kernel=2, pad=1, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=2, stride=2"; << "float, pooling_type=max, kernel=2, pad=2, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=5, stride=2"; << "float, pooling_type=max, kernel=2, pad=5, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=0, stride=2"; << "float, pooling_type=avg, kernel=2, pad=0, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=1, stride=2"; << "float, pooling_type=avg, kernel=2, pad=1, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=2, stride=2"; << "float, pooling_type=avg, kernel=2, pad=2, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=5, stride=2"; << "float, pooling_type=avg, kernel=2, pad=5, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width);
}
int main(int argc, char *argv[]) {
// if (argc < 4) {
// LOG(paddle_mobile::kLOG_INFO)
// << "Usage:\n"
// << " ./test-pool-op in_channels in_height in_width \n"
// << " params:\n"
// << " -in_channels: int, input image's channels\n"
// << " -in_height: int, input image's height\n"
// << " -in_width: int, input image's width\n";
// return 1;
// }
// int in_channels = atoi(argv[1]);
// int in_height = atoi(argv[2]);
// int in_width = atoi(argv[3]);
Test(1, 10, 10);
Test(1, 50, 50);
Test(32, 10, 10);
Test(32, 50, 50);
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h"
#include "operators/relu_op.h"
namespace paddle_mobile {
void Relu6(const framework::Tensor *X, framework::Tensor *Y) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
for (int i = 0; i < X->numel(); ++i) {
float q = x[i];
y[i] = std::min(std::max(0.f, q), 6.f);
}
}
int TestRelu6Op(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
auto *op = new operators::Relu6Op<CPU, float>("relu6", inputs, outputs, attrs,
scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Relu6(input, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestRelu6Op({1, 1, 2, 3});
paddle_mobile::TestRelu6Op({1, 3, 11, 22});
paddle_mobile::TestRelu6Op({1, 32, 112, 112});
std::cout << "test relu6 op pass." << std::endl;
return 0;
}
...@@ -12,46 +12,71 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,46 +12,71 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h" #include "../test_include.h"
#include "operators/relu_op.h" #include "operators/relu_op.h"
int main() { namespace paddle_mobile {
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(g_resnet);
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail");
Executor4Test<paddle_mobile::CPU, void Relu(const framework::Tensor *X, framework::Tensor *Y) {
paddle_mobile::operators::ReluOp<paddle_mobile::CPU, float>> const float *x = X->data<float>();
executor(program, "relu"); float *y = Y->mutable_data<float>();
// 1. input_tensors; for (int i = 0; i < X->numel(); ++i) {
vector<Tensor> input_tensors; float q = x[i];
y[i] = std::max(0.f, q);
}
}
Tensor input1; int TestReluOp(const std::vector<int> input_shape) {
auto input1_data = CreateInput<float>(&input1, {1, 2, 3, 4}, -1, 1); framework::DDim dims = framework::make_ddim(input_shape);
input_tensors.push_back(input1); VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
// 2. input_names auto input_var = scope.get()->Var("input");
vector<string> input_names({ auto input = input_var->template GetMutable<framework::LoDTensor>();
"batch_norm_0.tmp_2", SetupTensor<float>(input, dims, -100.0, 100.0);
});
// 3. output_names auto output_var = scope.get()->Var("output");
vector<string> output_names({"batch_norm_0.tmp_3"});
// 4. out_dims; framework::AttributeMap attrs;
vector<DDim> out_ddims; auto *op =
auto out_ddim = paddle_mobile::framework::make_ddim({1, 2, 3, 4}); new operators::ReluOp<CPU, float>("relu", inputs, outputs, attrs, scope);
out_ddims.push_back(out_ddim); op->InferShape();
op->Init();
op->Run();
auto output = executor.Predict<LoDTensor>(input_tensors, input_names, auto output = output_var->template Get<framework::LoDTensor>();
output_names, out_ddims);
auto output0_data = output[0]->data<float>(); framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Relu(input, &output_cmp);
for (int j = 0; j < output[0]->numel(); ++j) { const float *output_data = output->data<float>();
DLOG << " value of output: " << output0_data[j]; for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
} }
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestReluOp({1, 1, 2, 3});
paddle_mobile::TestReluOp({1, 3, 11, 22});
paddle_mobile::TestReluOp({1, 32, 112, 112});
std::cout << "test relu op pass." << std::endl;
return 0; return 0;
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include <limits>
#include "../test_include.h"
#include "operators/top_k_op.h"
namespace paddle_mobile {
void TopK(const framework::Tensor *X, framework::Tensor *Y,
framework::Tensor *Indices, const int K) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
int64_t *indices = Indices->mutable_data<int64_t>();
int dim_size = X->dims().size();
int row = 1;
int col = X->dims()[dim_size - 1];
for (int i = 0; i < dim_size - 1; ++i) {
row *= X->dims()[i];
}
std::vector<float> vec(col);
for (int i = 0; i < row; ++i) {
for (int j = 0; j < col; ++j) {
vec[j] = x[i * col + j];
}
for (int k = 0; k < K; ++k) {
float max = vec[0];
int index = 0;
for (int j = 1; j < col; ++j) {
if (vec[j] > max) {
max = vec[j];
index = j;
}
}
y[i * K + k] = max;
indices[i * K + k] = index;
vec[index] = -std::numeric_limits<float>::max();
}
}
}
int TestTopKOp(const std::vector<int> input_shape, const int K) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
outputs["Indices"] = std::vector<std::string>({"indices"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
auto indices_var = scope.get()->Var("indices");
framework::AttributeMap attrs;
attrs["k"].Set<int>(K);
auto *op =
new operators::TopKOp<CPU, float>("top_k", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
auto indices = indices_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp, indices_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
int64_t *indices_cmp_data =
indices_cmp.mutable_data<int64_t>(indices->dims());
TopK(input, &output_cmp, &indices_cmp, K);
// sort output
float *output_data = const_cast<float *>(output->data<float>());
int64_t *indices_data = const_cast<int64_t *>(indices->data<int64_t>());
// std::vector<std::pair<float, size_t>> vec(K);
// for (int i = 0; i < output->numel() / K; ++i) {
// for (int j = 0; j < K; ++j) {
// vec[j] = std::move(std::make_pair(output_data[i * K + j],
// indices_data[i * K + j]));
// }
// std::sort(vec.begin(), vec.end(),
// [](const std::pair<float, size_t> &l,
// const std::pair<float, size_t> &r) {
// return l.first > r.first; });
// for (int j = 0; j < K; ++j) {
// output_data[i * K + j] = vec[j].first;
// indices_data[i * K + j] = vec[j].second;
// }
// }
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
for (int i = 0; i < indices->numel(); ++i) {
if (indices_data[i] != indices_cmp_data[i]) {
LOG(kLOG_INFO) << "indices_data[" << i << "] = " << indices_data[i]
<< ", indices_cmp_data[" << i
<< "] = " << indices_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
TestTopKOp({1, 100}, 1);
TestTopKOp({128, 100}, 10);
TestTopKOp({128, 2, 100}, 10);
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册