提交 5ac44d9b 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge branch 'develop' into develop

/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#include "batchnorm_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void BatchNormOp<Dtype, T>::InferShape() const {
auto x_dims = param_.InputX()->dims();
param_.OutputY()->Resize(x_dims);
}
template class BatchNormOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#include "framework/operator.h"
#include "operators/kernel/batchnorm_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class BatchNormOp : public framework::OperatorWithKernel<DeviceType> {
public:
BatchNormOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs,
attrs, scope),
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
operators::BatchNormKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override;
protected:
BatchNormParam param_;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -16,68 +16,49 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#include "elementwise_add_op_test.h"
#include "framework/executor.h"
#include "io.h"
#include "mul_op_test.h"
#include "test_helper.h"
#include "concat_op.h"
//
// template <typename T>
// void SetupTensor(paddle::framework::LoDTensor* input,
// paddle::framework::DDim dims, T lower, T upper) {
// static unsigned int seed = 100;
// std::mt19937 rng(seed++);
// std::uniform_real_distribution<double> uniform_dist(0, 1);
//
// T* input_ptr = input->mutable_data<T>(dims, paddle::platform::CPUPlace());
// for (int i = 0; i < input->numel(); ++i) {
// input_ptr[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) +
// lower);
// }
//}
namespace paddle_mobile {
namespace operators {
int main() {
template <typename Dtype, typename T>
void ConcatOp<Dtype, T>::InferShape() const {
auto inputs = param_.Inputs();
const size_t n = inputs.size();
std::string data_set = "cifar10";
//
// if (data_set == "cifar10") {
// SetupTensor<float>(&input, {FLAGS_batch_size, 3, 32, 32},
// static_cast<float>(0), static_cast<float>(1));
// } else if (data_set == "imagenet") {
// SetupTensor<float>(&input, {FLAGS_batch_size, 3, 224, 224},
// static_cast<float>(0), static_cast<float>(1));
// } else {
// LOG(FATAL) << "Only cifar10 or imagenet is supported.";
// }
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(
"../../../test/models/image_classification_resnet.inference.model"));
std::vector<DDim> inputs_dims;
inputs_dims.reserve(n);
for (int i = 0; i < n; i++) {
inputs_dims.push_back(inputs[i]->dims());
}
paddle_mobile::framework::Executor<paddle_mobile::CPU> executor(program);
auto axis = static_cast<size_t>(param_.Axis());
paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {1, 3, 32, 32}, static_cast<float>(0),
static_cast<float>(1));
float *input_ptr = input.data<float>();
for (int i = 0; i < input.numel(); ++i) {
// std::cout << input_ptr[i] << std::endl;
if (n == 1) {
DLOG << "Warning: concat op have only one input, "
"may waste memory";
}
// std::cout << "input: " << input.memory_size() << std::endl;
// std::cout << "input: " << input.numel() << std::endl;
auto output = executor.predict(input);
/// add all dim[axis] and check other dims if equal.
auto out_dims = inputs_dims[0];
int in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
out_dims[axis] += inputs_dims[i][j];
} else {
assert(out_dims[j] == inputs_dims[i][j]);
}
}
}
// std::cout << "output: " << output->memory_size() << std::endl;
// std::cout << "output: " << output->numel() << std::endl;
if (out_dims[axis] < 0) {
out_dims[axis] = -1;
}
// float* output_ptr = output->data<float>();
// for (int j = 0; j < output->numel(); ++j) {
// std::cout << " value of output: " << output_ptr[j] << std::endl;
//
paddle_mobile::test::testElementwiseAdd();
paddle_mobile::test::testMul();
return 0;
param_.Out()->Resize(out_dims);
}
template class ConcatOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include "framework/operator.h"
#include "operators/kernel/concat_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class ConcatOp : public framework::OperatorWithKernel<DeviceType> {
public:
ConcatOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs,
attrs, scope),
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
operators::ConcatKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override;
protected:
ConcatParam param_;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -37,8 +37,7 @@ class ElementwiseAddOp : public framework::OperatorWithKernel<DeviceType> {
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
operators::ElementwiseAddKernel<DeviceType, T, ElementwiseAddParam>
kernel;
operators::ElementwiseAddKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "operators/kernel/batchnorm_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
void BatchNormKernel<CPU, float>::Compute(const BatchNormParam &param) const {
/// todo: test.
const Tensor *input_x = param.InputX();
auto input_x_ptr = input_x->data<float>();
const auto &x_dims = input_x->dims();
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const int stride0 = C * H * W;
const int stride1 = H * W;
const int stride2 = W;
Tensor *out = param.OutputY();
auto out_ptr = out->mutable_data<float>();
const float epsilon = param.Epsilon();
const Tensor *mean = param.InputMean();
const Tensor *variance = param.InputVariance();
const Tensor *scale = param.InputScale();
const Tensor *bias = param.InputBias();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
Tensor inv_std;
auto inv_std_ptr = inv_std.mutable_data<float>(make_ddim({C}));
if (C != variance->numel()) {
std::cout << "C must equal to variance.numel()" << std::endl;
}
assert(C == variance->numel());
/// std = (var + epsilon).sqrt();
/// inv_std = 1 / std;
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor new_scale;
auto new_scale_ptr = new_scale.mutable_data<float>(make_ddim({C}));
Tensor new_bias;
auto new_bias_ptr = new_bias.mutable_data<float>(make_ddim({C}));
/// ((x - est_mean) * (inv_var) * scale + bias equal to
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] =
bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
{
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
for (int w = 0; w < W; w++) {
int index = n * stride0 + i * stride1 + h * stride2 + w;
out_ptr[index] = input_x_ptr[index] * new_scale_ptr[i] +
new_bias_ptr[i];
}
}
}
}
}
DLOG << "input[2,5,1,0](input[102]) ,channel 5 :";
DLOG << "input_x_ptr : " << input_x_ptr[102];
DLOG << "variance : " << variance_ptr[5];
DLOG << "inv_std_ptr : " << inv_std_ptr[5];
DLOG << "new_scale_ptr : " << new_scale_ptr[5];
DLOG << "new_bias_ptr : " << new_bias_ptr[5];
DLOG << "out_ptr : " << out_ptr[102];
}
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "operators/kernel/concat_kernel.h"
namespace paddle_mobile {
namespace operators {
template <typename T> class ConcatFunctor {
public:
void operator()(const std::vector<framework::Tensor> &input, const int axis,
framework::Tensor *output) {
size_t num = input.size();
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
// computation
for (int k = 0; k < out_rows; ++k) {
T *dst_ptr = output->data<T>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
const T *src_prt = input[j].data<T>() + k * col_len;
memory::Copy(dst_ptr + col_idx, src_prt, sizeof(T) * col_len);
col_idx += col_len;
}
}
}
};
template <typename T>
void StridedNumelCopyWithAxis(int64_t axis, T *dst,
const framework::DDim &dst_stride_numel,
const T *src,
const framework::DDim &src_stride_numel,
int64_t size) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis];
///"src and dst tensor should have the same dims size."
assert(src_stride_numel.size() == dst_stride_numel.size());
for (int64_t i = 0; i < axis; ++i) {
if (i < axis) {
/// src and dst should have the same elements
/// except the specified axis.
assert(src_stride_numel[i] / src_stride_numel[axis] ==
dst_stride_numel[i] / dst_stride_numel[axis]);
} else if (i == axis) {
continue;
} else {
/// "src and dst should have the same elements "
/// "except the specified axis."
assert(src_stride_numel[i] == dst_stride_numel[i]);
}
}
for (int64_t i = 0; i < before; ++i) {
memory::Copy(dst + i * dst_after, src + i * src_after,
sizeof(T) * size);
}
}
template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
auto inputs = param.Inputs();
auto *out = param.Out();
int64_t axis = param.Axis();
out->mutable_data<float>();
/// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && inputs.size() < 10) {
size_t output_offset = 0;
for (auto *in : inputs) {
auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<float>(
axis, out->data<float>() + output_offset, out_stride,
in->data<float>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
} else {
std::vector<framework::Tensor> inputs_concat(inputs.size());
for (int j = 0; j < inputs.size(); ++j) {
inputs_concat[j] = *inputs[j];
}
ConcatFunctor<float> concat_functor;
concat_functor(inputs_concat, static_cast<int>(axis), out);
}
}
} // namespace operators
} // namespace paddle_mobile
......@@ -24,7 +24,7 @@ template <typename T> struct AddFunctor {
};
template <>
void ElementwiseAddKernel<CPU, float, ElementwiseAddParam>::Compute(
void ElementwiseAddKernel<CPU, float>::Compute(
const ElementwiseAddParam &param) const {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
......@@ -35,7 +35,7 @@ void ElementwiseAddKernel<CPU, float, ElementwiseAddParam>::Compute(
AddFunctor<float>(), Out);
}
template class ElementwiseAddKernel<CPU, float, ElementwiseAddParam>;
template class ElementwiseAddKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include "operators/kernel/lrn_kernel.h"
namespace paddle_mobile {
namespace operators {
template <> void LrnKernel<CPU, float>::Compute(const LrnParam &param) const {
const Tensor *input_x = param.InputX();
auto x_dims = input_x->dims();
/// data_format = NCHW
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
Tensor *out = param.Out();
out->mutable_data<float>();
const int n = param.N();
const float alpha = param.Alpha();
const float beta = param.Beta();
const float k = param.K();
LRNFunctor<float> lrnFunctor;
lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta);
}
template class LrnKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
......@@ -23,8 +23,7 @@ SOFTWARE.
namespace paddle_mobile {
namespace operators {
template <>
void MulKernel<CPU, float, MulParam>::Compute(const MulParam &param) const {
template <> void MulKernel<CPU, float>::Compute(const MulParam &param) const {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *out = param.Out();
......@@ -48,7 +47,7 @@ void MulKernel<CPU, float, MulParam>::Compute(const MulParam &param) const {
}
}
template class MulKernel<CPU, float, MulParam>;
template class MulKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class BatchNormKernel
: public framework::OpKernelBase<DeviceType, BatchNormParam> {
public:
void Compute(const BatchNormParam &param) const;
};
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> {
public:
void Compute(const ConcatParam &param) const;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -26,7 +26,7 @@ namespace operators {
using namespace framework;
template <typename DeviceType, typename T, typename P>
template <typename DeviceType, typename T>
class ElementwiseAddKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddParam> {
public:
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename T> struct LRNFunctor {
void operator()(const framework::Tensor &input, framework::Tensor *out,
int N, int C, int H, int W, int n, T k, T alpha, T beta) {
auto input_ptr = input.data<T>();
const int start = -(n - 1) / 2;
const int end = start + n;
const int stride0 = C * H * W;
const int stride1 = H * W;
const int stride2 = W;
const int stride3 = 1;
framework::Tensor sqr_buffer;
auto sqr_buffer_ptr = sqr_buffer.mutable_data<T>(input.dims());
std::fill(sqr_buffer_ptr, sqr_buffer_ptr + sqr_buffer.numel(), k);
for (int a = 0; a < N; a++) {
for (int b = 0; b < C; b++) {
for (int index = start; index < end; index++) {
int channel = b + index;
if (channel >= 0 && channel < C) {
for (int c = 0; c < H; c++) {
for (int d = 0; d < W; d++) {
int u =
a * stride0 + b * stride1 + c * stride2 + d;
int i = a * stride0 + channel * stride1 +
c * stride2 + d;
sqr_buffer_ptr[u] +=
alpha * input_ptr[i] * input_ptr[i];
}
}
}
}
}
}
auto out_ptr = out->data<T>();
for (int i = 0; i < input.numel(); i++) {
out_ptr[i] = input_ptr[i] / pow(sqr_buffer_ptr[i], beta);
}
}
};
template <typename DeviceType, typename T>
class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> {
public:
void Compute(const LrnParam &param) const;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -26,7 +26,7 @@ namespace operators {
using namespace framework;
template <typename DeviceType, typename T, typename P>
template <typename DeviceType, typename T>
class MulKernel : public framework::OpKernelBase<DeviceType, MulParam> {
public:
void Compute(const MulParam &param) const;
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#include "lrn_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T> void LrnOp<Dtype, T>::InferShape() const {
auto x_dims = param_.InputX()->dims();
param_.Out()->Resize(x_dims);
}
template class LrnOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#include "framework/operator.h"
#include "operators/kernel/lrn_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class LrnOp : public framework::OperatorWithKernel<DeviceType> {
public:
LrnOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs,
attrs, scope),
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
operators::LrnKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override;
protected:
LrnParam param_;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -36,7 +36,7 @@ class MulOp : public framework::OperatorWithKernel<DeviceType> {
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
operators::MulKernel<DeviceType, T, MulParam> kernel;
operators::MulKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
......
......@@ -48,10 +48,29 @@ class OpParam : PaddleMobileObject {
return GetVarValue<T>("Y", inputs, scope);
}
template <typename T>
static T *InputBiasFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Bias", inputs, scope);
}
template <typename T>
static T *InputVarianceFrom(const VariableNameMap &inputs,
const Scope &scope) {
return GetVarValue<T>("Variance", inputs, scope);
}
template <typename T>
static T *InputMeanFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Mean", inputs, scope);
}
template <typename T>
static T *InputScaleFrom(const VariableNameMap &inputs,
const Scope &scope) {
return GetVarValue<T>("Scale", inputs, scope);
}
template <typename T>
static std::vector<T *> InputMultiFrom(const VariableNameMap &inputs,
const Scope &scope) {
return GetMultiVarValue<T>("Input", inputs, scope);
return GetMultiVarValue<T>("X", inputs, scope);
}
template <typename T>
......@@ -64,21 +83,31 @@ class OpParam : PaddleMobileObject {
return GetVarValue<T>("Out", outputs, scope);
}
template <typename T>
static T *OutputYFrom(const VariableNameMap &outputs, const Scope &scope) {
return GetVarValue<T>("Y", outputs, scope);
}
template <typename T>
static T *MidOutFrom(const VariableNameMap &outputs, const Scope &scope) {
return GetVarValue<T>("MidOut", outputs, scope);
}
template <typename T>
static T *FilterFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Filter", inputs, scope);
}
template <typename T>
static const T GetAttr(std::string key, const AttributeMap &map) {
static const T GetAttr(const std::string &key, const AttributeMap &map) {
return ((Attribute)map.at(key)).Get<T>();
}
template <typename T>
static T *GetVarValue(std::string key, const VariableNameMap &var_map,
const Scope &scope) {
static T *GetVarValue(const std::string &key,
const VariableNameMap &var_map, const Scope &scope) {
auto var_vec = var_map.at(key);
if (var_vec.size()) {
if (!var_vec.empty()) {
// std::cout << " get var value -- " << var_vec[0] <<
// std::endl;
auto var = scope.FindVar(var_vec[0]);
......@@ -89,7 +118,7 @@ class OpParam : PaddleMobileObject {
}
template <typename T>
static std::vector<T *> GetMultiVarValue(std::string key,
static std::vector<T *> GetMultiVarValue(const std::string &key,
const VariableNameMap &var_map,
const Scope &scope) {
auto var_vecs = var_map.at(key);
......@@ -222,5 +251,95 @@ class ConcatParam : public OpParam {
int axis_;
};
class LrnParam : public OpParam {
public:
LrnParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
out_ = OutFrom<framework::Tensor>(outputs, scope);
mid_out_ = MidOutFrom<framework::Tensor>(outputs, scope);
n_ = GetAttr<int>("n", attrs);
alpha_ = GetAttr<float>("alpha", attrs);
beta_ = GetAttr<float>("beta", attrs);
k_ = GetAttr<float>("k", attrs);
data_format_ = GetAttr<std::string>("data_format", attrs);
}
const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; }
Tensor *MidOut() const { return mid_out_; }
const int &N() const { return n_; }
const float &Alpha() const { return alpha_; }
const float &Beta() const { return beta_; }
const float &K() const { return k_; }
const std::string &DataFormat() const { return data_format_; }
private:
Tensor *input_x_;
Tensor *out_;
Tensor *mid_out_;
int n_;
float alpha_;
float beta_;
float k_;
std::string data_format_;
};
class BatchNormParam : OpParam {
public:
BatchNormParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
output_y_ = OutputYFrom<framework::Tensor>(outputs, scope);
input_bias_ = InputBiasFrom<framework::Tensor>(inputs, scope);
input_mean_ = InputMeanFrom<framework::Tensor>(inputs, scope);
input_scale_ = InputScaleFrom<framework::Tensor>(inputs, scope);
input_variance_ = InputVarianceFrom<framework::Tensor>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
is_test_ = GetAttr<bool>("is_test", attrs);
}
const Tensor *InputX() const { return input_x_; }
Tensor *OutputY() const { return output_y_; }
const Tensor *InputBias() const { return input_bias_; }
const Tensor *InputMean() const { return input_mean_; }
const Tensor *InputScale() const { return input_scale_; }
const Tensor *InputVariance() const { return input_variance_; }
const float &Epsilon() const { return epsilon_; }
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
const std::string &DataFormat() const { return data_format_; }
private:
Tensor *input_x_;
Tensor *output_y_;
Tensor *input_bias_;
Tensor *input_mean_;
Tensor *input_scale_;
Tensor *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
std::string data_format_;
};
} // namespace operators
} // namespace paddle_mobile
# gen test
ADD_EXECUTABLE(paddle-mobile-test main.cpp test_helper.h elementwise_add_op_test.h test_include.h mul_op_test.h)
target_link_libraries(paddle-mobile-test paddle-mobile)
# gen test
ADD_EXECUTABLE(test-conv-op operators/test_cov_op.cpp test_helper.h test_include.h)
target_link_libraries(test-conv-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-mul-op operators/test_mul_op.cpp test_helper.h test_include.h)
target_link_libraries(test-mul-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-elementwiseadd-op operators/test_elementwise_add_op.cpp test_helper.h test_include.h)
target_link_libraries(test-elementwiseadd-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-concat-op operators/test_concat_op.cpp test_helper.h test_include.h)
target_link_libraries(test-concat-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-lrn-op operators/test_lrn_op.cpp test_helper.h test_include.h)
target_link_libraries(test-lrn-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-batchnorm-op operators/test_batchnorm_op.cpp test_helper.h test_include.h)
target_link_libraries(test-batchnorm-op paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-log common/test_log.cpp)
target_link_libraries(test-log paddle-mobile)
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include "../test_include.h"
#include "operators/batchnorm_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype> class TestBatchNormOp {
public:
explicit TestBatchNormOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "batch_norm" &&
op->Input("X")[0] == "conv2d_0.tmp_0") {
DLOG << " mul attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
DLOG << " Input X is : " << op->Input("X")[0];
DLOG << " Input Mean is : " << op->Input("Mean")[0];
DLOG << " Input Variance is : " << op->Input("Variance")[0];
DLOG << " Input Scale is : " << op->Input("Scale")[0];
DLOG << " Input Bias is : " << op->Input("Bias")[0];
DLOG << " Output Y is : " << op->Output("Y")[0];
DLOG << " epsilon : "
<< op->GetAttrMap().at("epsilon").Get<float>();
std::shared_ptr<operators::BatchNormOp<Dtype, float>> lrn =
std::make_shared<operators::BatchNormOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(lrn);
}
}
}
}
std::shared_ptr<Tensor> predict_bn(Tensor &t1, Tensor &t2, Tensor &t3,
Tensor &t4, Tensor &t5) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("conv2d_0.tmp_0");
auto tensor_x1 = x1_feed_value->GetMutable<Tensor>();
tensor_x1->ShareDataWith(t1);
Variable *mean_feed_value = scope->Var("batch_norm_0.w_1");
auto tensor_mean = mean_feed_value->GetMutable<Tensor>();
tensor_mean->ShareDataWith(t2);
Variable *scale_feed_value = scope->Var("batch_norm_0.w_0");
auto tensor_scale = scale_feed_value->GetMutable<Tensor>();
tensor_scale->ShareDataWith(t3);
Variable *variance_feed_value = scope->Var("batch_norm_0.w_2");
auto tensor_variance = variance_feed_value->GetMutable<Tensor>();
tensor_variance->ShareDataWith(t4);
Variable *bias_feed_value = scope->Var("batch_norm_0.b_0");
auto tensor_bias = bias_feed_value->GetMutable<Tensor>();
tensor_bias->ShareDataWith(t5);
Variable *output = scope->Var("batch_norm_0.tmp_2");
auto *output_tensor = output->GetMutable<Tensor>();
output_tensor->mutable_data<float>({4, 10, 2, 2});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict_bn(t1, t2, t3, t4, t5, 0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict_bn(const Tensor &t1, const Tensor &t2, const Tensor &t3,
const Tensor &t4, const Tensor &t5, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size();
++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestBatchNormOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run BatchNormOp Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(
"../../test/models/image_classification_resnet.inference.model"));
/// input x (4,10,2,2)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {4, 10, 2, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
paddle_mobile::framework::Tensor mean;
SetupTensor<float>(&mean, {10}, static_cast<float>(0),
static_cast<float>(1));
auto *mean_ptr = mean.data<float>();
paddle_mobile::framework::Tensor scale;
SetupTensor<float>(&scale, {10}, static_cast<float>(0),
static_cast<float>(1));
auto *scale_ptr = scale.data<float>();
paddle_mobile::framework::Tensor variance;
SetupTensor<float>(&variance, {10}, static_cast<float>(0),
static_cast<float>(1));
auto *variance_ptr = variance.data<float>();
paddle_mobile::framework::Tensor bias;
SetupTensor<float>(&bias, {10}, static_cast<float>(0),
static_cast<float>(1));
auto *bias_ptr = bias.data<float>();
paddle_mobile::framework::TestBatchNormOp<paddle_mobile::CPU>
testBatchNormOp(program);
auto output_bn =
testBatchNormOp.predict_bn(inputx1, mean, scale, variance, bias);
auto *output_bn_ptr = output_bn->data<float>();
/// [2, 5, 1, 0]
DLOG << " (" << inputx1_ptr[102] << " - " << mean_ptr[5] << ")/(("
<< variance_ptr[5] << " + 0.00001"
<< ")^0.5)* " << scale_ptr[5] << " + " << bias_ptr[5] << " = ";
DLOG << output_bn_ptr[102];
return 0;
}
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include "../test_include.h"
#include "operators/concat_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype> class TestConcatOp {
public:
explicit TestConcatOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "concat" &&
op->Input("X")[0] == "conv2d_3.tmp_1") {
DLOG << " mul attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
DLOG << " Input X is : " << op->Input("X")[0];
DLOG << " Output Out is : " << op->Output("Out")[0];
DLOG << " axis : "
<< op->GetAttrMap().at("axis").Get<int>();
std::shared_ptr<operators::ConcatOp<Dtype, float>> concat =
std::make_shared<operators::ConcatOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(concat);
}
}
}
}
std::shared_ptr<Tensor> predict_concat(Tensor &t1, Tensor &t2, Tensor &t3,
Tensor &t4) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("conv2d_3.tmp_1");
auto tensor_x1 = x1_feed_value->GetMutable<Tensor>();
tensor_x1->ShareDataWith(t1);
Variable *x2_feed_value = scope->Var("conv2d_5.tmp_1");
auto tensor_x2 = x2_feed_value->GetMutable<Tensor>();
tensor_x2->ShareDataWith(t2);
Variable *x3_feed_value = scope->Var("conv2d_7.tmp_1");
auto tensor_x3 = x3_feed_value->GetMutable<Tensor>();
tensor_x3->ShareDataWith(t3);
Variable *x4_feed_value = scope->Var("conv2d_8.tmp_1");
auto tensor_x4 = x4_feed_value->GetMutable<Tensor>();
tensor_x4->ShareDataWith(t4);
Variable *con_output = scope->Var("concat_0.tmp_0");
auto *output_tensor = con_output->GetMutable<Tensor>();
output_tensor->mutable_data<float>({4, 100, 2, 2});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict_concat(t1, t2, t3, t4, 0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict_concat(const Tensor &t1, const Tensor &t2, const Tensor &t3,
const Tensor &t4, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size();
++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestConcatOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run ConcatOp Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string("../../test/models/googlenet"));
/// input x (4,10,2,2)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {4, 10, 2, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
/// input x (4,20,2,2)
paddle_mobile::framework::Tensor inputx2;
SetupTensor<float>(&inputx2, {4, 20, 2, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx2_ptr = inputx2.data<float>();
/// input x (4,30,2,2)
paddle_mobile::framework::Tensor inputx3;
SetupTensor<float>(&inputx3, {4, 30, 2, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx3_ptr = inputx3.data<float>();
/// input x (4,40,2,2)
paddle_mobile::framework::Tensor inputx4;
SetupTensor<float>(&inputx4, {4, 40, 2, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx4_ptr = inputx4.data<float>();
paddle_mobile::framework::TestConcatOp<paddle_mobile::CPU> testConcatOp(
program);
auto output_concat =
testConcatOp.predict_concat(inputx1, inputx2, inputx3, inputx4);
auto *output_concat_ptr = output_concat->data<float>();
int input_n = 1;
int input_c = 2;
int input_h = 0;
int input_w = 1;
int stride0 = inputx3.numel() / inputx3.dims()[0];
int stride1 = inputx3.numel() / inputx3.dims()[0] / inputx3.dims()[1];
int stride2 = inputx3.dims()[3];
/// inputx1 (4,10,2,2),
/// inputx2 (4,20,2,2),
/// inputx3 (4,30,2,2),
/// inputx4 (4,40,2,2),
/// axis = 1
/// output (4,100,2,2)
int input_index =
input_n * stride0 + input_c * stride1 + input_h * stride2 + input_w;
int output_index =
input_n * 100 * 2 * 2 +
(input_c + inputx1.dims()[1] + inputx2.dims()[1]) * 2 * 2 +
input_h * 2 + input_w;
DLOG << " inputx3[1,2,0,1] = " << inputx3_ptr[input_index];
DLOG << " output[1,12,0,1] = " << output_concat_ptr[output_index];
return 0;
}
......@@ -17,15 +17,15 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include "../test_include.h"
#include "operators/elementwise_add_op.h"
#include "test_include.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype> class TestElementwiseAddOp {
public:
TestElementwiseAddOp(const Program<Dtype> p) : program_(p) {
explicit TestElementwiseAddOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
......@@ -41,18 +41,6 @@ template <typename Dtype> class TestElementwiseAddOp {
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
// if (op->Type() ==
// "elementwise_add") {
// if
// (op->GetAttrMap().at("axis").Get<int>()
// != -1) {
// DLOG << "attr: axis =
// "
// <<
// op->GetAttrMap().at("axis").Get<int>();
// }
// }
// DLOG << "op:" << op->Type();
if (op->Type() == "elementwise_add" &&
op->Input("X")[0] == "batch_norm_2.tmp_2") {
DLOG << " elementwise_add attr size: "
......@@ -89,7 +77,7 @@ template <typename Dtype> class TestElementwiseAddOp {
tensor_y->ShareDataWith(t2);
Variable *con_output = scope->Var("elementwise_add_0.tmp_0");
Tensor *output_tensor = con_output->GetMutable<Tensor>();
auto *output_tensor = con_output->GetMutable<Tensor>();
output_tensor->mutable_data<float>({1, 3, 224, 224});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
......@@ -123,9 +111,8 @@ template <typename Dtype> class TestElementwiseAddOp {
template class TestElementwiseAddOp<CPU>;
} // namespace framework
namespace test {
void testElementwiseAdd() {
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run ElementAddOp Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
......@@ -137,18 +124,18 @@ void testElementwiseAdd() {
paddle_mobile::framework::Tensor inputx;
SetupTensor<float>(&inputx, {1, 3, 224, 224}, static_cast<float>(0),
static_cast<float>(1));
float *inputx_ptr = inputx.data<float>();
auto *inputx_ptr = inputx.data<float>();
/// input y (224,)
paddle_mobile::framework::Tensor inputy;
SetupTensor<float>(&inputy, {224}, static_cast<float>(0),
static_cast<float>(1));
float *inputy_ptr = inputy.data<float>();
auto *inputy_ptr = inputy.data<float>();
paddle_mobile::framework::TestElementwiseAddOp<paddle_mobile::CPU>
testElementwiseAddOp(program);
auto output_add = testElementwiseAddOp.predict_add(inputx, inputy);
float *output_add_ptr = output_add->data<float>();
auto *output_add_ptr = output_add->data<float>();
// for (int j = 0; j < output_add->numel(); ++j) {
// DLOG << "value of output: " << output_add_ptr[j];
// }
......@@ -159,6 +146,5 @@ void testElementwiseAdd() {
DLOG << inputx_ptr[226] << " + " << inputy_ptr[2] << " = "
<< output_add_ptr[226];
return 0;
}
} // namespace test
} // namespace paddle_mobile
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include "../test_include.h"
#include "operators/lrn_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype> class TestLrnOp {
public:
explicit TestLrnOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "lrn" &&
op->Input("X")[0] == "pool2d_0.tmp_0") {
DLOG << " mul attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
DLOG << " Input X is : " << op->Input("X")[0];
DLOG << " Output Out is : " << op->Output("Out")[0];
DLOG << " n : " << op->GetAttrMap().at("n").Get<int>();
DLOG << " alpha : "
<< op->GetAttrMap().at("alpha").Get<float>();
DLOG << " beta : "
<< op->GetAttrMap().at("beta").Get<float>();
DLOG << " k : " << op->GetAttrMap().at("k").Get<float>();
std::shared_ptr<operators::LrnOp<Dtype, float>> lrn =
std::make_shared<operators::LrnOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(lrn);
}
}
}
}
std::shared_ptr<Tensor> predict_lrn(Tensor &t1) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("pool2d_0.tmp_0");
auto tensor_x1 = x1_feed_value->GetMutable<Tensor>();
tensor_x1->ShareDataWith(t1);
Variable *con_output = scope->Var("pool1_norm1.tmp_1");
auto *output_tensor = con_output->GetMutable<Tensor>();
output_tensor->mutable_data<float>({3, 4, 2, 2});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict_lrn(t1, 0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict_lrn(const Tensor &t1, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size();
++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestLrnOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run LrnOp Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string("../../test/models/googlenet"));
/// input x (3,4,2,2)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {3, 4, 2, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
paddle_mobile::framework::TestLrnOp<paddle_mobile::CPU> testLrnOp(program);
auto output_lrn = testLrnOp.predict_lrn(inputx1);
auto *output_lrn_ptr = output_lrn->data<float>();
DLOG << " LrnOp input: ";
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 4; j++) {
for (int c = 0; c < 2; c++) {
for (int d = 0; d < 2; d++) {
DLOGF("%f ", inputx1_ptr[i * 16 + j * 4 + c * 2 + d]);
}
DLOGF("\n");
}
DLOGF("\n");
}
DLOGF("\n");
}
DLOG << " LrnOp output: ";
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 4; j++) {
for (int c = 0; c < 2; c++) {
for (int d = 0; d < 2; d++) {
DLOGF("%f ", output_lrn_ptr[i * 16 + j * 4 + c * 2 + d]);
}
DLOGF("\n");
}
DLOGF("\n");
}
DLOGF("\n");
}
DLOG << inputx1_ptr[0] << " / ((1 + 0.00002 * ( " << inputx1_ptr[0]
<< "^2 + " << inputx1_ptr[4] << "^2 + " << inputx1_ptr[8]
<< "^2 ))^0.75) = ";
DLOG << output_lrn_ptr[0];
return 0;
}
......@@ -17,15 +17,15 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#pragma once
#include "../test_include.h"
#include "operators/mul_op.h"
#include "test_include.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype> class TestMulOp {
public:
TestMulOp(const Program<Dtype> p) : program_(p) {
explicit TestMulOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
......@@ -41,21 +41,6 @@ template <typename Dtype> class TestMulOp {
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
// if (op->Type() == "mul") {
// DLOG << "x_num_col_dims :
// "
// << op->GetAttrMap()
// .at("x_num_col_dims")
// .Get<int>();
// DLOG << "y_num_col_dims :
// "
// << op->GetAttrMap()
// .at("y_num_col_dims")
// .Get<int>();
// DLOG << " Input X is : "
// << op->Input("X")[0];
// }
// DLOG << "op:" << op->Type();
if (op->Type() == "mul" &&
op->Input("X")[0] == "pool2d_0.tmp_0") {
DLOG << " mul attr size: " << op->GetAttrMap().size();
......@@ -69,17 +54,17 @@ template <typename Dtype> class TestMulOp {
DLOG << "y_num_col_dims : "
<< op->GetAttrMap().at("y_num_col_dims").Get<int>();
std::shared_ptr<operators::MulOp<Dtype, float>> add =
std::shared_ptr<operators::MulOp<Dtype, float>> mul =
std::make_shared<operators::MulOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(add);
ops_of_block_[*block_desc.get()].push_back(mul);
}
}
}
}
std::shared_ptr<Tensor> predict_add(Tensor &t1, Tensor &t2) {
std::shared_ptr<Tensor> predict_mul(Tensor &t1, Tensor &t2) {
// feed
auto scope = program_.scope;
Variable *x_feed_value = scope->Var("pool2d_0.tmp_0");
......@@ -91,7 +76,7 @@ template <typename Dtype> class TestMulOp {
tensor_y->ShareDataWith(t2);
Variable *con_output = scope->Var("fc_0.tmp_0");
Tensor *output_tensor = con_output->GetMutable<Tensor>();
auto *output_tensor = con_output->GetMutable<Tensor>();
output_tensor->mutable_data<float>({3, 3});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
......@@ -99,7 +84,7 @@ template <typename Dtype> class TestMulOp {
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict_add(t1, t2, 0);
predict_mul(t1, t2, 0);
return out_tensor;
}
......@@ -111,7 +96,7 @@ template <typename Dtype> class TestMulOp {
ops_of_block_;
bool use_optimize_ = false;
void predict_add(const Tensor &t1, const Tensor &t2, int block_id) {
void predict_mul(const Tensor &t1, const Tensor &t2, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size();
......@@ -125,9 +110,9 @@ template <typename Dtype> class TestMulOp {
template class TestMulOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
namespace test {
void testMul() {
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run MulOp Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
......@@ -139,18 +124,18 @@ void testMul() {
paddle_mobile::framework::Tensor inputx;
SetupTensor<float>(&inputx, {3, 2, 1, 1}, static_cast<float>(0),
static_cast<float>(1));
float *inputx_ptr = inputx.data<float>();
auto *inputx_ptr = inputx.data<float>();
/// input y (2,3)
paddle_mobile::framework::Tensor inputy;
SetupTensor<float>(&inputy, {2, 3}, static_cast<float>(0),
static_cast<float>(1));
float *inputy_ptr = inputy.data<float>();
auto *inputy_ptr = inputy.data<float>();
paddle_mobile::framework::TestMulOp<paddle_mobile::CPU> testMulOp(program);
auto output_mul = testMulOp.predict_add(inputx, inputy);
float *output_mul_ptr = output_mul->data<float>();
auto output_mul = testMulOp.predict_mul(inputx, inputy);
auto *output_mul_ptr = output_mul->data<float>();
auto dimx_1 = inputx.numel() / inputx.dims()[0];
DLOG << " inputx : ";
......@@ -185,6 +170,5 @@ void testMul() {
DLOG << inputx_ptr[0] << " x " << inputy_ptr[0] << " + " << inputx_ptr[1]
<< " x " << inputy_ptr[0 + 3] << " = " << output_mul_ptr[0];
return 0;
}
} // namespace test
} // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册