提交 6173d413 编写于 作者: J Jiaying Zhao 提交者: GitHub

Merge pull request #1427 from smilejames/develop

add increment、is_empty op
......@@ -83,6 +83,8 @@ const char *G_OP_TYPE_LOGICAL_NOT = "logical_not";
const char *G_OP_TYPE_LOGICAL_XOR = "logical_xor";
const char *G_OP_TYPE_WRITE_TO_ARRAY = "write_to_array";
const char *G_OP_TYPE_READ_FROM_ARRAY = "read_from_array";
const char *G_OP_TYPE_IS_EMPTY = "is_empty";
const char *G_OP_TYPE_INCREMENT = "increment";
const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
......@@ -199,6 +201,8 @@ std::unordered_map<
{G_OP_TYPE_LOGICAL_NOT, {{"X"}, {"Out"}}},
{G_OP_TYPE_WRITE_TO_ARRAY, {{"X", "I"}, {"Out"}}},
{G_OP_TYPE_READ_FROM_ARRAY, {{"X", "I"}, {"Out"}}},
{G_OP_TYPE_IS_EMPTY, {{"X"}, {"Out"}}},
{G_OP_TYPE_INCREMENT, {{"X"}, {"Out"}}},
{G_OP_TYPE_SLICE, {{"Input"}, {"Out"}}},
{G_OP_TYPE_ANCHOR_GENERATOR, {{"Input"}, {"Anchors", "Variances"}}},
{G_OP_TYPE_GENERATE_PROPOSALS,
......
......@@ -172,6 +172,8 @@ extern const char *G_OP_TYPE_LOGICAL_NOT;
extern const char *G_OP_TYPE_LOGICAL_XOR;
extern const char *G_OP_TYPE_WRITE_TO_ARRAY;
extern const char *G_OP_TYPE_READ_FROM_ARRAY;
extern const char *G_OP_TYPE_IS_EMPTY;
extern const char *G_OP_TYPE_INCREMENT;
extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE;
......
......@@ -306,6 +306,12 @@ LOAD_OP1(write_to_array, CPU);
#ifdef READ_FROM_ARRAY_OP
LOAD_OP1(read_from_array, CPU);
#endif
#ifdef IS_EMPTY_OP
LOAD_OP1(is_empty, CPU);
#endif
#ifdef INCREMENT_OP
LOAD_OP1(increment, CPU);
#endif
#ifdef ANCHOR_GENERATOR_OP
LOAD_OP1(anchor_generator, CPU);
#endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#include "operators/increment_op.h"
#include "framework/op_proto_maker.h"
#include "framework/op_registry.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void IncrementOp<Dtype, T>::InferShape() const {
auto input = this->param_.InputX();
auto out = this->param_.Out();
PADDLE_MOBILE_ENFORCE(input->numel() == 1, "input's numel should be 1");
out->Resize(input->dims());
out->set_lod(input->lod());
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(increment, ops::IncrementOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/increment_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class IncrementOp
: public framework::OperatorWithKernel<DeviceType,
IncrementParam<DeviceType>,
IncrementKernel<DeviceType, T>> {
public:
IncrementOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, IncrementParam<DeviceType>,
IncrementKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef IS_EMPTY_OP
#include "operators/is_empty_op.h"
#include "framework/op_proto_maker.h"
#include "framework/op_registry.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void IsEmptyOp<Dtype, T>::InferShape() const {
auto out = this->param_.Out();
out->Resize({1});
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(is_empty, ops::IsEmptyOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef IS_EMPTY_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/is_empty_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class IsEmptyOp
: public framework::OperatorWithKernel<DeviceType, IsEmptyParam<DeviceType>,
IsEmptyKernel<DeviceType, T>> {
public:
IsEmptyOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, IsEmptyParam<DeviceType>,
IsEmptyKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#include "operators/kernel/increment_kernel.h"
#include <operators/kernel/central-arm-func/increment_arm_func.h>
namespace paddle_mobile {
namespace operators {
template <>
bool IncrementKernel<CPU, float>::Init(IncrementParam<CPU> *param) {
return true;
}
template <>
void IncrementKernel<CPU, float>::Compute(const IncrementParam<CPU> &param) {
IncrementCompute<float>(param);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#include "operators/kernel/is_empty_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool IsEmptyKernel<CPU, float>::Init(IsEmptyParam<CPU> *param) {
return true;
}
template <>
void IsEmptyKernel<CPU, float>::Compute(const IsEmptyParam<CPU> &param) {
const framework::Tensor *input = param.InputX();
framework::Tensor *out = param.Out();
out->mutable_data<bool>()[0] = input->numel() == 0;
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#pragma once
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void IncrementCompute(const IncrementParam<CPU> &param) {
const framework::Tensor *input = param.InputX();
framework::Tensor *out = param.Out();
int step = param.Step();
out->mutable_data<P>();
const P *input_data = input->data<P>();
P *out_data = out->data<P>();
*out_data = *input_data + step;
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class IncrementKernel
: public framework::OpKernelBase<DeviceType, IncrementParam<DeviceType>> {
public:
void Compute(const IncrementParam<DeviceType> &param);
bool Init(IncrementParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef IS_EMPTY_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class IsEmptyKernel
: public framework::OpKernelBase<DeviceType, IsEmptyParam<DeviceType>> {
public:
void Compute(const IsEmptyParam<DeviceType> &param);
bool Init(IsEmptyParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
此差异已折叠。
......@@ -46,15 +46,6 @@ namespace math {
class Gemm {
public:
/*
// 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
*/
typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *);
typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *,
int);
......@@ -62,31 +53,31 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
FnPack procPackB;
FnAddDot procAddDot;
// 将 A 矩阵分块复制到连续内存(RowMajor)
// 将 A\B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
#if __aarch64__
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
#endif
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
......@@ -106,22 +97,16 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
/*
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float
*C, int ldc, bool relu, float *new_scale, float *new_bias);
*/
// 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
#if __aarch64__
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc);
#else
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc);
#endif
// 分块矩阵乘法结果回写
// C = A * B
......@@ -149,6 +134,18 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias1);
// 向量矩阵乘法 (M = 1)
#if __aarch64__
#else
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta,
float *C, int ldc, bool relu, float *new_scale,
float *new_bias);
// 向量矩阵乘法结果回写
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc);
......@@ -158,14 +155,13 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
/*
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float
*new_scale, float *new_bias);
*/
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
#endif
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
......
......@@ -3066,5 +3066,52 @@ class ReadFromArrayParam : public OpParam {
};
#endif
#ifdef IS_EMPTY_OP
template <typename Dtype>
class IsEmptyParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
IsEmptyParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
}
const GType *InputX() const { return input_x_; }
GType *Out() const { return output_; }
public:
GType *input_x_;
GType *output_;
};
#endif // IS_EMPTY_OP
#ifdef INCREMENT_OP
template <typename Dtype>
class IncrementParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
IncrementParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
step_ = OpParam::GetAttr<int>("step", attrs);
}
const GType *InputX() const { return input_x_; }
GType *Out() const { return output_; }
int Step() const { return step_; }
public:
GType *input_x_;
GType *output_;
int step_;
};
#endif // INCREMENT_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -437,6 +437,14 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-logical-xor-op operators/test_logical_xor_op.cpp test_helper.h test_include.h)
target_link_libraries(test-logical-xor-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-increment-op operators/test_increment_op.cpp test_helper.h test_include.h)
target_link_libraries(test-increment-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-is-empty-op operators/test_is_empty_op.cpp test_helper.h test_include.h)
target_link_libraries(test-is-empty-op paddle-mobile)
ADD_EXECUTABLE(test-conv-bn-relu-op operators/test_conv_bn_relu_op.cpp test_helper.h test_include.h)
target_link_libraries(test-conv-bn-relu-op paddle-mobile)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/increment_op.h"
namespace paddle_mobile {
template <typename T>
void Increment(const framework::Tensor *input, framework::Tensor *out,
int step) {
auto input_data = input->data<T>();
auto out_data = out->data<T>();
*out_data = *input_data + step;
}
int TestIncrementOp(const std::vector<int> input_shape, int step) {
framework::DDim input_dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"inputX"});
outputs["Out"] = std::vector<std::string>({"output"});
auto x_var = scope.get()->Var("inputX");
auto x = x_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(x, input_dims, 0, 100);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["step"].Set<int>(step);
auto *op = new operators::IncrementOp<CPU, float>("increment", inputs,
outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Increment<float>(x, &output_cmp, step);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestIncrementOp({1}, 4);
paddle_mobile::TestIncrementOp({1}, 10);
DLOG << "test increment op pass.";
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/is_empty_op.h"
namespace paddle_mobile {
void IsEmpty(const framework::Tensor *input, framework::Tensor *out) {
out->data<bool>()[0] = input->numel() == 0;
}
int TestIsEmptyOp(const std::vector<int> input_shape) {
framework::DDim input_dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"inputX"});
outputs["Out"] = std::vector<std::string>({"output"});
auto x_var = scope.get()->Var("inputX");
auto x = x_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(x, input_dims, 0, 100);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
auto *op = new operators::IsEmptyOp<CPU, float>("is_empty", inputs, outputs,
attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
bool *output_cmp_data = output_cmp.mutable_data<bool>(output->dims());
IsEmpty(x, &output_cmp);
const bool *output_data = output->data<bool>();
for (int i = 0; i < output->numel(); ++i) {
if (output_data[i] != output_cmp_data[i]) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestIsEmptyOp({1, 3, 100, 100});
paddle_mobile::TestIsEmptyOp({0});
DLOG << "test is_empty op pass.";
return 0;
}
......@@ -288,6 +288,8 @@ if(NOT FOUND_MATCH)
set(WHILE_OP ON)
set(WRITE_TO_ARRAY_OP ON)
set(READ_FROM_ARRAY_OP ON)
set(IS_EMPTY_OP ON)
set(INCREMENT_OP ON)
set(ANCHOR_GENERATOR_OP ON)
set(PROPOSAL_OP ON)
set(PSROI_POOL_OP ON)
......@@ -575,6 +577,12 @@ endif()
if (READ_FROM_ARRAY_OP)
add_definitions(-DREAD_FROM_ARRAY_OP)
endif()
if (IS_EMPTY_OP)
add_definitions(-DIS_EMPTY_OP)
endif()
if (INCREMENT_OP)
add_definitions(-DINCREMENT_OP)
endif()
if (ANCHOR_GENERATOR_OP)
add_definitions(-DANCHOR_GENERATOR_OP)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册