diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 5c0bdd933b1e4a62e49981798c56c70907d16424..3fab2127a151161080cbe38771aa237f84a3b6e0 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -162,14 +162,19 @@ template class CrossMapNormalFunc : public FunctionBase { public: void init(const FuncConfig& config) override { + // function arguments size_ = config.get("size"); scale_ = config.get("scale"); pow_ = config.get("pow"); + + // number of inputs and outputs + numInputs_ = 1; + numOutputs_ = 2; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)1, inputs.size()); - CHECK_EQ((size_t)2, outputs.size()); + CHECK_EQ((size_t)numInputs_, inputs.size()); + CHECK_EQ((size_t)numOutputs_, outputs.size()); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape() == outputs[0].shape()); @@ -236,14 +241,19 @@ template class CrossMapNormalGradFunc : public FunctionBase { public: void init(const FuncConfig& config) override { + // function arguments size_ = config.get("size"); scale_ = config.get("scale"); pow_ = config.get("pow"); + + // number of inputs and outputs + numInputs_ = 4; + numOutputs_ = 1; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)4, inputs.size()); - CHECK_EQ((size_t)1, outputs.size()); + CHECK_EQ((size_t)numInputs_, inputs.size()); + CHECK_EQ((size_t)numOutputs_, outputs.size()); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape() == inputs[1].shape()); diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 9215c137eb8e85a9a03575104d7f89bbce441eba..4a6c79b6ebdf820aa8144519ecf43167788f8e2a 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -153,7 +153,20 @@ public: virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + int getNumInputs() const { return numInputs_; } + + int getNumOutputs() const { return numOutputs_; } + static ClassRegistrar funcRegistrar_; + +protected: + // numInputs_ and numOutputs_ represents the maximum + // input and output supported by Function. + // Some functions are optimized for input and output, + // so when comparing the number of arguments, for these functions + // inputs.size() <= numInputs_ or outputs.size() <= numOutputs_ + size_t numInputs_; + size_t numOutputs_; }; #define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName