提交 c6e010d0 编写于 作者: H hedaoyuan

Follow comments.

上级 2608c485
......@@ -104,19 +104,11 @@ public:
protected:
size_t getFilterHeight(const TensorShape& filter) const {
if (filter.ndims() == 5) {
return filter[3];
} else {
return filter[2];
}
filter[filter.ndims() - 2];
}
size_t getFilterWidth(const TensorShape& filter) const {
if (filter.ndims() == 5) {
return filter[4];
} else {
return filter[3];
}
filter[filter.ndims() - 1];
}
std::vector<size_t> strides_;
......
......@@ -296,9 +296,9 @@ public:
compareOutputs();
}
std::shared_ptr<FunctionBase> getCpuFunction() const { return function1_; }
std::shared_ptr<FunctionBase> getFunction1() const { return function1_; }
std::shared_ptr<FunctionBase> getGpuFunction() const { return function2_; }
std::shared_ptr<FunctionBase> getFunction2() const { return function2_; }
protected:
// only init cpu argument, gpu argument copy from cpu argument.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册