diff --git a/src/common/types.cpp b/src/common/types.cpp index 2a8b8c8a151e58d13093c99c753cc47e8eef64a3..8c8de7765161e61dc75036a87a34fc6abd2df43e 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -22,6 +22,7 @@ const char *G_OP_TYPE_BATCHNORM = "batch_norm"; const char *G_OP_TYPE_BOX_CODER = "box_coder"; const char *G_OP_TYPE_CONCAT = "concat"; const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; +const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant"; const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu"; const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu"; @@ -99,6 +100,7 @@ std::unordered_map< {G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}}, {G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}}, {G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}}, + {G_OP_TYPE_FILL_CONSTANT, {{}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_PRELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU, {{"Input"}, {"Out"}}}, diff --git a/src/framework/data_type.cpp b/src/framework/data_type.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0bcf7d9f67dae28db5a316476778b4132b39b274 --- /dev/null +++ b/src/framework/data_type.cpp @@ -0,0 +1,107 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "framework/data_type.h" +#include +#include +#include + +namespace paddle_mobile { +namespace framework { + +struct DataTypeMap { + std::unordered_map + cpp_to_proto_; + std::unordered_map proto_to_cpp_; + std::unordered_map proto_to_str_; + std::unordered_map cpp_to_size_; +}; + +static DataTypeMap* InitDataTypeMap(); +// C++11 removes the need for manual locking. Concurrent execution shall wait if +// a static local variable is already being initialized. +// https://stackoverflow.com/questions/11711920/how-to-implement-multithread-safe-singleton-in-c11-without-using-mutex +static DataTypeMap& gDataTypeMap() { + static DataTypeMap* g_data_type_map_ = InitDataTypeMap(); + return *g_data_type_map_; +} + +template +static inline void RegisterType( + DataTypeMap* map, _PaddleMobile__Framework__Proto__VarType__Type proto_type, + const std::string& name) { + map->proto_to_cpp_.emplace(static_cast(proto_type), typeid(T)); + map->cpp_to_proto_.emplace(typeid(T), proto_type); + map->proto_to_str_.emplace(static_cast(proto_type), name); + map->cpp_to_size_.emplace(typeid(T), sizeof(T)); +} + +static DataTypeMap* InitDataTypeMap() { + auto retv = new DataTypeMap(); + +#define RegType(cc_type, proto_type) \ + RegisterType(retv, proto_type, #cc_type) + + // NOTE: Add your customize type here. + // RegType(float16, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP16); + RegType(float, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP32); + RegType(double, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP64); + RegType(int, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT32); + RegType(int64_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT64); + RegType(bool, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__BOOL); + RegType(size_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__SIZE_T); + RegType(int16_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT16); + RegType(uint8_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__UINT8); + RegType(int8_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT8); + +#undef RegType + return retv; +} + +_PaddleMobile__Framework__Proto__VarType__Type ToDataType( + std::type_index type) { + auto it = gDataTypeMap().cpp_to_proto_.find(type); + if (it != gDataTypeMap().cpp_to_proto_.end()) { + return it->second; + } + PADDLE_MOBILE_THROW_EXCEPTION("Not support %s as tensor type", type.name()); +} + +std::type_index ToTypeIndex( + _PaddleMobile__Framework__Proto__VarType__Type type) { + auto it = gDataTypeMap().proto_to_cpp_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_cpp_.end()) { + return it->second; + } + PADDLE_MOBILE_THROW_EXCEPTION( + "Not support _PaddleMobile__Framework__Proto__VarType__Type(%d) as " + "tensor type", + static_cast(type)); +} + +std::string DataTypeToString( + const _PaddleMobile__Framework__Proto__VarType__Type type) { + auto it = gDataTypeMap().proto_to_str_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_str_.end()) { + return it->second; + } + PADDLE_MOBILE_THROW_EXCEPTION( + "Not support _PaddleMobile__Framework__Proto__VarType__Type(%d) as " + "tensor type", + static_cast(type)); +} + +} // namespace framework +} // namespace paddle_mobile diff --git a/src/framework/data_type.h b/src/framework/data_type.h new file mode 100644 index 0000000000000000000000000000000000000000..2e3623fdedcb527cb0c85bbb7a2eaf04d91a2193 --- /dev/null +++ b/src/framework/data_type.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "common/enforce.h" +#include "framework/framework.pb-c.h" + +namespace paddle_mobile { + +namespace framework { + +extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType( + std::type_index type); +extern std::type_index ToTypeIndex( + _PaddleMobile__Framework__Proto__VarType__Type type); + +template +inline void VisitDataType(_PaddleMobile__Framework__Proto__VarType__Type type, + Visitor visitor) { + switch (type) { + // case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP16: + // visitor.template apply(); + // break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP32: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP64: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT32: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT64: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__BOOL: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__UINT8: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT16: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT8: + visitor.template apply(); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Not supported %d", type); + } +} + +extern std::string DataTypeToString( + const _PaddleMobile__Framework__Proto__VarType__Type type); +inline std::ostream& operator<<( + std::ostream& out, + const _PaddleMobile__Framework__Proto__VarType__Type& type) { + out << DataTypeToString(type); + return out; +} + +} // namespace framework +} // namespace paddle_mobile diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 7fd6290704a573aac535685b4bdf48092b35e98b..2b76b0158fe06e8678208f6f98fcdb71f8d91e51 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -64,6 +64,9 @@ limitations under the License. */ // load requared ops LOAD_OP(feed) LOAD_OP(fetch) +#ifdef FILL_CONSTANT_OP +LOAD_OP(fill_constant) +#endif #ifdef BATCHNORM_OP LOAD_OP2(batch_norm, CPU, MALI_GPU); #endif diff --git a/src/operators/fill_constant_op.cpp b/src/operators/fill_constant_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6d7c4f44f1b769c47d6f741d139118158292a40f --- /dev/null +++ b/src/operators/fill_constant_op.cpp @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FILL_CONSTANT_OP + +#include "operators/fill_constant_op.h" + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fill_constant, ops::FillConstantOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +REGISTER_OPERATOR_MALI_GPU(fill_constant, ops::FillConstantOp); +#endif +#ifdef PADDLE_MOBILE_FPGA +REGISTER_OPERATOR_FPGA(fill_constant, ops::FillConstantOp); +#endif + +#endif diff --git a/src/operators/fill_constant_op.h b/src/operators/fill_constant_op.h new file mode 100644 index 0000000000000000000000000000000000000000..78eb162efc8ccd42b9fba363d49d1dbc4052f6b2 --- /dev/null +++ b/src/operators/fill_constant_op.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FILL_CONSTANT_OP + +#pragma once + +#include +#include "framework/data_type.h" +#include "framework/operator.h" +#include "framework/selected_rows.h" +#include "operators/math/math_function.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +using std::string; + +template +class FillConstantOp : public framework::OperatorBase { + public: + FillConstantOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap attrs, + std::shared_ptr scope) + : framework::OperatorBase(type, inputs, outputs, attrs, + scope), + param_(inputs, outputs, attrs, *scope) {} + void RunImpl() const { + auto data_type = + static_cast<_PaddleMobile__Framework__Proto__VarType__Type>( + param_.DataDtype()); + framework::Tensor *tensor = nullptr; + auto value = param_.Value(); + auto *outvar = param_.OutVar(); + + if (outvar->template IsType()) { + tensor = outvar->template GetMutable(); + } else if (outvar->template IsType()) { + tensor = outvar->template GetMutable() + ->mutable_value(); + } else { + PADDLE_MOBILE_THROW_EXCEPTION( + "fill constant op's output only" + "supports SelectedRows and LoDTensor"); + } + tensor->Resize(framework::make_ddim(param_.Shape())); + tensor->mutable_data(framework::ToTypeIndex(data_type)); + + math::set_constant(tensor, value); + } + + void Init() {} + + void InferShape() const { + PADDLE_MOBILE_ENFORCE( + param_.Out() != nullptr, + "Output (Out) of fill_constant op should not be null."); + framework::DDim ddim = framework::make_ddim(param_.Shape()); + param_.Out()->Resize(ddim); + } + + protected: + FillConstantParam param_; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index fc4c385add5ccf30ebe42695fb616e41deb1a827..4365bf5716b8b5811f6ac66217b2fe74ae116f52 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -15,12 +15,31 @@ limitations under the License. */ #include "operators/math/math_function.h" #include #include +#include "framework/data_type.h" +#include "framework/tensor.h" #include "operators/math/gemm.h" namespace paddle_mobile { namespace operators { namespace math { +struct TensorSetConstant { + TensorSetConstant(framework::Tensor *tensor, float value) + : tensor_(tensor), value_(value) {} + template + void apply() const { + auto *begin = tensor_->mutable_data(); + std::fill(begin, begin + tensor_->numel(), static_cast(value_)); + } + framework::Tensor *tensor_; + float value_; +}; + +void set_constant(framework::Tensor *tensor, float value) { + framework::VisitDataType(framework::ToDataType(tensor->type()), + TensorSetConstant(tensor, value)); +} + template <> void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index da8f3e042ac441c44d99072fad9593042735c008..b91242c1868398e4541c3727567a905e5b0c8714 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -22,6 +22,8 @@ namespace paddle_mobile { namespace operators { namespace math { +void set_constant(framework::Tensor *tensor, float value); + template void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, diff --git a/src/operators/op_param.h b/src/operators/op_param.h index ba5bf40aab2ab81d42de4957c6aa12e30e84924d..568cf77b8e4e81732cd9a783c1a9ea64d347102b 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1063,6 +1063,42 @@ class FetchParam : public OpParam { RType *out_; }; +#ifdef FILL_CONSTANT_OP +template +class FillConstantParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FillConstantParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs, + const Scope &scope) { + out_var_ = OutVarFrom(outputs, scope); + out_ = OutFrom(outputs, scope); + dtype_ = GetAttr("dtype", attrs); + shape_ = GetAttr>("shape", attrs); + value_ = GetAttr("value", attrs); + } + + Variable *OutVar() const { return out_var_; } + + RType *Out() const { return out_; } + + const int &DataDtype() const { return dtype_; } + + const vector &Shape() const { return shape_; } + + const float &Value() const { return value_; } + + private: + Variable *out_var_; + RType *out_; + int dtype_; + vector shape_; + float value_; +}; +#endif + #ifdef TRANSPOSE_OP template class TransposeParam : public OpParam { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6f38a353765769f060e4c5468bd44d7d9a9b3b32..e3d79edf483c8182cd91bc0b82ad9989e211f671 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -185,6 +185,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-polygon-box-transform-op operators/test_polygon_box_transform_op.cpp test_helper.h test_include.h) target_link_libraries(test-polygon-box-transform-op paddle-mobile) + # gen test + ADD_EXECUTABLE(test-fill-constant-op operators/test_fill_constant_op.cpp test_helper.h test_include.h) + target_link_libraries(test-fill-constant-op paddle-mobile) + # gen test ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h) target_link_libraries(test-reshape-op paddle-mobile) diff --git a/test/operators/test_fill_constant_op.cpp b/test/operators/test_fill_constant_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b099217d1641eb221b3d0d86d780fb6ecfa929bd --- /dev/null +++ b/test/operators/test_fill_constant_op.cpp @@ -0,0 +1,113 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "../test_include.h" +#include "operators/fill_constant_op.h" + +namespace paddle_mobile { +namespace framework { + +template +class TestFillConstantOp { + public: + explicit TestFillConstantOp(const Program p) : program_(p) { + if (use_optimize_) { + to_predict_program_ = program_.optimizeProgram; + } else { + to_predict_program_ = program_.originProgram; + } + const std::vector> blocks = + to_predict_program_->Blocks(); + for (auto block_desc : blocks) { + std::vector> ops = block_desc->Ops(); + for (auto op : ops) { + if (op->Type() == "fill_constant") { + DLOG << " attr size: " << op->GetAttrMap().size(); + std::unordered_map attrs = op->GetAttrMap(); + for (std::unordered_map::iterator it = + attrs.begin(); + it != attrs.end(); ++it) { + DLOG << " " << it->first << " " << it->second; + } + DLOG << " inputs size: " << op->GetInputs().size(); + DLOG << " outputs size: " << op->GetOutputs().size(); + DLOG << " output is : " << op->Output("Out")[0]; + output_var_name = op->Output("Out")[0]; + std::shared_ptr> op_ptr = + std::make_shared>( + op->Type(), op->GetInputs(), op->GetOutputs(), + op->GetAttrMap(), program_.scope); + ops_of_block_[*block_desc.get()].push_back(op_ptr); + } + } + } + } + + std::shared_ptr predict() { + auto scope = program_.scope; + + Variable *output = scope->Var(output_var_name); + auto *output_tensor = output->GetMutable(); + + std::shared_ptr out_tensor = std::make_shared(); + out_tensor.reset(output_tensor); + + predict(0); + + return out_tensor; + } + + private: + const framework::Program program_; + std::shared_ptr to_predict_program_; + std::map>>> + ops_of_block_; + bool use_optimize_ = false; + string output_var_name; + + void predict(int block_id) { + std::shared_ptr to_predict_block = + to_predict_program_->Block(block_id); + for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { + auto op = ops_of_block_[*to_predict_block.get()][j]; + op->Run(); + } + } +}; + +template class TestFillConstantOp; +} // namespace framework +} // namespace paddle_mobile + +int main() { + DLOG << "----------**********----------"; + DLOG << "begin to run FillConstant Test"; + paddle_mobile::Loader loader; + auto program = loader.Load(std::string(g_ocr) + "/model", + std::string(g_ocr) + "/params"); + + paddle_mobile::framework::TestFillConstantOp + testFillConstantOp(program); + + auto output = testFillConstantOp.predict(); + auto *output_ptr = output->data(); + + DLOG << "output : "; + for (int i = 0; i < output->numel(); ++i) { + DLOG << " index " << i << " : " << output_ptr[i]; + } + return 0; +} diff --git a/tools/op.cmake b/tools/op.cmake index 68f363eef4d641a7d0a65ed9132ef6e54616ffcf..f7a6ed4b134f78ddb23487cd3a861f244e6a86db 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -188,6 +188,7 @@ if(NOT FOUND_MATCH) set(ELEMENTWISEADD_OP ON) set(ELEMENTWISESUB_OP ON) set(IM2SEQUENCE_OP ON) + set(FILL_CONSTANT_OP ON) set(FUSION_CONVADD_OP ON) set(FUSION_CONVADDPRELU_OP ON) set(FUSION_CONVADDRELU_OP ON) @@ -233,6 +234,7 @@ endif() # option(CONV_OP "" ON) # option(DEPTHWISECONV_OP "" ON) # option(ELEMENTWISEADD_OP "" ON) + # option(FILL_CONSTANT_OP "" ON) # option(FUSION_CONVADD_OP "" ON) # option(FUSION_CONVADDRELU_OP "" ON) # option(FUSION_FC_OP "" ON) @@ -270,6 +272,9 @@ endif() if (ELEMENTWISESUB_OP) add_definitions(-DELEMENTWISESUB_OP) endif() +if (FILL_CONSTANT_OP) + add_definitions(-DFILL_CONSTANT_OP) +endif() if (FUSION_CONVADD_OP) add_definitions(-DFUSION_CONVADD_OP) endif()