From a9228e2a406ecb3588ea0c2d112971260d87e1a3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 23 Jan 2017 13:49:19 +0800 Subject: [PATCH] Fix CrossMapNormalGradFunc --- paddle/function/CrossMapNormalOp.cpp | 59 ++++++++++++++++++---------- paddle/function/Function.h | 5 ++- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 99af02ac744..ef878bfbba9 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -196,8 +196,8 @@ public: } void check(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)numInputs_, inputs.size()); - CHECK_EQ((size_t)numOutputs_, outputs.size()); + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape() == outputs[0].shape()); @@ -215,7 +215,7 @@ public: // number of floating-point operations // an approximate value - size_t ops = batchSize * maps * ((rows * columns) * size_); + size_t ops = batchSize * maps * rows * columns * (size_ * 2 + 3); return ops; } @@ -273,15 +273,7 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)numInputs_, inputs.size()); - CHECK_EQ((size_t)numOutputs_, outputs.size()); - - CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); - CHECK(inputs[0].shape() == inputs[1].shape()); - CHECK(inputs[0].shape() == inputs[2].shape()); - CHECK(inputs[0].shape() == inputs[3].shape()); - CHECK(inputs[0].shape() == outputs[0].shape()); - + check(inputs, outputs); if (outputs[0].getArgType() != ADD_TO) { // Currently, some algorithm implementations are ASSIGN_TO mode, // if need to support the ADD_TO calculation, need to clear the output. @@ -290,25 +282,52 @@ public: tmp.zero(); } - size_t samples = inputs[0].shape()[0]; - size_t channels = inputs[0].shape()[1]; - size_t height = inputs[0].shape()[2]; - size_t width = inputs[0].shape()[3]; + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; CrossMapNormalGrad(outputs[0].data(), inputs[0].data(), inputs[1].data(), inputs[2].data(), inputs[3].data(), - samples, - channels, - height, - width, + batchSize, + maps, + rows, + columns, size_, scale_, pow_); } + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + + CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); + CHECK(inputs[0].shape() == inputs[1].shape()); + CHECK(inputs[0].shape() == inputs[2].shape()); + CHECK(inputs[0].shape() == inputs[3].shape()); + CHECK(inputs[0].shape() == outputs[0].shape()); + } + + // Only need the shape of one input, can calculate the + // floating-point operation. + size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_LT((size_t)1, inputs.size()); + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; + + // number of floating-point operations + // an approximate value + size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2); + + return ops; + } + private: size_t size_; real scale_; diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 4802c2e846c..3bbeb6e525f 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -156,12 +156,15 @@ public: // This member function is used to check whether the BufferType and shape of // the inputs and outputs arguments of the Function are correct. // General calc function which will call this check to do arguments check. - // Also before the call calc, the caller can also check their own arguments. + // And before the calc called, the caller can also check their own arguments. virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {} // Calculate the number of floating-point operations of this Function. // The inputs and outputs arguments do not need to contain the actual data, // only the shape. + // And some Functions have the same input and output shapes, + // so you may not need to enter the complete number of arguments. + // But entering the full arguments is always correct for this interface. virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) { return 0; } -- GitLab