diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 5c0bdd933b1e4a62e49981798c56c70907d16424..ef878bfbba961bdd3d5212e19fb83bb1e285e47f 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -162,38 +162,64 @@ template class CrossMapNormalFunc : public FunctionBase { public: void init(const FuncConfig& config) override { + // function arguments size_ = config.get("size"); scale_ = config.get("scale"); pow_ = config.get("pow"); + + // number of inputs and outputs + numInputs_ = 1; + numOutputs_ = 2; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)1, inputs.size()); - CHECK_EQ((size_t)2, outputs.size()); - - CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); - CHECK(inputs[0].shape() == outputs[0].shape()); - CHECK(inputs[0].shape() == outputs[1].shape()); - + check(inputs, outputs); + // ArgType check still on here, + // not sure whether it is better to put inside the check. CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO); - size_t samples = inputs[0].shape()[0]; - size_t channels = inputs[0].shape()[1]; - size_t height = inputs[0].shape()[2]; - size_t width = inputs[0].shape()[3]; + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; CrossMapNormal(outputs[0].data(), outputs[1].data(), inputs[0].data(), - samples, - channels, - height, - width, + batchSize, + maps, + rows, + columns, size_, scale_, pow_); } + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + + CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); + CHECK(inputs[0].shape() == outputs[0].shape()); + CHECK(inputs[0].shape() == outputs[1].shape()); + } + + // Only need the shape of the input, can calculate the + // floating-point operation. + size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ((size_t)numInputs_, inputs.size()); + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; + + // number of floating-point operations + // an approximate value + size_t ops = batchSize * maps * rows * columns * (size_ * 2 + 3); + + return ops; + } + private: size_t size_; real scale_; @@ -236,21 +262,18 @@ template class CrossMapNormalGradFunc : public FunctionBase { public: void init(const FuncConfig& config) override { + // function arguments size_ = config.get("size"); scale_ = config.get("scale"); pow_ = config.get("pow"); + + // number of inputs and outputs + numInputs_ = 4; + numOutputs_ = 1; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)4, inputs.size()); - CHECK_EQ((size_t)1, outputs.size()); - - CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); - CHECK(inputs[0].shape() == inputs[1].shape()); - CHECK(inputs[0].shape() == inputs[2].shape()); - CHECK(inputs[0].shape() == inputs[3].shape()); - CHECK(inputs[0].shape() == outputs[0].shape()); - + check(inputs, outputs); if (outputs[0].getArgType() != ADD_TO) { // Currently, some algorithm implementations are ASSIGN_TO mode, // if need to support the ADD_TO calculation, need to clear the output. @@ -259,25 +282,52 @@ public: tmp.zero(); } - size_t samples = inputs[0].shape()[0]; - size_t channels = inputs[0].shape()[1]; - size_t height = inputs[0].shape()[2]; - size_t width = inputs[0].shape()[3]; + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; CrossMapNormalGrad(outputs[0].data(), inputs[0].data(), inputs[1].data(), inputs[2].data(), inputs[3].data(), - samples, - channels, - height, - width, + batchSize, + maps, + rows, + columns, size_, scale_, pow_); } + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + + CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); + CHECK(inputs[0].shape() == inputs[1].shape()); + CHECK(inputs[0].shape() == inputs[2].shape()); + CHECK(inputs[0].shape() == inputs[3].shape()); + CHECK(inputs[0].shape() == outputs[0].shape()); + } + + // Only need the shape of one input, can calculate the + // floating-point operation. + size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_LT((size_t)1, inputs.size()); + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; + + // number of floating-point operations + // an approximate value + size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2); + + return ops; + } + private: size_t size_; real scale_; diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 9215c137eb8e85a9a03575104d7f89bbce441eba..3bbeb6e525f85bdde9a54c8d60146eaa30a1bb4d 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -153,7 +153,36 @@ public: virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + // This member function is used to check whether the BufferType and shape of + // the inputs and outputs arguments of the Function are correct. + // General calc function which will call this check to do arguments check. + // And before the calc called, the caller can also check their own arguments. + virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {} + + // Calculate the number of floating-point operations of this Function. + // The inputs and outputs arguments do not need to contain the actual data, + // only the shape. + // And some Functions have the same input and output shapes, + // so you may not need to enter the complete number of arguments. + // But entering the full arguments is always correct for this interface. + virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) { + return 0; + } + + int getNumInputs() const { return numInputs_; } + + int getNumOutputs() const { return numOutputs_; } + static ClassRegistrar funcRegistrar_; + +protected: + // numInputs_ and numOutputs_ represents the maximum + // input and output supported by Function. + // Some functions are optimized for input and output, + // so when comparing the number of arguments, for these functions + // inputs.size() <= numInputs_ or outputs.size() <= numOutputs_ + size_t numInputs_; + size_t numOutputs_; }; #define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName