提交 225a8fa1 编写于 作者: H hedaoyuan

Add numInputs_ and numOutputs_

上级 50e525ca
......@@ -162,14 +162,19 @@ template <DeviceType Device>
class CrossMapNormalFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
// function arguments
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");
// number of inputs and outputs
numInputs_ = 1;
numOutputs_ = 2;
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)1, inputs.size());
CHECK_EQ((size_t)2, outputs.size());
CHECK_EQ((size_t)numInputs_, inputs.size());
CHECK_EQ((size_t)numOutputs_, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == outputs[0].shape());
......@@ -236,14 +241,19 @@ template <DeviceType Device>
class CrossMapNormalGradFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
// function arguments
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");
// number of inputs and outputs
numInputs_ = 4;
numOutputs_ = 1;
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)4, inputs.size());
CHECK_EQ((size_t)1, outputs.size());
CHECK_EQ((size_t)numInputs_, inputs.size());
CHECK_EQ((size_t)numOutputs_, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == inputs[1].shape());
......
......@@ -153,7 +153,20 @@ public:
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
int getNumInputs() const { return numInputs_; }
int getNumOutputs() const { return numOutputs_; }
static ClassRegistrar<FunctionBase> funcRegistrar_;
protected:
// numInputs_ and numOutputs_ represents the maximum
// input and output supported by Function.
// Some functions are optimized for input and output,
// so when comparing the number of arguments, for these functions
// inputs.size() <= numInputs_ or outputs.size() <= numOutputs_
size_t numInputs_;
size_t numOutputs_;
};
#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册