From a1d2abc16d9c7b42af6dcb41902423ae2904ee9a Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 14 Dec 2016 18:46:40 +0800 Subject: [PATCH] add Function --- paddle/math/Function.cpp | 47 +++++++++++++ paddle/math/Function.h | 84 ++++++++++++++++++++++++ paddle/math/cross_map_normal_op.cpp | 46 +++++++++++++ paddle/math/cross_map_normal_op.h | 20 +----- paddle/math/tests/test_matrixCompare.cpp | 15 +++-- 5 files changed, 188 insertions(+), 24 deletions(-) create mode 100644 paddle/math/Function.cpp create mode 100644 paddle/math/Function.h diff --git a/paddle/math/Function.cpp b/paddle/math/Function.cpp new file mode 100644 index 00000000000..21d27191728 --- /dev/null +++ b/paddle/math/Function.cpp @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" + +namespace paddle { + +template <> +size_t FuncConfig::get(const std::string& key) const { + auto it = valueMap_.find(key); + CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'"; + return it->second.s; +} + +template <> +real FuncConfig::get(const std::string& key) const { + auto it = valueMap_.find(key); + CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'"; + return it->second.r; +} + +template <> +void FuncConfig::set(const std::string& key, size_t v) { + CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; + valueMap_[key].s = v; +} + +template <> +void FuncConfig::set(const std::string& key, real v) { + CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; + valueMap_[key].r = v; +} + +ClassRegistrar FunctionBase::funcRegistrar_; + +} // namespace paddle diff --git a/paddle/math/Function.h b/paddle/math/Function.h new file mode 100644 index 00000000000..b41ba2a13d3 --- /dev/null +++ b/paddle/math/Function.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/utils/ClassRegistrar.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0, + DEVICE_TYPE_CPU = 1, + DEVICE_TYPE_GPU = 2, +}; + +template +struct MatrixT; + +template <> +struct MatrixT { + using type = CpuMatrix; +}; + +template <> +struct MatrixT { + using type = GpuMatrix; +}; + +typedef std::vector Arguments; + +class FuncConfig { +public: + union value { + size_t s; + real r; + }; + + template + T get(const std::string& key) const; + + template + void set(const std::string& key, T v); + +protected: + std::map valueMap_; +}; + +class FunctionBase { +public: + virtual ~FunctionBase() {} + + virtual void init(const FuncConfig& config) {} + + virtual void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) {} + + static ClassRegistrar funcRegistrar_; +}; + +#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName + +#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ + static InitFunction __reg_type_##typeName([]() { \ + FunctionBase::funcRegistrar_ \ + .registerClass>( \ + FUNC_NAME(typeName, deviceName)); \ + }) + +} // namespace paddle diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index be242926aff..0b727320638 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -128,4 +128,50 @@ void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, } } +template +class CrossMapNormalFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + size_ = config.get("size"); + scale_ = config.get("scale"); + pow_ = config.get("pow"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(1, inputs.size()); + CHECK_EQ(2, outputs.size()); + CHECK_EQ(0, inouts.size()); + + auto input = dynamic_cast::type&>(inputs[0]); + auto output = + dynamic_cast::type&>(outputs[0]); + auto denom = + dynamic_cast::type&>(outputs[1]); + + CHECK(input.isContiguous()); + CHECK(output.isContiguous()); + CHECK(denom.isContiguous()); + CHECK_EQ(output.getHeight(), input.getHeight()); + CHECK_EQ(output.getWidth(), input.getWidth()); + CHECK_EQ(output.getHeight(), denom.getHeight()); + CHECK_EQ(output.getWidth(), denom.getWidth()); + + // CrossMapNormal cross; + // need: + // size_t channels, + // size_t imgSizeH, + // size_t imgSizeW, + // cross(output, denom, input, ); + } + +private: + size_t size_; + real scale_; + real pow_; +}; + +REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); + } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index c2bb95f6b11..86f54abde10 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -14,29 +14,11 @@ limitations under the License. */ #pragma once +#include "Function.h" #include "paddle/math/Matrix.h" namespace paddle { -enum DeviceType { - DEVICE_TYPE_UNSPECIFIED = 0, - DEVICE_TYPE_CPU = 1, - DEVICE_TYPE_GPU = 2, -}; - -template -struct MatrixT; - -template <> -struct MatrixT { - using type = CpuMatrix; -}; - -template <> -struct MatrixT { - using type = GpuMatrix; -}; - template struct CrossMapNormal { void operator()(typename MatrixT::type& outputs, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 8d7a4fb94d0..0b75785528f 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/utils/Stat.h" #include "TensorCheck.h" #include "paddle/math/cross_map_normal_op.h" +#include "paddle/math/Function.h" using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -1280,6 +1281,15 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); + FuncConfig config; + config.set("size", (size_t)sizeX); + config.set("scale", scale); + config.set("pow", pow); + FunctionBase* cpu = + FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); + cpu->init(config); + // cpu->calc(); + CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); @@ -1295,11 +1305,6 @@ void testCrossMapNormalFwd( scale, pow); -#if 0 - outputsGpu.crossMapNormalFwd( - inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); -#endif - TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); } -- GitLab