diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 3fab2127a151161080cbe38771aa237f84a3b6e0..8749a48327604bd2460ed32035d0f6baac5c4e26 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -182,23 +182,37 @@ public: CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO); - size_t samples = inputs[0].shape()[0]; - size_t channels = inputs[0].shape()[1]; - size_t height = inputs[0].shape()[2]; - size_t width = inputs[0].shape()[3]; + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; CrossMapNormal<Device>(outputs[0].data<real>(), outputs[1].data<real>(), inputs[0].data<real>(), - samples, - channels, - height, - width, + batchSize, + maps, + rows, + columns, size_, scale_, pow_); } + // Only need the shape of the input, can calculate the + // floating-point operation. + size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ((size_t)numInputs_, inputs.size()); + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; + + // number of floating-point operations + // an approximate value + size_t ops = batchSize * maps * ((rows * columns) * size_); + } + private: size_t size_; real scale_; diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 4a6c79b6ebdf820aa8144519ecf43167788f8e2a..65688eebee975c0ba2ff72f5285c44af04b1dc3a 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -153,6 +153,13 @@ public: virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + // Calculate the number of floating-point operations of this Function. + // The inputs and outputs arguments do not need to contain the actual data, + // only the shape. + virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) { + return 0; + } + int getNumInputs() const { return numInputs_; } int getNumOutputs() const { return numOutputs_; }