提交 d4413a54 编写于 作者: A Adam 提交者: Tao Luo

Add common CreateKey for mkldnn handlers (#19767)

test=develop
上级 0d6ea529
......@@ -163,8 +163,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type);
const std::string key = platform::ReorderMKLDNNHandler::GetHash(
in_tz, in_format, out_format, std::to_string(in_type));
const std::string key = platform::CreateKey(in_tz, in_format, out_format,
std::to_string(in_type));
platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx,
cpu_engine, key);
......
......@@ -70,7 +70,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto out_format = platform::MKLDNNFormatForSize(
x_dims.size(), MKLDNNMemoryFormat::nchw);
const std::string key = platform::ReorderMKLDNNHandler::GetHash(
const std::string key = platform::CreateKey(
src_x_tz, x->format(), out_format, std::to_string(in_type));
platform::ReorderMKLDNNHandler handler(src_x_tz, x->type(), in_type,
......@@ -136,7 +136,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
std::vector<memory::primitive_desc> srcs_pd;
std::vector<float> scales = {1.0f, 1.0f};
const std::string key = platform::GetHash(
const std::string key = platform::CreateKey(
src_x_tz, ctx.op().Output("Out") + std::to_string(x->format()) +
std::to_string(y->format()));
......
......@@ -27,20 +27,6 @@ using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
namespace {
std::string gethash(const mkldnn::memory::dims &operand_dims,
const mkldnn::algorithm algorithm) {
auto dim2str = [](const mkldnn::memory::dims &operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dim2str(operand_dims) + std::to_string(algorithm);
}
} // namespace
template <typename Functor>
class MKLDNNActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......
......@@ -120,22 +120,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
return batch_norm_p;
}
static std::string GetHash(const memory::dims &input_dims, float epsilon,
unsigned flag, bool is_test,
MKLDNNMemoryFormat format,
const std::string &suffix = "") {
auto dims2str = [](const memory::dims &operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dims2str(input_dims) + std::to_string(epsilon) +
std::to_string(flag) + std::to_string(is_test) +
std::to_string(format) + suffix;
}
private:
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_pd_;
};
......@@ -236,8 +220,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
// keys for backward pass
const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, global_stats, input_format,
const std::string key =
platform::CreateKey(src_tz, epsilon, flags, global_stats, input_format,
ctx.op().Output("SavedMean"));
BatchNormMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
......@@ -369,15 +353,14 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
unsigned flags = mkldnn::use_scale_shift;
// keys from forward pass
const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, false, input_format,
const std::string key =
platform::CreateKey(src_tz, epsilon, flags, false, input_format,
ctx.op().Input("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
// keys for primitives reuse
const std::string key_with_hash =
key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false,
input_format);
key + platform::CreateKey(src_tz, epsilon, flags, false, input_format);
const std::string key_batch_norm_bwd_p =
key_with_hash + "@batch_norm_bwd_p";
const std::string key_batch_norm_src_mem_p =
......
......@@ -66,27 +66,6 @@ static const mkldnn::engine& GetMKLDNNEngine(
return dev_ctx.GetEngine();
}
std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const std::vector<const Tensor*> multi_input,
const int64_t& concat_axis, const memory::data_type& dt) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
for (size_t i = 0; i < multi_input.size(); i++) {
platform::AppendKeyDims(
&key, paddle::framework::vectorize<int>(multi_input[i]->dims()));
}
platform::AppendKey(&key, std::to_string(concat_axis));
platform::AppendKey(&key, ctx.op().Output("Out"));
platform::AppendKey(&key, std::to_string(dt));
platform::AppendKey(&key, std::to_string(multi_input[0]->format()));
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
platform::AppendKey(&key, "-t:");
platform::AppendKey(&key, platform::ThreadIDasStr());
}
return key;
}
template <typename T>
class ConcatPrimitiveFactory {
public:
......@@ -175,7 +154,10 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
ConcatPrimitiveFactory<T> prim_creator;
std::string key = CreateKey(ctx, multi_input, concat_axis, dt);
std::string key = platform::CreateKey(
paddle::framework::vectorize<int>(multi_input[0]->dims()), concat_axis,
ctx.op().Output("Out"), dt, multi_input[0]->format(),
platform::ThreadIDasStr());
const std::string key_prim = key + "@concat_p";
const std::string key_concat_pd = key + "@concat_pd";
const std::string key_srcs = key + "@concat_srcs";
......
......@@ -190,7 +190,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_tz = paddle::framework::vectorize<int>(output->dims());
// Get unique name for storing MKLDNN primitives
const std::string key = platform::ConvMKLDNNHandler::GetHash(
const std::string key = platform::CreateKey(
src_tz, weights_tz, fuse_activation, strides, paddings, dilations,
groups, ctx.op().Input("Input") + ctx.op().Input("Filter"));
......@@ -415,10 +415,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(input->type());
// Get unique name for storing MKLDNN primitives
std::string key;
key.reserve(MaxKeyLength);
platform::ConvMKLDNNHandler::CreateKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
const std::string key = platform::CreateKey(
src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_activation, fuse_residual_conn,
ctx.op().Input("Input") + ctx.op().Input("Filter"));
......@@ -715,7 +713,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "input" and "Filter" variable
// as well as attributes of primitive to be created
// This name will be used as key when saving info into device context
const std::string key = platform::ConvMKLDNNHandler::GetHash(
const std::string key = platform::CreateKey(
src_tz, weights_tz, "", strides, paddings, dilations, groups,
ctx.op().Input("Input") + ctx.op().Input("Filter"));
......
......@@ -127,9 +127,9 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_tz = paddle::framework::vectorize<int>(output->dims());
// Get unique name for storing MKLDNN primitives
const std::string key = platform::ConvTransposeMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output"));
const std::string key =
platform::CreateKey(src_tz, weights_tz, strides, paddings, dilations,
groups, ctx.op().Output("Output"));
std::vector<mkldnn::primitive> pipeline;
......
......@@ -31,18 +31,6 @@ using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const mkldnn::memory::data_type& src_dt,
const std::vector<int>& src_tz, const float scale_data) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::AppendKey(&key, std::to_string(src_dt));
platform::AppendKeyDims(&key, src_tz);
platform::AppendKey(&key, std::to_string(scale_data));
platform::AppendKey(&key, ctx.op().Output("Output"));
return key;
}
template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> {
public:
......@@ -64,7 +52,8 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format();
std::string key = CreateKey(ctx, src_dt, src_tz, reorder_scale[0]);
std::string key = platform::CreateKey(src_dt, src_tz, reorder_scale[0],
ctx.op().Output("Output"));
const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem";
const std::string key_dst_mem = key + "@dst_mem";
......
......@@ -221,25 +221,14 @@ class FCPrimitiveFactory {
boost::optional<inner_product_forward> fc_;
};
static std::string GetHash(const Tensor* input, const Tensor* weights,
const std::string& suffix) {
auto dim2str = [](const DDim& operand_dims) {
std::string str = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
str += std::to_string(operand_dims[i]) + "-";
}
return str;
};
return std::to_string((unsigned)input->format()) + dim2str(weights->dims()) +
suffix;
}
template <typename T>
std::shared_ptr<FCPrimitiveFactory<T>> GetPrimitiveFactory(
const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx,
const Tensor* input, const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
const std::string key = GetHash(input, weights, ctx.op().Output("Out"));
const std::string key = platform::CreateKey(
input->format(), framework::vectorize<int>(weights->dims()),
ctx.op().Output("Out"));
auto prim_creator =
std::static_pointer_cast<FCPrimitiveFactory<T>>(dev_ctx.GetBlob(key));
......
......@@ -62,7 +62,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto md = paddle::platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), x->format());
const std::string key = platform::LRNMKLDNNHandler::GetHash(
const std::string key = platform::CreateKey(
dims, n, alpha, beta, k, x->format(), ctx.op().Output("Out"));
platform::LRNMKLDNNHandler handler(ctx.Attr<bool>("is_test"), dev_ctx,
......@@ -121,7 +121,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto dims = paddle::framework::vectorize<int>(x->dims());
const std::string key = platform::LRNMKLDNNHandler::GetHash(
const std::string key = platform::CreateKey(
dims, n, alpha, beta, k, x->format(), ctx.op().Input("Out"));
platform::LRNMKLDNNHandler handler(false, dev_ctx, mkldnn_engine, key);
......
......@@ -332,33 +332,17 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> {
}
};
static std::string GetHash(const Tensor *input_x, const Tensor *input_y,
const std::string &suffix) {
auto dim2str = [](const DDim &operand_dims) {
std::string str = "";
for (int i = 0; i < operand_dims.size(); ++i) {
str += std::to_string(operand_dims[i]) + "-";
}
return str;
};
std::string hash = std::to_string((unsigned)input_x->format()) +
std::to_string((unsigned)input_x->type()) +
dim2str(input_x->dims()) +
std::to_string((unsigned)input_y->format()) +
std::to_string((unsigned)input_y->type()) +
dim2str(input_y->dims()) + suffix;
return hash;
}
/* OT: output data type */
template <typename XT, typename YT, typename OT>
std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
const MKLDNNDeviceContext &dev_ctx, const ExecutionContext &ctx,
const Tensor *input_x, const Tensor *input_y,
const mkldnn::engine &mkldnn_engine, bool enable_quant) {
const std::string key = GetHash(input_x, input_y, ctx.op().Output("Out"));
const std::string key = platform::CreateKey(
input_x->format(), input_x->type(),
framework::vectorize<int>(input_x->dims()), input_y->format(),
input_y->type(), framework::vectorize<int>(input_y->dims()),
ctx.op().Output("Out"));
auto prim_creator = std::static_pointer_cast<MulPrimitiveFactory<XT, YT, OT>>(
dev_ctx.GetBlob(key));
......
......@@ -79,9 +79,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(input->type());
auto fmt = input->format();
const std::string key = platform::PoolingMKLDNNHandler::GetHash(
src_tz, pooling_type, ksize, strides, paddings, dt, fmt,
ctx.op().Output("Out"));
const std::string key =
platform::CreateKey(src_tz, pooling_type, ksize, strides, paddings, dt,
fmt, ctx.op().Output("Out"));
platform::PoolingMKLDNNHandler handler(pooling_type, dt,
ctx.Attr<bool>("is_test"), dev_ctx,
......@@ -171,7 +171,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const std::string key = platform::PoolingMKLDNNHandler::GetHash(
const std::string key = platform::CreateKey(
diff_src_tz, pooling_type, ksize, strides, paddings,
memory::data_type::f32, in_x->format(), ctx.op().Input("Out"));
......
......@@ -30,18 +30,6 @@ using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const std::vector<int>& src_tz, const float scale_data,
const bool is_negative) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::AppendKeyDims(&key, src_tz);
platform::AppendKey(&key, std::to_string(scale_data));
platform::AppendKey(&key, std::to_string(is_negative));
platform::AppendKey(&key, ctx.op().Output("Output"));
return key;
}
template <typename T>
class QuantOpKernel : public framework::OpKernel<T> {
public:
......@@ -60,7 +48,8 @@ class QuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>();
bool is_negative = ctx.Attr<bool>("is_negative_input");
std::string key = CreateKey(ctx, src_tz, scale_data, is_negative);
std::string key = platform::CreateKey(src_tz, scale_data, is_negative,
ctx.op().Output("Output"));
const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem";
const std::string key_dst_mem = key + "@dst_mem";
......
......@@ -40,7 +40,7 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
platform::GetHash(dims, uniq_name)),
platform::CreateKey(dims, uniq_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
......@@ -53,7 +53,7 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
platform::GetHash(dims, uniq_name)),
platform::CreateKey(dims, uniq_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
......@@ -218,6 +218,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto dst_tz = src_tz;
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx,
ctx.GetPlace(), ctx.op().Output("Out"));
// Currently only NC data format is supported
......
......@@ -45,9 +45,9 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto nchw_tz = paddle::framework::vectorize<int>(input->dims());
const std::string key = platform::TransposeMKLDNNHandler::GetHash(
nchw_tz, axis,
ctx.op().Output("Out") + std::to_string(input->format()));
const std::string key =
platform::CreateKey(nchw_tz, axis, ctx.op().Output("Out") +
std::to_string(input->format()));
platform::TransposeMKLDNNHandler handler(nchw_tz, axis, dev_ctx,
mkldnn_engine, key);
......@@ -99,7 +99,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto nchw_tz = paddle::framework::vectorize<int>(out_grad->dims());
const std::string key = platform::TransposeMKLDNNHandler::GetHash(
const std::string key = platform::CreateKey(
nchw_tz, axis, ctx.op().Output(framework::GradVarName("X")));
platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx,
......
......@@ -184,28 +184,31 @@ inline std::string ThreadIDasStr(void) {
std::hash<std::thread::id>()(std::this_thread::get_id()));
}
inline std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
template <typename T>
inline void AppendKey(std::string* key, const T& num) {
key->append(std::to_string(num));
}
inline void AppendKey(std::string* key, const std::string& s) {
key->append(s);
inline void AppendKey(std::string* key, const std::string& str) {
key->append(str);
}
inline std::string GetHash(const mkldnn::memory::dims& operand_dims,
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
}
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
inline void AppendKeyDims(std::string* key, const mkldnn::memory::dims& dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
inline void AppendKey(std::string* key, const std::vector<int>& dims) {
for (size_t i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i]));
}
}
template <typename... ArgTypes>
inline std::string CreateKey(ArgTypes&&... args) {
std::string key;
key.reserve(256);
using expand_type = int[];
expand_type{0, (AppendKey(&key, args), 0)...};
return key;
}
} // namespace platform
} // namespace paddle
......@@ -198,9 +198,6 @@ class MKLDNNHandler {
mkldnn::engine engine_;
std::string key_;
std::string key_common_;
public:
static constexpr int MaxKeyLength = 256;
};
class SumMKLDNNHandler : public MKLDNNHandler {
......@@ -267,10 +264,9 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
platform::Place cpu_place,
const std::string& unique_name)
: platform::MKLDNNHandler(
dev_ctx, dev_ctx.GetEngine(),
platform::ActivationMKLDNNHandler<T>::GetHash(
dims, algorithm, fmt, alpha, beta, unique_name)),
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
platform::CreateKey(dims, algorithm, fmt, alpha,
beta, unique_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
......@@ -288,10 +284,9 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
platform::Place cpu_place,
const std::string& unique_name)
: platform::MKLDNNHandler(
dev_ctx, dev_ctx.GetEngine(),
platform::ActivationMKLDNNHandler<T>::GetHash(
dims, algorithm, fmt, alpha, beta, unique_name)),
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
platform::CreateKey(dims, algorithm, fmt, alpha,
beta, unique_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
......@@ -383,21 +378,6 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
return eltwise_bwd_p;
}
static std::string GetHash(const memory::dims& input_dims,
const mkldnn::algorithm algorithm,
const MKLDNNMemoryFormat fmt, const float alpha,
const float beta, const std::string& suffix) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::AppendKeyDims(&key, input_dims);
platform::AppendKey(&key, std::to_string(algorithm));
platform::AppendKey(&key, std::to_string(fmt));
platform::AppendKey(&key, std::to_string(alpha));
platform::AppendKey(&key, std::to_string(beta));
platform::AppendKey(&key, suffix);
return key;
}
protected:
void AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind,
mkldnn::algorithm algorithm,
......@@ -597,22 +577,6 @@ class LRNMKLDNNHandler : public MKLDNNHandler {
return lrn_bwd_p;
}
static std::string GetHash(const memory::dims& input_dims, const int n,
const float alpha, const float beta, const float k,
const MKLDNNMemoryFormat& fmt,
const std::string& suffix) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::AppendKeyDims(&key, input_dims);
platform::AppendKey(&key, std::to_string(n));
platform::AppendKey(&key, std::to_string(alpha));
platform::AppendKey(&key, std::to_string(beta));
platform::AppendKey(&key, std::to_string(k));
platform::AppendKey(&key, std::to_string(fmt));
platform::AppendKey(&key, suffix);
return key;
}
private:
bool is_test_;
std::shared_ptr<mkldnn::lrn_forward::primitive_desc> fwd_pd_;
......@@ -790,24 +754,6 @@ class PoolingMKLDNNHandler : public MKLDNNHandler {
return pooling_bwd_p;
}
static std::string GetHash(
const memory::dims& input_dims, const std::string& pooling_type,
const std::vector<int>& ksize, const std::vector<int>& strides,
const std::vector<int>& paddings, const memory::data_type& dt,
const MKLDNNMemoryFormat& fmt, const std::string& suffix) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::AppendKeyDims(&key, input_dims);
platform::AppendKey(&key, pooling_type);
platform::AppendKeyDims(&key, ksize);
platform::AppendKeyDims(&key, strides);
platform::AppendKeyDims(&key, paddings);
platform::AppendKey(&key, std::to_string(dt));
platform::AppendKey(&key, std::to_string(fmt));
platform::AppendKey(&key, suffix);
return key;
}
private:
static inline int ComputeCeiledOutput(int input_size, int kernel_size,
int padding, int stride) {
......@@ -905,12 +851,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
return transpose_p;
}
static std::string GetHash(std::vector<int>& shape, // NOLINT
std::vector<int>& axis, // NOLINT
const std::string& suffix) {
return dims2str(shape) + dims2str(axis) + suffix;
}
protected:
mkldnn_memory_desc_t Axis2MemoryDesc(std::vector<int>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
......@@ -999,14 +939,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
return reorder_p;
}
static std::string GetHash(std::vector<int>& shape, // NOLINT
MKLDNNMemoryFormat in_fmt,
MKLDNNMemoryFormat out_fmt,
const std::string& suffix) {
return dims2str(shape) + std::to_string(in_fmt) + "->" +
std::to_string(out_fmt) + "#" + suffix;
}
private:
std::vector<int> dims_;
framework::proto::VarType::Type vtype_;
......@@ -1346,58 +1278,6 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
return conv_bwd_data_p;
}
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT
mkldnn::memory::dims& weights_dims, // NOLINT
const std::string& fuse_activation, // NOLINT
std::vector<int>& strides, // NOLINT
std::vector<int>& paddings, // NOLINT
std::vector<int>& dilations, // NOLINT
int groups, const std::string& suffix) {
return dims2str(input_dims) + dims2str(weights_dims) + fuse_activation +
dims2str(strides) + dims2str(paddings) + dims2str(dilations) +
std::to_string(groups) + suffix;
}
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT
mkldnn::memory::dims& weights_dims, // NOLINT
std::vector<int>& strides, // NOLINT
std::vector<int>& paddings, // NOLINT
std::vector<int>& dilations, // NOLINT
int groups, const std::string& suffix) {
return dims2str(input_dims) + dims2str(weights_dims) + dims2str(strides) +
dims2str(paddings) + dims2str(dilations) + std::to_string(groups) +
suffix;
}
static void CreateKey(
std::string* key, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations,
const int& groups, const mkldnn::memory::data_type& srcdt,
const MKLDNNMemoryFormat& format, const std::string& fuse_activation,
const bool& residual, const std::string& suffix) {
AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims);
AppendKeyDims(key, strides);
AppendKeyDims(key, paddings);
AppendKeyDims(key, dilations);
AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format));
AppendKey(key, fuse_activation);
AppendKey(key, std::to_string(residual));
AppendKey(key, suffix);
}
private:
std::shared_ptr<typename forward_t::primitive_desc> conv_pd_;
std::shared_ptr<typename backward_weights_t::primitive_desc>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册