提交 c889558c 编写于 作者: L lijiancheng0614

add fill constant op

上级 92bdb523
......@@ -22,6 +22,7 @@ const char *G_OP_TYPE_BATCHNORM = "batch_norm";
const char *G_OP_TYPE_BOX_CODER = "box_coder";
const char *G_OP_TYPE_CONCAT = "concat";
const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant";
const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu";
const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu";
......@@ -99,6 +100,7 @@ std::unordered_map<
{G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}},
{G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}},
{G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}},
{G_OP_TYPE_FILL_CONSTANT, {{}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_PRELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU, {{"Input"}, {"Out"}}},
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/data_type.h"
#include <stdint.h>
#include <string>
#include <unordered_map>
namespace paddle_mobile {
namespace framework {
struct DataTypeMap {
std::unordered_map<std::type_index,
_PaddleMobile__Framework__Proto__VarType__Type>
cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_;
};
static DataTypeMap* InitDataTypeMap();
// C++11 removes the need for manual locking. Concurrent execution shall wait if
// a static local variable is already being initialized.
// https://stackoverflow.com/questions/11711920/how-to-implement-multithread-safe-singleton-in-c11-without-using-mutex
static DataTypeMap& gDataTypeMap() {
static DataTypeMap* g_data_type_map_ = InitDataTypeMap();
return *g_data_type_map_;
}
template <typename T>
static inline void RegisterType(
DataTypeMap* map, _PaddleMobile__Framework__Proto__VarType__Type proto_type,
const std::string& name) {
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T));
}
static DataTypeMap* InitDataTypeMap() {
auto retv = new DataTypeMap();
#define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here.
// RegType(float16, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP16);
RegType(float, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP32);
RegType(double, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP64);
RegType(int, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT32);
RegType(int64_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT64);
RegType(bool, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__BOOL);
RegType(size_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__SIZE_T);
RegType(int16_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT16);
RegType(uint8_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__UINT8);
RegType(int8_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT8);
#undef RegType
return retv;
}
_PaddleMobile__Framework__Proto__VarType__Type ToDataType(
std::type_index type) {
auto it = gDataTypeMap().cpp_to_proto_.find(type);
if (it != gDataTypeMap().cpp_to_proto_.end()) {
return it->second;
}
PADDLE_MOBILE_THROW_EXCEPTION("Not support %s as tensor type", type.name());
}
std::type_index ToTypeIndex(
_PaddleMobile__Framework__Proto__VarType__Type type) {
auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_cpp_.end()) {
return it->second;
}
PADDLE_MOBILE_THROW_EXCEPTION(
"Not support _PaddleMobile__Framework__Proto__VarType__Type(%d) as "
"tensor type",
static_cast<int>(type));
}
std::string DataTypeToString(
const _PaddleMobile__Framework__Proto__VarType__Type type) {
auto it = gDataTypeMap().proto_to_str_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_str_.end()) {
return it->second;
}
PADDLE_MOBILE_THROW_EXCEPTION(
"Not support _PaddleMobile__Framework__Proto__VarType__Type(%d) as "
"tensor type",
static_cast<int>(type));
}
} // namespace framework
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <typeindex>
#include "common/enforce.h"
#include "framework/framework.pb-c.h"
namespace paddle_mobile {
namespace framework {
extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType(
std::type_index type);
extern std::type_index ToTypeIndex(
_PaddleMobile__Framework__Proto__VarType__Type type);
template <typename Visitor>
inline void VisitDataType(_PaddleMobile__Framework__Proto__VarType__Type type,
Visitor visitor) {
switch (type) {
// case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP16:
// visitor.template apply<float16>();
// break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP32:
visitor.template apply<float>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP64:
visitor.template apply<double>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT32:
visitor.template apply<int>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT64:
visitor.template apply<int64_t>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__BOOL:
visitor.template apply<bool>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__UINT8:
visitor.template apply<uint8_t>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT16:
visitor.template apply<int16_t>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT8:
visitor.template apply<int8_t>();
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Not supported %d", type);
}
}
extern std::string DataTypeToString(
const _PaddleMobile__Framework__Proto__VarType__Type type);
inline std::ostream& operator<<(
std::ostream& out,
const _PaddleMobile__Framework__Proto__VarType__Type& type) {
out << DataTypeToString(type);
return out;
}
} // namespace framework
} // namespace paddle_mobile
......@@ -64,6 +64,9 @@ limitations under the License. */
// load requared ops
LOAD_OP(feed)
LOAD_OP(fetch)
#ifdef FILL_CONSTANT_OP
LOAD_OP(fill_constant)
#endif
#ifdef BATCHNORM_OP
LOAD_OP2(batch_norm, CPU, MALI_GPU);
#endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FILL_CONSTANT_OP
#include "operators/fill_constant_op.h"
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fill_constant, ops::FillConstantOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(fill_constant, ops::FillConstantOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fill_constant, ops::FillConstantOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FILL_CONSTANT_OP
#pragma once
#include <string>
#include "framework/data_type.h"
#include "framework/operator.h"
#include "framework/selected_rows.h"
#include "operators/math/math_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class FillConstantOp : public framework::OperatorBase<DeviceType> {
public:
FillConstantOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorBase<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const {
auto data_type =
static_cast<_PaddleMobile__Framework__Proto__VarType__Type>(
param_.DataDtype());
framework::Tensor *tensor = nullptr;
auto value = param_.Value();
auto *outvar = param_.OutVar();
if (outvar->template IsType<framework::LoDTensor>()) {
tensor = outvar->template GetMutable<framework::LoDTensor>();
} else if (outvar->template IsType<framework::SelectedRows>()) {
tensor = outvar->template GetMutable<framework::SelectedRows>()
->mutable_value();
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
"fill constant op's output only"
"supports SelectedRows and LoDTensor");
}
tensor->Resize(framework::make_ddim(param_.Shape()));
tensor->mutable_data(framework::ToTypeIndex(data_type));
math::set_constant(tensor, value);
}
void Init() {}
void InferShape() const {
PADDLE_MOBILE_ENFORCE(
param_.Out() != nullptr,
"Output (Out) of fill_constant op should not be null.");
framework::DDim ddim = framework::make_ddim(param_.Shape());
param_.Out()->Resize(ddim);
}
protected:
FillConstantParam<DeviceType> param_;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -15,12 +15,31 @@ limitations under the License. */
#include "operators/math/math_function.h"
#include <cstring>
#include <string>
#include "framework/data_type.h"
#include "framework/tensor.h"
#include "operators/math/gemm.h"
namespace paddle_mobile {
namespace operators {
namespace math {
struct TensorSetConstant {
TensorSetConstant(framework::Tensor *tensor, float value)
: tensor_(tensor), value_(value) {}
template <typename T>
void apply() const {
auto *begin = tensor_->mutable_data<T>();
std::fill(begin, begin + tensor_->numel(), static_cast<T>(value_));
}
framework::Tensor *tensor_;
float value_;
};
void set_constant(framework::Tensor *tensor, float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstant(tensor, value));
}
template <>
void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
......
......@@ -21,6 +21,8 @@ namespace paddle_mobile {
namespace operators {
namespace math {
void set_constant(framework::Tensor *tensor, float value);
template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
......
......@@ -1063,6 +1063,42 @@ class FetchParam : public OpParam {
RType *out_;
};
#ifdef FILL_CONSTANT_OP
template <typename Dtype>
class FillConstantParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FillConstantParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
out_var_ = OutVarFrom(outputs, scope);
out_ = OutFrom<GType>(outputs, scope);
dtype_ = GetAttr<int>("dtype", attrs);
shape_ = GetAttr<vector<int>>("shape", attrs);
value_ = GetAttr<float>("value", attrs);
}
Variable *OutVar() const { return out_var_; }
RType *Out() const { return out_; }
const int &DataDtype() const { return dtype_; }
const vector<int> &Shape() const { return shape_; }
const float &Value() const { return value_; }
private:
Variable *out_var_;
RType *out_;
int dtype_;
vector<int> shape_;
float value_;
};
#endif
#ifdef TRANSPOSE_OP
template <typename Dtype>
class TransposeParam : public OpParam {
......
......@@ -188,6 +188,7 @@ if(NOT FOUND_MATCH)
set(ELEMENTWISEADD_OP ON)
set(ELEMENTWISESUB_OP ON)
set(IM2SEQUENCE_OP ON)
set(FILL_CONSTANT_OP ON)
set(FUSION_CONVADD_OP ON)
set(FUSION_CONVADDPRELU_OP ON)
set(FUSION_CONVADDRELU_OP ON)
......@@ -231,6 +232,7 @@ endif()
# option(CONV_OP "" ON)
# option(DEPTHWISECONV_OP "" ON)
# option(ELEMENTWISEADD_OP "" ON)
# option(FILL_CONSTANT_OP "" ON)
# option(FUSION_CONVADD_OP "" ON)
# option(FUSION_CONVADDRELU_OP "" ON)
# option(FUSION_FC_OP "" ON)
......@@ -268,6 +270,9 @@ endif()
if (ELEMENTWISESUB_OP)
add_definitions(-DELEMENTWISESUB_OP)
endif()
if (FILL_CONSTANT_OP)
add_definitions(-DFILL_CONSTANT_OP)
endif()
if (FUSION_CONVADD_OP)
add_definitions(-DFUSION_CONVADD_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册