提交 c4437fa2 编写于 作者: H hedaoyuan

Add FunctionBase::check()

上级 9896f15e
...@@ -173,13 +173,9 @@ public: ...@@ -173,13 +173,9 @@ public:
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)numInputs_, inputs.size()); check(inputs, outputs);
CHECK_EQ((size_t)numOutputs_, outputs.size()); // ArgType check still on here,
// not sure whether it is better to put inside the check.
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == outputs[0].shape());
CHECK(inputs[0].shape() == outputs[1].shape());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
size_t batchSize = inputs[0].shape()[0]; size_t batchSize = inputs[0].shape()[0];
...@@ -199,6 +195,15 @@ public: ...@@ -199,6 +195,15 @@ public:
pow_); pow_);
} }
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)numInputs_, inputs.size());
CHECK_EQ((size_t)numOutputs_, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == outputs[0].shape());
CHECK(inputs[0].shape() == outputs[1].shape());
}
// Only need the shape of the input, can calculate the // Only need the shape of the input, can calculate the
// floating-point operation. // floating-point operation.
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
...@@ -211,6 +216,8 @@ public: ...@@ -211,6 +216,8 @@ public:
// number of floating-point operations // number of floating-point operations
// an approximate value // an approximate value
size_t ops = batchSize * maps * ((rows * columns) * size_); size_t ops = batchSize * maps * ((rows * columns) * size_);
return ops;
} }
private: private:
......
...@@ -153,6 +153,12 @@ public: ...@@ -153,6 +153,12 @@ public:
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
// This member function is used to check whether the BufferType and shape of
// the inputs and outputs arguments of the Function are correct.
// General calc function which will call this check to do arguments check.
// Also before the call calc, the caller can also check their own arguments.
virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {}
// Calculate the number of floating-point operations of this Function. // Calculate the number of floating-point operations of this Function.
// The inputs and outputs arguments do not need to contain the actual data, // The inputs and outputs arguments do not need to contain the actual data,
// only the shape. // only the shape.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册