未验证 提交 c6e349d7 编写于 作者: H Houjiang Chen 提交者: GitHub

Merge pull request #1387 from hjchen2/ocr_attention

Add/Refine unit tests
......@@ -235,6 +235,15 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h)
target_link_libraries(test-relu-op paddle-mobile)
ADD_EXECUTABLE(test-relu6-op operators/test_relu6_op.cpp test_helper.h test_include.h)
target_link_libraries(test-relu6-op paddle-mobile)
ADD_EXECUTABLE(test-topk-op operators/test_topk_op.cpp test_helper.h test_include.h)
target_link_libraries(test-topk-op paddle-mobile)
ADD_EXECUTABLE(test-cast-op operators/test_cast_op.cpp test_helper.h test_include.h)
target_link_libraries(test-cast-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h)
target_link_libraries(test-fc-op paddle-mobile)
......
......@@ -17,157 +17,106 @@ limitations under the License. */
#include "operators/batchnorm_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestBatchNormOp {
public:
explicit TestBatchNormOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "batch_norm" &&
op->Input("X")[0] == "conv2d_5.tmp_0") {
DLOG << " mul attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
DLOG << " Input X is : " << op->Input("X")[0];
DLOG << " Input Mean is : " << op->Input("Mean")[0];
DLOG << " Input Variance is : " << op->Input("Variance")[0];
DLOG << " Input Scale is : " << op->Input("Scale")[0];
DLOG << " Input Bias is : " << op->Input("Bias")[0];
DLOG << " Output Y is : " << op->Output("Y")[0];
DLOG << " epsilon : " << op->GetAttrMap().at("epsilon").Get<float>();
std::shared_ptr<operators::BatchNormOp<Dtype, float>> lrn =
std::make_shared<operators::BatchNormOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(lrn);
}
void BatchNorm(const framework::Tensor *X, const framework::Tensor *Mean,
const framework::Tensor *Var, const framework::Tensor *Scale,
const framework::Tensor *Bias, const float eps,
framework::Tensor *Y) {
const float *x = X->data<float>();
const float *m = Mean->data<float>();
const float *v = Var->data<float>();
const float *s = Scale->data<float>();
const float *b = Bias->data<float>();
float *y = Y->mutable_data<float>();
int batch_size = X->dims()[0];
int channel = X->dims()[1];
int hw = X->dims()[2] * X->dims()[3];
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channel; ++c) {
float mean = m[c];
float inv_var = 1.f / std::sqrt(v[c] + eps);
float scale = s[c];
float bias = b[c];
const float *input = x + (batch * channel + c) * hw;
float *output = y + (batch * channel + c) * hw;
for (int j = 0; j < hw; ++j) {
output[j] = scale * ((input[j] - mean) * inv_var) + bias;
}
}
}
}
std::shared_ptr<Tensor> predict_bn(const Tensor &t1, const Tensor &t2,
const Tensor &t3, const Tensor &t4,
const Tensor &t5) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("conv2d_5.tmp_0");
auto tensor_x1 = x1_feed_value->GetMutable<LoDTensor>();
tensor_x1->ShareDataWith(t1);
Variable *mean_feed_value = scope->Var("batch_norm_10.w_1");
auto tensor_mean = mean_feed_value->GetMutable<LoDTensor>();
tensor_mean->ShareDataWith(t2);
Variable *scale_feed_value = scope->Var("batch_norm_10.w_0");
auto tensor_scale = scale_feed_value->GetMutable<LoDTensor>();
tensor_scale->ShareDataWith(t3);
Variable *variance_feed_value = scope->Var("batch_norm_10.w_2");
auto tensor_variance = variance_feed_value->GetMutable<LoDTensor>();
tensor_variance->ShareDataWith(t4);
Variable *bias_feed_value = scope->Var("batch_norm_10.b_0");
auto tensor_bias = bias_feed_value->GetMutable<LoDTensor>();
tensor_bias->ShareDataWith(t5);
Variable *output = scope->Var("batch_norm_10.tmp_2");
auto *output_tensor = output->GetMutable<LoDTensor>();
output_tensor->mutable_data<float>({1, 256, 38, 38});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict_bn(t1, t2, t3, t4, t5, 0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict_bn(const Tensor &t1, const Tensor &t2, const Tensor &t3,
const Tensor &t4, const Tensor &t5, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
int TestBatchNormOp(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
inputs["Mean"] = std::vector<std::string>({"mean"});
inputs["Variance"] = std::vector<std::string>({"variance"});
inputs["Scale"] = std::vector<std::string>({"scale"});
inputs["Bias"] = std::vector<std::string>({"bias"});
outputs["Y"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto mean_var = scope.get()->Var("mean");
auto mean = mean_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(mean, framework::make_ddim({input_shape[1]}), -10.0, 10.0);
auto vari_var = scope.get()->Var("variance");
auto vari = vari_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(vari, framework::make_ddim({input_shape[1]}), -10.0, 10.0);
auto scale_var = scope.get()->Var("scale");
auto scale = scale_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(scale, framework::make_ddim({input_shape[1]}), -10.0,
10.0);
auto bias_var = scope.get()->Var("bias");
auto bias = bias_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(bias, framework::make_ddim({input_shape[1]}), -10.0, 10.0);
auto output_var = scope.get()->Var("output");
float eps = 1e-6;
framework::AttributeMap attrs;
attrs["epsilon"].Set<float>(eps);
attrs["momentum"].Set<float>(0.f);
auto *op = new operators::BatchNormOp<CPU, float>("batch_norm", inputs,
outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
BatchNorm(input, mean, vari, scale, bias, eps, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
};
}
template class TestBatchNormOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run BatchNormOp Test";
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_mobilenet_ssd));
/// input x (4,10,2,2)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {1, 256, 38, 38}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
paddle_mobile::framework::Tensor mean;
SetupTensor<float>(&mean, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *mean_ptr = mean.data<float>();
paddle_mobile::framework::Tensor scale;
SetupTensor<float>(&scale, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *scale_ptr = scale.data<float>();
paddle_mobile::framework::Tensor variance;
SetupTensor<float>(&variance, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *variance_ptr = variance.data<float>();
paddle_mobile::framework::Tensor bias;
SetupTensor<float>(&bias, {256}, static_cast<float>(0),
static_cast<float>(1));
auto *bias_ptr = bias.data<float>();
paddle_mobile::framework::TestBatchNormOp<paddle_mobile::CPU> testBatchNormOp(
program);
auto output_bn =
testBatchNormOp.predict_bn(inputx1, mean, scale, variance, bias);
auto *output_bn_ptr = output_bn->data<float>();
DLOG << " (" << inputx1_ptr[0] << " - " << mean_ptr[0] << ")/(("
<< variance_ptr[0] << " + 0.00001"
<< ")^0.5)* " << scale_ptr[0] << " + " << bias_ptr[0] << " = ";
DLOG << output_bn_ptr[0];
DLOG << "input_ptr 0 : " << inputx1_ptr[0];
DLOG << "output_ptr 0 : " << output_bn_ptr[0];
TestBatchNormOp({1, 1, 10, 10});
TestBatchNormOp({1, 32, 100, 100});
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/cast_op.h"
namespace paddle_mobile {
template <typename Itype, typename Otype>
void Cast(const framework::Tensor *X, framework::Tensor *Y) {
const Itype *x = X->data<Itype>();
Otype *y = Y->mutable_data<Otype>();
for (int i = 0; i < X->numel(); ++i) {
y[i] = static_cast<Otype>(x[i]);
}
}
template <typename T>
int TypeInt() {}
template <>
int TypeInt<bool>() {
return 0;
}
template <>
int TypeInt<int>() {
return 2;
}
template <>
int TypeInt<int64_t>() {
return 3;
}
template <>
int TypeInt<float>() {
return 5;
}
template <>
int TypeInt<double>() {
return 6;
}
template <>
int TypeInt<size_t>() {
return 19;
}
template <>
int TypeInt<uint8_t>() {
return 20;
}
template <>
int TypeInt<int8_t>() {
return 21;
}
template <typename Itype, typename Otype>
int TestCastOp(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(input, dims, static_cast<Itype>(-100),
static_cast<Itype>(100));
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["in_dtype"].Set<int>(TypeInt<Itype>());
attrs["out_dtype"].Set<int>(TypeInt<Otype>());
auto *op =
new operators::CastOp<CPU, float>("cast", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
Otype *output_cmp_data = output_cmp.mutable_data<Otype>(output->dims());
Cast<Itype, Otype>(input, &output_cmp);
const Otype *output_data = output->data<Otype>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
TestCastOp<float, int>({1, 100});
TestCastOp<float, int>({128, 100});
TestCastOp<float, int64_t>({1, 100});
TestCastOp<float, int64_t>({128, 100});
TestCastOp<int, float>({1, 100});
TestCastOp<int, float>({128, 100});
TestCastOp<int64_t, float>({1, 100});
TestCastOp<int64_t, float>({128, 100});
return 0;
}
......@@ -103,20 +103,7 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
if (argc < 4) {
LOG(paddle_mobile::kLOG_INFO)
<< "Usage:\n"
<< " ./test-pool-op in_channels in_height in_width \n"
<< " params:\n"
<< " -in_channels: int, input image's channels\n"
<< " -in_height: int, input image's height\n"
<< " -in_width: int, input image's width\n";
return 1;
}
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
int Test(const int in_channels, const int in_height, const int in_width) {
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width);
......@@ -169,55 +156,75 @@ int main(int argc, char *argv[]) {
<< "float, pooling_type=avg, kernel=3, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=0, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=1, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=2, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=5, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=0, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=1, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=2, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=5, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=0, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=1, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=2, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=5, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=0, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=1, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=2, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=5, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=0, stride=1";
paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=1, stride=1";
paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=2, stride=1";
paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=5, stride=1";
paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=0, stride=1";
paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=1, stride=1";
paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=2, stride=1";
paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=5, stride=1";
paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=0, stride=2";
paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=1, stride=2";
paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=2, stride=2";
paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=5, stride=2";
paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=0, stride=2";
paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=1, stride=2";
paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=2, stride=2";
paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width);
}
int main(int argc, char *argv[]) {
// if (argc < 4) {
// LOG(paddle_mobile::kLOG_INFO)
// << "Usage:\n"
// << " ./test-pool-op in_channels in_height in_width \n"
// << " params:\n"
// << " -in_channels: int, input image's channels\n"
// << " -in_height: int, input image's height\n"
// << " -in_width: int, input image's width\n";
// return 1;
// }
// int in_channels = atoi(argv[1]);
// int in_height = atoi(argv[2]);
// int in_width = atoi(argv[3]);
Test(1, 10, 10);
Test(1, 50, 50);
Test(32, 10, 10);
Test(32, 50, 50);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h"
#include "operators/relu_op.h"
namespace paddle_mobile {
void Relu6(const framework::Tensor *X, framework::Tensor *Y) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
for (int i = 0; i < X->numel(); ++i) {
float q = x[i];
y[i] = std::min(std::max(0.f, q), 6.f);
}
}
int TestRelu6Op(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
auto *op = new operators::Relu6Op<CPU, float>("relu6", inputs, outputs, attrs,
scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Relu6(input, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestRelu6Op({1, 1, 2, 3});
paddle_mobile::TestRelu6Op({1, 3, 11, 22});
paddle_mobile::TestRelu6Op({1, 32, 112, 112});
std::cout << "test relu6 op pass." << std::endl;
return 0;
}
......@@ -12,46 +12,71 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h"
#include "operators/relu_op.h"
int main() {
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(g_resnet);
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail");
namespace paddle_mobile {
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::ReluOp<paddle_mobile::CPU, float>>
executor(program, "relu");
void Relu(const framework::Tensor *X, framework::Tensor *Y) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
// 1. input_tensors;
vector<Tensor> input_tensors;
for (int i = 0; i < X->numel(); ++i) {
float q = x[i];
y[i] = std::max(0.f, q);
}
}
Tensor input1;
auto input1_data = CreateInput<float>(&input1, {1, 2, 3, 4}, -1, 1);
input_tensors.push_back(input1);
int TestReluOp(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
// 2. input_names
vector<string> input_names({
"batch_norm_0.tmp_2",
});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
// 3. output_names
vector<string> output_names({"batch_norm_0.tmp_3"});
auto output_var = scope.get()->Var("output");
// 4. out_dims;
vector<DDim> out_ddims;
auto out_ddim = paddle_mobile::framework::make_ddim({1, 2, 3, 4});
out_ddims.push_back(out_ddim);
framework::AttributeMap attrs;
auto *op =
new operators::ReluOp<CPU, float>("relu", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = executor.Predict<LoDTensor>(input_tensors, input_names,
output_names, out_ddims);
auto output = output_var->template Get<framework::LoDTensor>();
auto output0_data = output[0]->data<float>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Relu(input, &output_cmp);
for (int j = 0; j < output[0]->numel(); ++j) {
DLOG << " value of output: " << output0_data[j];
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestReluOp({1, 1, 2, 3});
paddle_mobile::TestReluOp({1, 3, 11, 22});
paddle_mobile::TestReluOp({1, 32, 112, 112});
std::cout << "test relu op pass." << std::endl;
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include <limits>
#include "../test_include.h"
#include "operators/top_k_op.h"
namespace paddle_mobile {
void TopK(const framework::Tensor *X, framework::Tensor *Y,
framework::Tensor *Indices, const int K) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
int64_t *indices = Indices->mutable_data<int64_t>();
int dim_size = X->dims().size();
int row = 1;
int col = X->dims()[dim_size - 1];
for (int i = 0; i < dim_size - 1; ++i) {
row *= X->dims()[i];
}
std::vector<float> vec(col);
for (int i = 0; i < row; ++i) {
for (int j = 0; j < col; ++j) {
vec[j] = x[i * col + j];
}
for (int k = 0; k < K; ++k) {
float max = vec[0];
int index = 0;
for (int j = 1; j < col; ++j) {
if (vec[j] > max) {
max = vec[j];
index = j;
}
}
y[i * K + k] = max;
indices[i * K + k] = index;
vec[index] = -std::numeric_limits<float>::max();
}
}
}
int TestTopKOp(const std::vector<int> input_shape, const int K) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
outputs["Indices"] = std::vector<std::string>({"indices"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
auto indices_var = scope.get()->Var("indices");
framework::AttributeMap attrs;
attrs["k"].Set<int>(K);
auto *op =
new operators::TopKOp<CPU, float>("top_k", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
auto indices = indices_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp, indices_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
int64_t *indices_cmp_data =
indices_cmp.mutable_data<int64_t>(indices->dims());
TopK(input, &output_cmp, &indices_cmp, K);
// sort output
float *output_data = const_cast<float *>(output->data<float>());
int64_t *indices_data = const_cast<int64_t *>(indices->data<int64_t>());
// std::vector<std::pair<float, size_t>> vec(K);
// for (int i = 0; i < output->numel() / K; ++i) {
// for (int j = 0; j < K; ++j) {
// vec[j] = std::move(std::make_pair(output_data[i * K + j],
// indices_data[i * K + j]));
// }
// std::sort(vec.begin(), vec.end(),
// [](const std::pair<float, size_t> &l,
// const std::pair<float, size_t> &r) {
// return l.first > r.first; });
// for (int j = 0; j < K; ++j) {
// output_data[i * K + j] = vec[j].first;
// indices_data[i * K + j] = vec[j].second;
// }
// }
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
for (int i = 0; i < indices->numel(); ++i) {
if (indices_data[i] != indices_cmp_data[i]) {
LOG(kLOG_INFO) << "indices_data[" << i << "] = " << indices_data[i]
<< ", indices_cmp_data[" << i
<< "] = " << indices_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
TestTopKOp({1, 100}, 1);
TestTopKOp({128, 100}, 10);
TestTopKOp({128, 2, 100}, 10);
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册