提交 225a8fa1 编写于 作者: H hedaoyuan

Add numInputs_ and numOutputs_

上级 50e525ca
...@@ -162,14 +162,19 @@ template <DeviceType Device> ...@@ -162,14 +162,19 @@ template <DeviceType Device>
class CrossMapNormalFunc : public FunctionBase { class CrossMapNormalFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
// function arguments
size_ = config.get<size_t>("size"); size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale"); scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow"); pow_ = config.get<real>("pow");
// number of inputs and outputs
numInputs_ = 1;
numOutputs_ = 2;
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)1, inputs.size()); CHECK_EQ((size_t)numInputs_, inputs.size());
CHECK_EQ((size_t)2, outputs.size()); CHECK_EQ((size_t)numOutputs_, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == outputs[0].shape()); CHECK(inputs[0].shape() == outputs[0].shape());
...@@ -236,14 +241,19 @@ template <DeviceType Device> ...@@ -236,14 +241,19 @@ template <DeviceType Device>
class CrossMapNormalGradFunc : public FunctionBase { class CrossMapNormalGradFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
// function arguments
size_ = config.get<size_t>("size"); size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale"); scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow"); pow_ = config.get<real>("pow");
// number of inputs and outputs
numInputs_ = 4;
numOutputs_ = 1;
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)4, inputs.size()); CHECK_EQ((size_t)numInputs_, inputs.size());
CHECK_EQ((size_t)1, outputs.size()); CHECK_EQ((size_t)numOutputs_, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == inputs[1].shape()); CHECK(inputs[0].shape() == inputs[1].shape());
......
...@@ -153,7 +153,20 @@ public: ...@@ -153,7 +153,20 @@ public:
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
int getNumInputs() const { return numInputs_; }
int getNumOutputs() const { return numOutputs_; }
static ClassRegistrar<FunctionBase> funcRegistrar_; static ClassRegistrar<FunctionBase> funcRegistrar_;
protected:
// numInputs_ and numOutputs_ represents the maximum
// input and output supported by Function.
// Some functions are optimized for input and output,
// so when comparing the number of arguments, for these functions
// inputs.size() <= numInputs_ or outputs.size() <= numOutputs_
size_t numInputs_;
size_t numOutputs_;
}; };
#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName #define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册