提交 07d16e3e 编写于 作者: T tensor-tang

refine comments

上级 635b8672
...@@ -345,10 +345,10 @@ void MKLDNNTester::run(const TestConfig& dnn, ...@@ -345,10 +345,10 @@ void MKLDNNTester::run(const TestConfig& dnn,
return; return;
} }
// After run some iters, the mkldnn weight has been stored in dnnLayer // After run some iterations, the mkldnn weight has been stored in dnnLayer
// and we can also get the mkldnn weight paramter header format // and we can also get the mkldnn weight parameter header format.
// Weight param should always be index 0 (and bias index 1). // Weight parameter should always be index 0 (and bias index 1).
// TODO(TJ): should also considerate mean and var format when batchnorm ready // TODO(TJ): should also consider mean and var format when batchnorm ready
int dnnWgtFmt = parameters_[DNN][0]->getHeaderFormat(); int dnnWgtFmt = parameters_[DNN][0]->getHeaderFormat();
int refWgtFmt = parameters_[REF][0]->getHeaderFormat(); int refWgtFmt = parameters_[REF][0]->getHeaderFormat();
if (dnnWgtFmt == refWgtFmt) { if (dnnWgtFmt == refWgtFmt) {
......
...@@ -35,9 +35,17 @@ limitations under the License. */ ...@@ -35,9 +35,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
typedef enum { typedef enum {
PARAM_FORMAT_ORIGINAL = 0, // the paddle original basic format /// The paddle original basic format
PARAM_FORMAT_MKLDNN_OI, // the mkldnn format oi PARAM_FORMAT_ORIGINAL = 0,
PARAM_FORMAT_ITEMS, // the total format items numbers
/// See mkldnn_memory_format_t in
/// https://github.com/01org/mkl-dnn/blob/master/include/mkldnn_types.h
/// for a detailed description.
/// 2D weights tensor in the format (output channels, input channels).
PARAM_FORMAT_MKLDNN_OI,
/// The total format items numbers
PARAM_FORMAT_ITEMS,
} PARAM_FORMAT; } PARAM_FORMAT;
class SparsePrefetchRowCpuMatrix; class SparsePrefetchRowCpuMatrix;
...@@ -256,19 +264,19 @@ public: ...@@ -256,19 +264,19 @@ public:
}; };
/** /**
* @brief Is the header supported * @brief Is the header format supported.
*/ */
static bool isHeaderFormatSupported(int32_t fmt) { static bool isHeaderFormatSupported(int32_t fmt) {
return fmt < PARAM_FORMAT_ITEMS; return fmt < PARAM_FORMAT_ITEMS;
} }
/** /**
* @brief Get the format in header * @brief Get the format in header.
*/ */
int getHeaderFormat() { return headerFormat_; } int getHeaderFormat() { return headerFormat_; }
/** /**
* @brief Set the format in header * @brief Set the format in header.
*/ */
void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; } void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; }
...@@ -343,7 +351,7 @@ protected: ...@@ -343,7 +351,7 @@ protected:
bool updated_; bool updated_;
SparseFormat format_; SparseFormat format_;
// The header format for saving or loading param /// The header format for saving or loading param
int32_t headerFormat_; int32_t headerFormat_;
std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_; std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册