提交 7efcc19c 编写于 作者: H hjchen2

Add top_k and cast operators

上级 bd7216fd
......@@ -63,12 +63,7 @@ struct PaddleMobileException : public std::exception {
#else
#define PADDLE_MOBILE_THROW_EXCEPTION(...)
#define PADDLE_MOBILE_ENFORCE(stat, ...) \
{ \
if (stat) { \
} else { \
} \
}
#define PADDLE_MOBILE_ENFORCE(stat, ...)
#endif
......
......@@ -69,6 +69,8 @@ const char *G_OP_TYPE_FLATTEN = "flatten";
const char *G_OP_TYPE_SHAPE = "shape";
const char *G_OP_TYPE_ELEMENTWISE_MUL = "elementwise_mul";
const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_TOP_K = "top_k";
const char *G_OP_TYPE_CAST = "cast";
const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
......@@ -142,6 +144,8 @@ std::unordered_map<
{G_OP_TYPE_SHAPE, {{"Input"}, {"Out"}}},
{G_OP_TYPE_CONV_TRANSPOSE, {{"Input"}, {"Output"}}},
{G_OP_TYPE_SUM, {{"X"}, {"Out"}}},
{G_OP_TYPE_TOP_K, {{"X"}, {"Out", "Indices"}}},
{G_OP_TYPE_CAST, {{"X"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
......
......@@ -150,6 +150,8 @@ extern const char *G_OP_TYPE_CONV_TRANSPOSE;
extern const char *G_OP_TYPE_PRELU;
extern const char *G_OP_TYPE_SUM;
extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_TOP_K;
extern const char *G_OP_TYPE_CAST;
extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE;
......
......@@ -28,6 +28,10 @@ extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType(
extern std::type_index ToTypeIndex(
_PaddleMobile__Framework__Proto__VarType__Type type);
inline _PaddleMobile__Framework__Proto__VarType__Type ToDataType(int type) {
return static_cast<_PaddleMobile__Framework__Proto__VarType__Type>(type);
}
template <typename Visitor>
inline void VisitDataType(_PaddleMobile__Framework__Proto__VarType__Type type,
Visitor visitor) {
......
......@@ -228,6 +228,12 @@ LOAD_FUSION_MATCHER(fusion_conv_bn);
#ifdef ELEMENTWISESUB_OP
LOAD_OP1(elementwise_sub, CPU)
#endif
#ifdef TOP_K_OP
LOAD_OP1(top_k, CPU)
#endif
#ifdef CAST_OP
LOAD_OP1(cast, CPU)
#endif
#ifdef QUANT_OP
LOAD_OP1(quantize, CPU);
#endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CAST_OP
#include "operators/cast_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void CastOp<DeviceType, T>::InferShape() const {
const auto &dims = this->param_.input_->dims();
this->param_.output_->Resize(dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(cast, ops::CastOp);
#endif
#endif // CAST_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CAST_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class CastOp : public framework::OperatorWithKernel<
DeviceType, CastParam<DeviceType>,
operators::CastKernel<DeviceType, T>> {
public:
CastOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, CastParam<DeviceType>,
operators::CastKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // CAST_OP
......@@ -33,4 +33,4 @@ namespace ops = paddle_mobile::operators;
REGISTER_OPERATOR_CPU(dequantize, ops::DequantizeOp);
#endif
#endif
#endif // DEQUANT_OP
......@@ -44,4 +44,4 @@ class DequantizeOp
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // DEQUANT_OP
......@@ -14,19 +14,16 @@ limitations under the License. */
#ifdef GRU_OP
#include "operators/gru_op.h"
#include <iostream>
#include <vector>
#include "common/enforce.h"
#include "operators/gru_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void GruOp<Dtype, T>::InferShape() const {
auto lod_size = this->param_.InputInput()->lod().size();
PADDLE_MOBILE_ENFORCE((lod_size == 1),
"Current LoD only supports one dimension.");
auto input_dims = this->param_.InputInput()->dims();
auto weight_dims = this->param_.InputWeight()->dims();
int input_size = input_dims[1];
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef IM2SEQUENCE_OP
#include "operators/im2sequence_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
......@@ -29,20 +30,16 @@ int Im2SequenceOutputSize(int input_size, int kernel, int padding_1,
template <typename Dtype, typename T>
void Im2SequenceOp<Dtype, T>::InferShape() const {
auto in_x_dims = this->param_.Input()->dims();
const std::vector<int> &kernels = this->param_.Kernels();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(Im2SequenceOutputSize(in_x_dims[i + 2], kernels[i],
paddings[i], paddings[i + 2],
strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
......@@ -54,9 +51,5 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(im2sequence, ops::Im2SequenceOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
#endif // IM2SEQUENCE_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CAST_OP
#include <algorithm>
#include <iostream>
#include <vector>
#include "framework/data_type.h"
#include "operators/kernel/kernels.h"
namespace paddle_mobile {
namespace operators {
template <typename InT>
struct CastOutOpFunctor {
const framework::Tensor* in_;
framework::Tensor* out_;
CastOutOpFunctor(const framework::Tensor* in, framework::Tensor* out)
: in_(in), out_(out) {}
template <typename OutT>
void apply() const {
const InT* input = in_->data<InT>();
OutT* output = out_->mutable_data<OutT>();
size_t numel = in_->numel();
for (int i = 0; i < numel; ++i) {
output[i] = static_cast<OutT>(input[i]);
}
}
};
struct CastOpFunctor {
const framework::Tensor* in_;
framework::Tensor* out_;
int output_type_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const int output_type)
: in_(in), out_(out), output_type_(output_type) {}
template <typename InT>
void apply() const {
framework::VisitDataType(framework::ToDataType(output_type_),
CastOutOpFunctor<InT>(in_, out_));
}
};
template <>
bool CastKernel<CPU, float>::Init(CastParam<CPU>* param) {
return true;
}
template <>
void CastKernel<CPU, float>::Compute(const CastParam<CPU>& param) {
const Tensor* input = param.input_;
Tensor* output = param.output_;
framework::VisitDataType(framework::ToDataType(param.input_type_),
CastOpFunctor(input, output, param.output_type_));
}
} // namespace operators
} // namespace paddle_mobile
#endif // CAST_OP
......@@ -55,10 +55,9 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight
framework::Tensor transformed_weight;
operators::math::winograd_transform_weight<8, 3>(*param->Filter(),
&transformed_weight);
framework::TensorCopy(transformed_weight, param->Filter());
param->transformed_filter_ = new framework::Tensor;
operators::math::winograd_transform_weight<8, 3>(
*param->Filter(), param->transformed_filter_);
#endif
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
......
......@@ -29,12 +29,6 @@ template <>
void GruKernel<CPU, float>::Compute(const GruParam<CPU> &param) {
GruCompute<float>(param);
param.OutHidden()->set_lod(param.InputInput()->lod());
// DLOG << "________________" << param.OutHidden()->dims();
// DLOG << "________________" << param.OutHidden()->numel();
// auto *hiden_data = param.OutHidden()->data<float>();
// for (int64_t i = 0; i < 10; i++) {
// DLOG << "****************" << hiden_data[i];
// }
}
template class GruKernel<CPU, float>;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TOP_K_OP
#include <algorithm>
#include <iostream>
#include <vector>
#include "operators/kernel/kernels.h"
namespace paddle_mobile {
namespace operators {
template <>
bool TopKKernel<CPU, float>::Init(TopKParam<CPU> *param) {
return true;
}
template <>
void TopKKernel<CPU, float>::Compute(const TopKParam<CPU> &param) {
const Tensor *input = param.input_;
Tensor *output = param.output_;
Tensor *indices = param.indices_;
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
int64_t *indices_data = indices->mutable_data<int64_t>();
framework::DDim input_dims = input->dims();
const size_t row = framework::product(
framework::slice_ddim(input_dims, 0, input_dims.size() - 1));
const size_t col = input_dims[input_dims.size() - 1];
#pragma omp parallel for
for (size_t i = 0; i < row; i++) {
std::vector<std::pair<float, size_t>> vec(col);
const float *input_ptr = input_data + i * col;
float *output_ptr = output_data + i * param.k_;
int64_t *indices_ptr = indices_data + i * param.k_;
for (size_t j = 0; j < col; j++) {
vec[j] = std::move(std::pair<float, size_t>(input_ptr[j], j));
}
std::partial_sort(
vec.begin(), vec.begin() + param.k_, vec.end(),
[](const std::pair<float, size_t> &l,
const std::pair<float, size_t> &r) { return l.first > r.first; });
for (int j = 0; j < param.k_; ++j) {
output_ptr[j] = vec[j].first;
indices_ptr[j] = static_cast<int64_t>(vec[j].second);
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif // TOP_K_OP
......@@ -117,7 +117,7 @@ inline void GemmConv(const ConvParam<CPU> &param) {
template <int tile, int kernel>
inline void WinogradConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
const Tensor *filter = param.transformed_filter_;
Tensor *output = param.Output();
output->mutable_data<float>();
int batch_size = input->dims()[0];
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
#define DECLARE_KERNEL(KernelClass, KernelParam) \
template <typename DeviceType, typename T> \
class KernelClass \
: public framework::OpKernelBase<DeviceType, KernelParam<DeviceType>> { \
public: \
bool Init(KernelParam<DeviceType> *param); \
void Compute(const KernelParam<DeviceType> &param); \
};
#ifdef TOP_K_OP
DECLARE_KERNEL(TopKKernel, TopKParam)
#endif // TOP_K_OP
#ifdef CAST_OP
DECLARE_KERNEL(CastKernel, CastParam)
#endif // CAST_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -439,10 +439,11 @@ class ConvParam : public OpParam {
#endif
protected:
public:
RType *input_;
RType *output_;
RType *filter_;
RType *transformed_filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
......@@ -455,7 +456,7 @@ class ConvParam : public OpParam {
#ifdef PADDLE_MOBILE_FPGA
private:
public:
fpga::SplitConvArgs fpga_conv_args;
public:
......@@ -2515,6 +2516,52 @@ class ShapeParam : public OpParam {
};
#endif
#ifdef TOP_K_OP
template <typename Dtype>
class TopKParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
TopKParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = OpParam::GetVarValue<GType>("X", inputs, scope);
output_ = OpParam::GetVarValue<GType>("Out", outputs, scope);
indices_ = OpParam::GetVarValue<GType>("Indices", outputs, scope);
k_ = OpParam::GetAttr<int>("k", attrs);
}
public:
RType *input_;
RType *output_;
RType *indices_;
int k_;
};
#endif // TOP_K_OP
#ifdef CAST_OP
template <typename Dtype>
class CastParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
CastParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = OpParam::GetVarValue<GType>("X", inputs, scope);
output_ = OpParam::GetVarValue<GType>("Out", outputs, scope);
input_type_ = OpParam::GetAttr<int>("in_dtype", attrs);
output_type_ = OpParam::GetAttr<int>("out_dtype", attrs);
}
public:
RType *input_;
RType *output_;
int input_type_;
int output_type_;
};
#endif // CAST_OP
#ifdef QUANT_OP
template <typename Dtype>
class QuantizeParam : public OpParam {
......
......@@ -36,4 +36,4 @@ namespace ops = paddle_mobile::operators;
REGISTER_OPERATOR_CPU(quantize, ops::QuantizeOp);
#endif
#endif
#endif // QUANT_OP
......@@ -43,4 +43,4 @@ class QuantizeOp : public framework::OperatorWithKernel<
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // QUANT_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TOP_K_OP
#include "operators/top_k_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void TopKOp<DeviceType, T>::InferShape() const {
const int k = this->param_.k_;
auto dims = this->param_.input_->dims();
// should check k <= dims[-1] && k >= 1
dims[dims.size() - 1] = k;
this->param_.output_->Resize(dims);
this->param_.indices_->Resize(dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(top_k, ops::TopKOp);
#endif
#endif // TOP_K_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TOP_K_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class TopKOp : public framework::OperatorWithKernel<
DeviceType, TopKParam<DeviceType>,
operators::TopKKernel<DeviceType, T>> {
public:
TopKOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, TopKParam<DeviceType>,
operators::TopKKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // TOP_K_OP
......@@ -249,6 +249,8 @@ if(NOT FOUND_MATCH)
set(SHAPE_OP ON)
set(ELEMENTWISEMUL_OP ON)
set(SUM_OP ON)
set(TOP_K_OP ON)
set(CAST_OP ON)
set(QUANT_OP ON)
set(DEQUANT_OP ON)
set(FUSION_DEQUANT_BN_OP ON)
......@@ -457,7 +459,12 @@ endif()
if (SUM_OP)
add_definitions(-DSUM_OP)
endif()
if (TOP_K_OP)
add_definitions(-DTOP_K_OP)
endif()
if (CAST_OP)
add_definitions(-DCAST_OP)
endif()
if (QUANT_OP)
add_definitions(-DQUANT_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册