From 225a8fa14b8fa04c814da02ff9f240f1819373f3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 23 Jan 2017 11:26:20 +0800 Subject: [PATCH] Add numInputs_ and numOutputs_ --- paddle/function/CrossMapNormalOp.cpp | 18 ++++++++++++++---- paddle/function/Function.h | 13 +++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 5c0bdd933..3fab2127a 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -162,14 +162,19 @@ template class CrossMapNormalFunc : public FunctionBase { public: void init(const FuncConfig& config) override { + // function arguments size_ = config.get("size"); scale_ = config.get("scale"); pow_ = config.get("pow"); + + // number of inputs and outputs + numInputs_ = 1; + numOutputs_ = 2; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)1, inputs.size()); - CHECK_EQ((size_t)2, outputs.size()); + CHECK_EQ((size_t)numInputs_, inputs.size()); + CHECK_EQ((size_t)numOutputs_, outputs.size()); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape() == outputs[0].shape()); @@ -236,14 +241,19 @@ template class CrossMapNormalGradFunc : public FunctionBase { public: void init(const FuncConfig& config) override { + // function arguments size_ = config.get("size"); scale_ = config.get("scale"); pow_ = config.get("pow"); + + // number of inputs and outputs + numInputs_ = 4; + numOutputs_ = 1; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)4, inputs.size()); - CHECK_EQ((size_t)1, outputs.size()); + CHECK_EQ((size_t)numInputs_, inputs.size()); + CHECK_EQ((size_t)numOutputs_, outputs.size()); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape() == inputs[1].shape()); diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 9215c137e..4a6c79b6e 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -153,7 +153,20 @@ public: virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + int getNumInputs() const { return numInputs_; } + + int getNumOutputs() const { return numOutputs_; } + static ClassRegistrar funcRegistrar_; + +protected: + // numInputs_ and numOutputs_ represents the maximum + // input and output supported by Function. + // Some functions are optimized for input and output, + // so when comparing the number of arguments, for these functions + // inputs.size() <= numInputs_ or outputs.size() <= numOutputs_ + size_t numInputs_; + size_t numOutputs_; }; #define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName -- GitLab