提交 2d4c66d4 编写于 作者: T tensor-tang

add comments and todo lists

上级 a475a57d
...@@ -29,7 +29,10 @@ protected: ...@@ -29,7 +29,10 @@ protected:
// input layer size, can not be change after init // input layer size, can not be change after init
size_t iLayerSize_; // == ic * ih * iw size_t iLayerSize_; // == ic * ih * iw
// if has already init the weight
bool hasInitedWgt_; bool hasInitedWgt_;
// if input layer has image size info (ih>1 && iw>1)
bool hasSpatial_; bool hasSpatial_;
// fc weight and bias // fc weight and bias
......
...@@ -123,7 +123,8 @@ void MKLDNNTester::checkForward() { ...@@ -123,7 +123,8 @@ void MKLDNNTester::checkForward() {
} }
void MKLDNNTester::checkBackwardData() { void MKLDNNTester::checkBackwardData() {
const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm"; // TODO(TJ): uncomment me when batch norm ready
// const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm";
for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) { for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) {
const MatrixPtr& dnnDiff = dataLayers_[DNN][i]->getOutputGrad(); const MatrixPtr& dnnDiff = dataLayers_[DNN][i]->getOutputGrad();
const MatrixPtr& refDiff = dataLayers_[REF][i]->getOutputGrad(); const MatrixPtr& refDiff = dataLayers_[REF][i]->getOutputGrad();
...@@ -134,10 +135,11 @@ void MKLDNNTester::checkBackwardData() { ...@@ -134,10 +135,11 @@ void MKLDNNTester::checkBackwardData() {
double delta = compareMatrix(dnnDiff, refDiff); double delta = compareMatrix(dnnDiff, refDiff);
EXPECT_LE(fabs(delta), eps_); EXPECT_LE(fabs(delta), eps_);
if (isBN) { // TODO(TJ): uncomment me when batch norm ready
// the other two inputs in batch norm are for moving mean and var // if (isBN) {
break; // // the other two inputs in batch norm are for moving mean and var
} // break;
// }
} }
} }
......
...@@ -27,9 +27,9 @@ namespace paddle { ...@@ -27,9 +27,9 @@ namespace paddle {
*/ */
class MKLDNNTester { class MKLDNNTester {
enum { enum {
DNN = 0, DNN = 0, // MKLDNN layer
REF = 1, REF = 1, // Reference layer
NUM = 2, NUM = 2, // Number of total
}; };
protected: protected:
...@@ -107,7 +107,8 @@ private: ...@@ -107,7 +107,8 @@ private:
* Get delta percent * Get delta percent
* if many(>failRate) wrong(abs(dnn-ref)/abs(ref)>thres) points return the * if many(>failRate) wrong(abs(dnn-ref)/abs(ref)>thres) points return the
* max(diff/ref) * max(diff/ref)
* else return sum(abs(a-b)) / sum(abs(b)) should smaller than eps * else return sum(abs(a-b)) / sum(abs(b))
* The return value should smaller than eps when passing.
*/ */
double getDelta(const real* d1, double getDelta(const real* d1,
const real* d2, const real* d2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册