提交 9896f15e 编写于 作者: H hedaoyuan

Add FunctionBase::ops()

上级 225a8fa1
......@@ -182,23 +182,37 @@ public:
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
size_t batchSize = inputs[0].shape()[0];
size_t maps = inputs[0].shape()[1];
size_t rows = inputs[0].shape()[2];
size_t columns = inputs[0].shape()[3];
CrossMapNormal<Device>(outputs[0].data<real>(),
outputs[1].data<real>(),
inputs[0].data<real>(),
samples,
channels,
height,
width,
batchSize,
maps,
rows,
columns,
size_,
scale_,
pow_);
}
// Only need the shape of the input, can calculate the
// floating-point operation.
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)numInputs_, inputs.size());
size_t batchSize = inputs[0].shape()[0];
size_t maps = inputs[0].shape()[1];
size_t rows = inputs[0].shape()[2];
size_t columns = inputs[0].shape()[3];
// number of floating-point operations
// an approximate value
size_t ops = batchSize * maps * ((rows * columns) * size_);
}
private:
size_t size_;
real scale_;
......
......@@ -153,6 +153,13 @@ public:
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
// Calculate the number of floating-point operations of this Function.
// The inputs and outputs arguments do not need to contain the actual data,
// only the shape.
virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) {
return 0;
}
int getNumInputs() const { return numInputs_; }
int getNumOutputs() const { return numOutputs_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册