diff --git a/paddle/gserver/tests/MKLDNNTester.cpp b/paddle/gserver/tests/MKLDNNTester.cpp index d20215571dac807054bca4cc0ff88460515bb457..de1635be2af37cd0ba49010199a417090865b0e4 100644 --- a/paddle/gserver/tests/MKLDNNTester.cpp +++ b/paddle/gserver/tests/MKLDNNTester.cpp @@ -345,10 +345,10 @@ void MKLDNNTester::run(const TestConfig& dnn, return; } - // After run some iters, the mkldnn weight has been stored in dnnLayer - // and we can also get the mkldnn weight paramter header format - // Weight param should always be index 0 (and bias index 1). - // TODO(TJ): should also considerate mean and var format when batchnorm ready + // After run some iterations, the mkldnn weight has been stored in dnnLayer + // and we can also get the mkldnn weight parameter header format. + // Weight parameter should always be index 0 (and bias index 1). + // TODO(TJ): should also consider mean and var format when batchnorm ready int dnnWgtFmt = parameters_[DNN][0]->getHeaderFormat(); int refWgtFmt = parameters_[REF][0]->getHeaderFormat(); if (dnnWgtFmt == refWgtFmt) { diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index cffd3aa92e30ba66cd29f66fec8febb741bc5e85..e31cbc3dee6c57851c241e117dbbd9b701db9d2c 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -35,9 +35,17 @@ limitations under the License. */ namespace paddle { typedef enum { - PARAM_FORMAT_ORIGINAL = 0, // the paddle original basic format - PARAM_FORMAT_MKLDNN_OI, // the mkldnn format oi - PARAM_FORMAT_ITEMS, // the total format items numbers + /// The paddle original basic format + PARAM_FORMAT_ORIGINAL = 0, + + /// See mkldnn_memory_format_t in + /// https://github.com/01org/mkl-dnn/blob/master/include/mkldnn_types.h + /// for a detailed description. + /// 2D weights tensor in the format (output channels, input channels). + PARAM_FORMAT_MKLDNN_OI, + + /// The total format items numbers + PARAM_FORMAT_ITEMS, } PARAM_FORMAT; class SparsePrefetchRowCpuMatrix; @@ -256,19 +264,19 @@ public: }; /** - * @brief Is the header supported + * @brief Is the header format supported. */ static bool isHeaderFormatSupported(int32_t fmt) { return fmt < PARAM_FORMAT_ITEMS; } /** - * @brief Get the format in header + * @brief Get the format in header. */ int getHeaderFormat() { return headerFormat_; } /** - * @brief Set the format in header + * @brief Set the format in header. */ void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; } @@ -343,7 +351,7 @@ protected: bool updated_; SparseFormat format_; - // The header format for saving or loading param + /// The header format for saving or loading param int32_t headerFormat_; std::vector> updaterHooks_;