From 9896f15e7cabd5d68ec03157439a44bbb709c221 Mon Sep 17 00:00:00 2001 From: hedaoyuan <hedaoyuan@github.com> Date: Mon, 23 Jan 2017 12:44:03 +0800 Subject: [PATCH] Add FunctionBase::ops() --- paddle/function/CrossMapNormalOp.cpp | 30 ++++++++++++++++++++-------- paddle/function/Function.h | 7 +++++++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 3fab2127a15..8749a483276 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -182,23 +182,37 @@ public: CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO); - size_t samples = inputs[0].shape()[0]; - size_t channels = inputs[0].shape()[1]; - size_t height = inputs[0].shape()[2]; - size_t width = inputs[0].shape()[3]; + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; CrossMapNormal<Device>(outputs[0].data<real>(), outputs[1].data<real>(), inputs[0].data<real>(), - samples, - channels, - height, - width, + batchSize, + maps, + rows, + columns, size_, scale_, pow_); } + // Only need the shape of the input, can calculate the + // floating-point operation. + size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ((size_t)numInputs_, inputs.size()); + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; + + // number of floating-point operations + // an approximate value + size_t ops = batchSize * maps * ((rows * columns) * size_); + } + private: size_t size_; real scale_; diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 4a6c79b6ebd..65688eebee9 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -153,6 +153,13 @@ public: virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + // Calculate the number of floating-point operations of this Function. + // The inputs and outputs arguments do not need to contain the actual data, + // only the shape. + virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) { + return 0; + } + int getNumInputs() const { return numInputs_; } int getNumOutputs() const { return numOutputs_; } -- GitLab