提交 7872f376 编写于 作者: H hedaoyuan

Fix some compile error.

上级 8266546e
...@@ -68,12 +68,10 @@ public: ...@@ -68,12 +68,10 @@ public:
numOutputs_ = 1; numOutputs_ = 1;
} }
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
// input can be INPUT and INPUT_GRAD // input can be INPUT and INPUT_GRAD
// filter can be FILTER and FILTER_GRAD // filter can be FILTER and FILTER_GRAD
// output can be OUTPUT and OUTPUT_GRAD // output can be OUTPUT and OUTPUT_GRAD
void check(const TensorShape& input, void checkShape(const TensorShape& input,
const TensorShape& filter, const TensorShape& filter,
const TensorShape& output) { const TensorShape& output) {
// inputs and outputs arguments should be 4-dimensional. // inputs and outputs arguments should be 4-dimensional.
......
...@@ -117,15 +117,23 @@ public: ...@@ -117,15 +117,23 @@ public:
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
// TODO(hedaoyuan): Need to define some index macros, // TODO(hedaoyuan): Need to define some index macros,
// to avoid useing 0 and 1. // to avoid useing 0 and 1.
const TensorShape& input = inputs[0].shape(); const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape(); const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape(); const TensorShape& output = outputs[0].shape();
check(input, filter, output);
real beta; real beta;
if (outputs[0].getArgType() == ADD_TO) { if (outputs[0].getArgType() == ADD_TO) {
...@@ -209,16 +217,24 @@ public: ...@@ -209,16 +217,24 @@ public:
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
const TensorShape& output = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& input = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
// Since the implementation of Col2ImFunctor is ADD_TO, // Since the implementation of Col2ImFunctor is ADD_TO,
// this function only supports ADD_TO mode. // this function only supports ADD_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO); CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const TensorShape& output = inputs[0].shape(); const TensorShape& output = inputs[0].shape();
const TensorShape& filter = inputs[1].shape(); const TensorShape& filter = inputs[1].shape();
const TensorShape& input = outputs[0].shape(); const TensorShape& input = outputs[0].shape();
check(input, filter, output);
size_t batchSize = input[0]; size_t batchSize = input[0];
size_t inputChannels = input[1]; size_t inputChannels = input[1];
...@@ -295,13 +311,21 @@ public: ...@@ -295,13 +311,21 @@ public:
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
const TensorShape& output = inputs[0].shape();
const TensorShape& input = inputs[1].shape();
const TensorShape& filter = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
const TensorShape& output = inputs[0].shape(); const TensorShape& output = inputs[0].shape();
const TensorShape& input = inputs[1].shape(); const TensorShape& input = inputs[1].shape();
const TensorShape& filter = outputs[0].shape(); const TensorShape& filter = outputs[0].shape();
check(input, filter, output);
real beta; real beta;
if (outputs[0].getArgType() == ADD_TO) { if (outputs[0].getArgType() == ADD_TO) {
......
...@@ -90,14 +90,19 @@ public: ...@@ -90,14 +90,19 @@ public:
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { virtual void check(const BufferArgs& inputs,
CHECK_EQ(numInputs_, inputs.size()); const BufferArgs& outputs) override {
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& input = inputs[0].shape(); const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape(); const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape(); const TensorShape& output = outputs[0].shape();
check(input, filter, output); checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
check(inputs, outputs);
size_t batchSize = inputs[0].shape()[0]; size_t batchSize = inputs[0].shape()[0];
size_t inputChannels = inputs[0].shape()[1]; size_t inputChannels = inputs[0].shape()[1];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册