diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 8749a48327604bd2460ed32035d0f6baac5c4e26..99af02ac7441460e6e2bf98cff4bd59ba6930185 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -173,13 +173,9 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)numInputs_, inputs.size()); - CHECK_EQ((size_t)numOutputs_, outputs.size()); - - CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); - CHECK(inputs[0].shape() == outputs[0].shape()); - CHECK(inputs[0].shape() == outputs[1].shape()); - + check(inputs, outputs); + // ArgType check still on here, + // not sure whether it is better to put inside the check. CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO); size_t batchSize = inputs[0].shape()[0]; @@ -199,6 +195,15 @@ public: pow_); } + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ((size_t)numInputs_, inputs.size()); + CHECK_EQ((size_t)numOutputs_, outputs.size()); + + CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); + CHECK(inputs[0].shape() == outputs[0].shape()); + CHECK(inputs[0].shape() == outputs[1].shape()); + } + // Only need the shape of the input, can calculate the // floating-point operation. size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { @@ -211,6 +216,8 @@ public: // number of floating-point operations // an approximate value size_t ops = batchSize * maps * ((rows * columns) * size_); + + return ops; } private: diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 65688eebee975c0ba2ff72f5285c44af04b1dc3a..4802c2e846cfa2e1b51e96b6a4c612b5172c0708 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -153,6 +153,12 @@ public: virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + // This member function is used to check whether the BufferType and shape of + // the inputs and outputs arguments of the Function are correct. + // General calc function which will call this check to do arguments check. + // Also before the call calc, the caller can also check their own arguments. + virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {} + // Calculate the number of floating-point operations of this Function. // The inputs and outputs arguments do not need to contain the actual data, // only the shape.