提交 2a0b6988 编写于 作者: R Ray Liu 提交者: GitHub

Merge pull request #1151 from kk12333/develop

Add elementwise sub and im2sequence unittest
......@@ -206,5 +206,8 @@ LOAD_OP2(slice, CPU, MALI_GPU);
LOAD_OP2(fusion_conv_bn, CPU, FPGA);
LOAD_FUSION_MATCHER(fusion_conv_bn);
#endif
#ifdef ELEMENTWISESUB_OP
LOAD_OP1(elementwise_sub, CPU)
#endif
LOAD_OP1(quantize, CPU);
LOAD_OP1(dequantize, CPU);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISESUB_OP
#include "operators/elementwise_sub_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ElementwiseSubOp<Dtype, T>::InferShape() const {
auto x_dim = this->param_.InputX()->dims();
this->param_.Out()->Resize(x_dim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(elementwise_sub, ops::ElementwiseSubOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(elementwise_sub, ops::ElementwiseSubOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISESUB_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "kernel/elementwise_sub_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class ElementwiseSubOp : public framework::OperatorWithKernel<
DeviceType, ElementwiseSubParam<DeviceType>,
operators::ElementwiseSubKernel<DeviceType, T>> {
public:
ElementwiseSubOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, ElementwiseSubParam<DeviceType>,
operators::ElementwiseSubKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, ElementwiseSubParam<DeviceType>,
operators::ElementwiseSubKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISESUB_OP
#include "operators/kernel/elementwise_sub_kernel.h"
#include "operators/kernel/central-arm-func/elementwise_sub_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ElementwiseSubKernel<CPU, float>::Init(ElementwiseSubParam<CPU> *param) {
return true;
}
template <>
void ElementwiseSubKernel<CPU, float>::Compute(
const ElementwiseSubParam<CPU> &param) const {
ElementwiseSubCompute<float>(param);
param.Out()->set_lod(param.InputX()->lod());
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISESUB_OP
#pragma once
#include "operators/math/elementwise_op_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
struct SubFunctor {
inline T operator()(T a, T b) const { return a - b; }
};
template <typename P>
void ElementwiseSubCompute(const ElementwiseSubParam<CPU> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *Out = param.Out();
Out->mutable_data<float>();
int axis = param.Axis();
ElementwiseComputeEx<SubFunctor<float>, float>(input_x, input_y, axis,
SubFunctor<float>(), Out);
}
template class ElementwiseSubKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISEADD_OP
#pragma once
#include "framework/operator.h"
#include "operators/math/elementwise_op_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class ElementwiseSubKernel
: public framework::OpKernelBase<DeviceType,
ElementwiseSubParam<DeviceType>> {
public:
void Compute(const ElementwiseSubParam<DeviceType> &param) const;
bool Init(ElementwiseSubParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -488,6 +488,38 @@ template <typename Dtype>
using ElementwiseAddReluParam = ElementwiseAddParam<Dtype>;
#endif
#ifdef ELEMENTWISESUB_OP
template <typename Dtype>
class ElementwiseSubParam : OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
ElementwiseSubParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
input_y_ = InputYFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
axis_ = GetAttr<int>("axis", attrs);
}
const GType *InputX() const { return input_x_; }
const GType *InputY() const { return input_y_; }
GType *Out() const { return out_; }
const int &Axis() const { return axis_; }
private:
GType *input_x_;
GType *input_y_;
GType *out_;
int axis_;
};
#endif
#ifdef MUL_OP
template <typename Dtype>
class MulParam : OpParam {
......
......@@ -173,6 +173,14 @@ if (NOT FOUND_MATCH)
target_link_libraries(test-elementwiseadd-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-elementwisesub-op operators/test_elementwise_sub_op.cpp test_helper.h test_include.h)
target_link_libraries(test-elementwisesub-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-im2sequence-op operators/test_im2sequence_op.cpp test_helper.h test_include.h)
target_link_libraries(test-im2sequence-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-concat-op operators/test_concat_op.cpp test_helper.h test_include.h)
target_link_libraries(test-concat-op paddle-mobile)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/elementwise_sub_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestElementwiseSubOp {
public:
explicit TestElementwiseSubOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "elementwise_sub" &&
op->Input("X")[0] == "sigmoid_1.tmp_0") {
DLOG << " elementwise_sub attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
std::shared_ptr<operators::ElementwiseSubOp<Dtype, float>> lrn =
std::make_shared<operators::ElementwiseSubOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(lrn);
}
}
}
}
std::shared_ptr<Tensor> predict_bn(const Tensor &t1, const Tensor &t2) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("tmp_0");
auto tensor_x1 = x1_feed_value->GetMutable<LoDTensor>();
tensor_x1->ShareDataWith(t1);
Variable *x2_feed_value = scope->Var("sigmoid_1.tmp_0");
auto tensor_x2 = x2_feed_value->GetMutable<LoDTensor>();
tensor_x2->ShareDataWith(t2);
Variable *output = scope->Var("tmp_1");
auto *output_tensor = output->GetMutable<LoDTensor>();
output_tensor->mutable_data<float>({1, 1, 6, 6});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict_bn(t1, t2, 0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict_bn(const Tensor &t1, const Tensor &t2, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestElementwiseSubOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run ElementwiseSub Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_ocr) + "/model",
std::string(g_ocr) + "/params");
/// input x1 (1,1,6,6)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {1, 1, 6, 6}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
/// input x2 (1,1,6,6)
paddle_mobile::framework::Tensor inputx2;
SetupTensor<float>(&inputx2, {1, 1, 6, 6}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx2_ptr = inputx2.data<float>();
paddle_mobile::framework::TestElementwiseSubOp<paddle_mobile::CPU>
testElementwiseSubOp(program);
auto output_op = testElementwiseSubOp.predict_bn(inputx1, inputx2);
auto *output_op_ptr = output_op->data<float>();
auto inputx1_dim = inputx1.numel() / inputx1.dims()[0];
DLOG << " input1 : ";
for (int i = 0; i < inputx1.dims()[0]; ++i) {
for (int j = 0; j < inputx1_dim; ++j) {
DLOGF("%f ", inputx1_ptr[i * inputx1_dim + j]);
}
DLOGF("\n");
}
auto inputx2_dim = inputx2.numel() / inputx2.dims()[0];
DLOG << " input2 : ";
for (int i = 0; i < inputx2.dims()[0]; ++i) {
for (int j = 0; j < inputx2_dim; ++j) {
DLOGF("%f ", inputx2_ptr[i * inputx2_dim + j]);
}
DLOGF("\n");
}
auto output_dim = output_op->numel() / output_op->dims()[0];
DLOG << " output : ";
for (int i = 0; i < output_op->dims()[0]; ++i) {
for (int j = 0; j < output_dim; ++j) {
DLOGF("%f ", output_op_ptr[i * output_dim + j]);
}
DLOGF("\n");
}
return 0;
}
......@@ -12,51 +12,129 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../executor_for_test.h"
#pragma once
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/im2sequence_op.h"
int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(g_ocr_recg);
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail");
namespace paddle_mobile {
namespace framework {
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::ReluOp<paddle_mobile::CPU, float>>
executor(program, "im2sequence");
template <typename Dtype>
class TestIm2SequenceOp {
public:
explicit TestIm2SequenceOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
// 1. input_tensors;
vector<Tensor> input_tensors;
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "im2sequence" &&
op->Input("X")[0] == "conv2d_19.tmp_1") {
DLOG << " im2squence attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
Tensor input1;
auto input1_data = CreateInput<float>(&input1, {2, 2, 3, 3}, -1, 1);
input_tensors.push_back(input1);
std::shared_ptr<operators::Im2SequenceOp<Dtype, float>> lrn =
std::make_shared<operators::Im2SequenceOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(lrn);
}
}
}
}
// 2. input_names
vector<string> input_names({
"conv2d_19.tmp_1",
});
std::shared_ptr<Tensor> predict_bn(const Tensor &t1) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("conv2d_19.tmp_1");
auto tensor_x1 = x1_feed_value->GetMutable<LoDTensor>();
tensor_x1->ShareDataWith(t1);
// 3. output_names
vector<string> output_names({"im2sequence_0.tmp_0"});
Variable *output = scope->Var("im2sequence_0.tmp_0");
auto *output_tensor = output->GetMutable<LoDTensor>();
output_tensor->mutable_data<float>({2, 12});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
// 4. out_dims;
vector<DDim> out_ddims;
auto out_ddim = paddle_mobile::framework::make_ddim({8, 9});
out_ddims.push_back(out_ddim);
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
auto output = executor.Predict<LoDTensor>(input_tensors, input_names,
output_names, out_ddims);
predict_bn(t1, 0);
return out_tensor;
}
auto output0_data = output[0]->data<float>();
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
for (int j = 0; j < input_tensors[0].numel(); ++j) {
DLOG << " value of input: " << input1_data[j];
void predict_bn(const Tensor &t1, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestIm2SequenceOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
for (int j = 0; j < output[0]->numel(); ++j) {
DLOG << " value of output: " << output0_data[j];
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run Im2Sequence Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_eng) + "/model",
std::string(g_eng) + "/params");
/// input x (4,10,2,2)
paddle_mobile::framework::Tensor inputx;
SetupTensor<float>(&inputx, {1, 2, 6, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx_ptr = inputx.data<float>();
paddle_mobile::framework::TestIm2SequenceOp<paddle_mobile::CPU>
testIm2SequenceOp(program);
auto output_op = testIm2SequenceOp.predict_bn(inputx);
auto *output_op_ptr = output_op->data<float>();
auto input_dim = inputx.numel() / inputx.dims()[0];
DLOG << " input : ";
for (int i = 0; i < inputx.dims()[0]; ++i) {
for (int j = 0; j < input_dim; ++j) {
DLOGF("%f ", inputx_ptr[i * input_dim + j]);
}
DLOGF("\n");
}
auto output_dim = output_op->numel() / output_op->dims()[0];
DLOG << " output : ";
for (int i = 0; i < output_op->dims()[0]; ++i) {
for (int j = 0; j < output_dim; ++j) {
DLOGF("%f ", output_op_ptr[i * output_dim + j]);
}
DLOGF("\n");
}
return 0;
}
......@@ -189,6 +189,8 @@ if(NOT FOUND_MATCH)
set(CONV_OP ON)
set(DEPTHWISECONV_OP ON)
set(ELEMENTWISEADD_OP ON)
set(ELEMENTWISESUB_OP ON)
set(IM2SEQUENCE_OP ON)
set(FUSION_CONVADD_OP ON)
set(FUSION_CONVADDPRELU_OP ON)
set(FUSION_CONVADDRELU_OP ON)
......@@ -264,6 +266,9 @@ endif()
if (ELEMENTWISEADD_OP)
add_definitions(-DELEMENTWISEADD_OP)
endif()
if (ELEMENTWISESUB_OP)
add_definitions(-DELEMENTWISESUB_OP)
endif()
if (FUSION_CONVADD_OP)
add_definitions(-DFUSION_CONVADD_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册