From 9896f15e7cabd5d68ec03157439a44bbb709c221 Mon Sep 17 00:00:00 2001
From: hedaoyuan <hedaoyuan@github.com>
Date: Mon, 23 Jan 2017 12:44:03 +0800
Subject: [PATCH] Add FunctionBase::ops()

---
 paddle/function/CrossMapNormalOp.cpp | 30 ++++++++++++++++++++--------
 paddle/function/Function.h           |  7 +++++++
 2 files changed, 29 insertions(+), 8 deletions(-)

diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp
index 3fab2127a15..8749a483276 100644
--- a/paddle/function/CrossMapNormalOp.cpp
+++ b/paddle/function/CrossMapNormalOp.cpp
@@ -182,23 +182,37 @@ public:
 
     CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
     CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
-    size_t samples = inputs[0].shape()[0];
-    size_t channels = inputs[0].shape()[1];
-    size_t height = inputs[0].shape()[2];
-    size_t width = inputs[0].shape()[3];
+    size_t batchSize = inputs[0].shape()[0];
+    size_t maps = inputs[0].shape()[1];
+    size_t rows = inputs[0].shape()[2];
+    size_t columns = inputs[0].shape()[3];
 
     CrossMapNormal<Device>(outputs[0].data<real>(),
                            outputs[1].data<real>(),
                            inputs[0].data<real>(),
-                           samples,
-                           channels,
-                           height,
-                           width,
+                           batchSize,
+                           maps,
+                           rows,
+                           columns,
                            size_,
                            scale_,
                            pow_);
   }
 
+  // Only need the shape of the input, can calculate the
+  // floating-point operation.
+  size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
+    CHECK_EQ((size_t)numInputs_, inputs.size());
+    size_t batchSize = inputs[0].shape()[0];
+    size_t maps = inputs[0].shape()[1];
+    size_t rows = inputs[0].shape()[2];
+    size_t columns = inputs[0].shape()[3];
+
+    // number of floating-point operations
+    // an approximate value
+    size_t ops = batchSize * maps * ((rows * columns) * size_);
+  }
+
 private:
   size_t size_;
   real scale_;
diff --git a/paddle/function/Function.h b/paddle/function/Function.h
index 4a6c79b6ebd..65688eebee9 100644
--- a/paddle/function/Function.h
+++ b/paddle/function/Function.h
@@ -153,6 +153,13 @@ public:
 
   virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
 
+  // Calculate the number of floating-point operations of this Function.
+  // The inputs and outputs arguments do not need to contain the actual data,
+  // only the shape.
+  virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) {
+    return 0;
+  }
+
   int getNumInputs() const { return numInputs_; }
 
   int getNumOutputs() const { return numOutputs_; }
-- 
GitLab