提交 a1d2abc1 编写于 作者: H hedaoyuan

add Function

上级 e357f271
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Function.h"
namespace paddle {
template <>
size_t FuncConfig::get<size_t>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
return it->second.s;
}
template <>
real FuncConfig::get<real>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
return it->second.r;
}
template <>
void FuncConfig::set<size_t>(const std::string& key, size_t v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
valueMap_[key].s = v;
}
template <>
void FuncConfig::set<real>(const std::string& key, real v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
valueMap_[key].r = v;
}
ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_;
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <vector>
#include "paddle/utils/ClassRegistrar.h"
#include "paddle/math/Matrix.h"
namespace paddle {
enum DeviceType {
DEVICE_TYPE_UNSPECIFIED = 0,
DEVICE_TYPE_CPU = 1,
DEVICE_TYPE_GPU = 2,
};
template <DeviceType Device>
struct MatrixT;
template <>
struct MatrixT<DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
template <>
struct MatrixT<DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
typedef std::vector<Matrix> Arguments;
class FuncConfig {
public:
union value {
size_t s;
real r;
};
template <typename T>
T get(const std::string& key) const;
template <typename T>
void set(const std::string& key, T v);
protected:
std::map<std::string, value> valueMap_;
};
class FunctionBase {
public:
virtual ~FunctionBase() {}
virtual void init(const FuncConfig& config) {}
virtual void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) {}
static ClassRegistrar<FunctionBase> funcRegistrar_;
};
#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName
#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \
static InitFunction __reg_type_##typeName([]() { \
FunctionBase::funcRegistrar_ \
.registerClass<className<DEVICE_TYPE_##deviceName>>( \
FUNC_NAME(typeName, deviceName)); \
})
} // namespace paddle
......@@ -128,4 +128,50 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>::operator()(CpuMatrix& inputsGrad,
}
}
template <DeviceType Device>
class CrossMapNormalFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(1, inputs.size());
CHECK_EQ(2, outputs.size());
CHECK_EQ(0, inouts.size());
auto input = dynamic_cast<const typename MatrixT<Device>::type&>(inputs[0]);
auto output =
dynamic_cast<const typename MatrixT<Device>::type&>(outputs[0]);
auto denom =
dynamic_cast<const typename MatrixT<Device>::type&>(outputs[1]);
CHECK(input.isContiguous());
CHECK(output.isContiguous());
CHECK(denom.isContiguous());
CHECK_EQ(output.getHeight(), input.getHeight());
CHECK_EQ(output.getWidth(), input.getWidth());
CHECK_EQ(output.getHeight(), denom.getHeight());
CHECK_EQ(output.getWidth(), denom.getWidth());
// CrossMapNormal<Device> cross;
// need:
// size_t channels,
// size_t imgSizeH,
// size_t imgSizeW,
// cross(output, denom, input, );
}
private:
size_t size_;
real scale_;
real pow_;
};
REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc);
} // namespace paddle
......@@ -14,29 +14,11 @@ limitations under the License. */
#pragma once
#include "Function.h"
#include "paddle/math/Matrix.h"
namespace paddle {
enum DeviceType {
DEVICE_TYPE_UNSPECIFIED = 0,
DEVICE_TYPE_CPU = 1,
DEVICE_TYPE_GPU = 2,
};
template <DeviceType Device>
struct MatrixT;
template <>
struct MatrixT<DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
template <>
struct MatrixT<DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
template <DeviceType Device>
struct CrossMapNormal {
void operator()(typename MatrixT<Device>::type& outputs,
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/utils/Stat.h"
#include "TensorCheck.h"
#include "paddle/math/cross_map_normal_op.h"
#include "paddle/math/Function.h"
using namespace paddle; // NOLINT
using namespace std; // NOLINT
......@@ -1280,6 +1281,15 @@ void testCrossMapNormalFwd(
inputsGpu.copyFrom(inputs);
outputsGpu.copyFrom(outputs);
FuncConfig config;
config.set("size", (size_t)sizeX);
config.set("scale", scale);
config.set("pow", pow);
FunctionBase* cpu =
FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU));
cpu->init(config);
// cpu->calc();
CrossMapNormal<DEVICE_TYPE_CPU> cpuCross;
cpuCross(
outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow);
......@@ -1295,11 +1305,6 @@ void testCrossMapNormalFwd(
scale,
pow);
#if 0
outputsGpu.crossMapNormalFwd(
inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow);
#endif
TensorCheckErr(outputs, outputsGpu);
TensorCheckErr(denoms, denomsGpu);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册