提交 07d16e3e 编写于 作者: T tensor-tang

refine comments

上级 635b8672
......@@ -345,10 +345,10 @@ void MKLDNNTester::run(const TestConfig& dnn,
return;
}
// After run some iters, the mkldnn weight has been stored in dnnLayer
// and we can also get the mkldnn weight paramter header format
// Weight param should always be index 0 (and bias index 1).
// TODO(TJ): should also considerate mean and var format when batchnorm ready
// After run some iterations, the mkldnn weight has been stored in dnnLayer
// and we can also get the mkldnn weight parameter header format.
// Weight parameter should always be index 0 (and bias index 1).
// TODO(TJ): should also consider mean and var format when batchnorm ready
int dnnWgtFmt = parameters_[DNN][0]->getHeaderFormat();
int refWgtFmt = parameters_[REF][0]->getHeaderFormat();
if (dnnWgtFmt == refWgtFmt) {
......
......@@ -35,9 +35,17 @@ limitations under the License. */
namespace paddle {
typedef enum {
PARAM_FORMAT_ORIGINAL = 0, // the paddle original basic format
PARAM_FORMAT_MKLDNN_OI, // the mkldnn format oi
PARAM_FORMAT_ITEMS, // the total format items numbers
/// The paddle original basic format
PARAM_FORMAT_ORIGINAL = 0,
/// See mkldnn_memory_format_t in
/// https://github.com/01org/mkl-dnn/blob/master/include/mkldnn_types.h
/// for a detailed description.
/// 2D weights tensor in the format (output channels, input channels).
PARAM_FORMAT_MKLDNN_OI,
/// The total format items numbers
PARAM_FORMAT_ITEMS,
} PARAM_FORMAT;
class SparsePrefetchRowCpuMatrix;
......@@ -256,19 +264,19 @@ public:
};
/**
* @brief Is the header supported
* @brief Is the header format supported.
*/
static bool isHeaderFormatSupported(int32_t fmt) {
return fmt < PARAM_FORMAT_ITEMS;
}
/**
* @brief Get the format in header
* @brief Get the format in header.
*/
int getHeaderFormat() { return headerFormat_; }
/**
* @brief Set the format in header
* @brief Set the format in header.
*/
void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; }
......@@ -343,7 +351,7 @@ protected:
bool updated_;
SparseFormat format_;
// The header format for saving or loading param
/// The header format for saving or loading param
int32_t headerFormat_;
std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册