未验证 提交 7afb1df1 编写于 作者: W Wojciech Uss 提交者: GitHub

Decouple weights and bias from fc primitive in MKLDNN cache (#26708)

* decouple weights and bias from fc primitive in cache

* removed reduntant update of pointers
上级 f32ae272
......@@ -44,6 +44,7 @@ class FCPrimitiveFactory {
void ExecuteFcPrimitive(const LoDTensor* input, const Tensor* weights,
const Tensor* bias, LoDTensor* output,
const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) {
RecomputeOutputDims(ctx, input, weights, output);
// If primitive has already been created and cached, don't create new one,
......@@ -74,8 +75,8 @@ class FCPrimitiveFactory {
"input format is equal to ncw."));
}
// Transform weights to default MKL-DNN format
weights_ = TransposeWeights(weights);
weights_ = CreateWeightsMemory(weights);
// Since MKL-DNN has a lot of limitations on what the input/weights/output
// dimensions should be, to simplify the code, the creation of primitive
// descriptor has been divided into separate cases, based on the number
......@@ -112,10 +113,13 @@ class FCPrimitiveFactory {
// Quantize weights and reorder to format chosen by FC primitive descriptor.
QuantizeWeights(ctx, fc_prim_desc->weights_desc());
bias_ = CreateMemory<float>(fc_prim_desc->bias_desc(), bias);
bias_ = CreateMemoryToBeCached<float>(fc_prim_desc->bias_desc(), bias);
// If int8 is desired, quantize bias into 32-bit signed int
QuantizeBias(*fc_prim_desc, ctx);
// Store weights and bias in the mkldnn cache
CacheWeightsAndBias(dev_ctx, ctx);
// Based on format determined by inner_product, create output in desired
// memory format
output_ = CreateDstMemory(*fc_prim_desc, ctx, output);
......@@ -262,14 +266,15 @@ class FCPrimitiveFactory {
}
// Convert data from one data format to another
mkldnn::memory Reorder(const memory::desc& src_desc,
const memory::desc& dst_desc, void* src_data) {
std::shared_ptr<mkldnn::memory> Reorder(const memory::desc& src_desc,
const memory::desc& dst_desc,
void* src_data) {
auto src_mem = memory(src_desc, engine_, src_data);
auto dst_mem = memory(dst_desc, engine_);
auto dst_mem = std::make_shared<memory>(dst_desc, engine_);
auto reorder = mkldnn::reorder(src_mem, dst_mem);
auto reorder = mkldnn::reorder(src_mem, *dst_mem);
mkldnn::stream astream(engine_);
reorder.execute(astream, src_mem, dst_mem);
reorder.execute(astream, src_mem, *dst_mem);
astream.wait();
return dst_mem;
......@@ -277,9 +282,10 @@ class FCPrimitiveFactory {
// Convert data from one data format to another and rescale it.
// If the desired data type is (un)signed int8, quantization occurs here.
mkldnn::memory Reorder(const memory& src_mem, const memory::desc& dst_md,
const std::vector<float>& scale_data) {
mkldnn::memory dst_mem = mkldnn::memory(dst_md, engine_);
std::shared_ptr<mkldnn::memory> ReorderWithScale(
const std::shared_ptr<memory> src_mem, const memory::desc& dst_md,
const std::vector<float>& scale_data) {
auto dst_mem = std::make_shared<mkldnn::memory>(dst_md, engine_);
mkldnn::primitive_attr attributes;
// According to MKL-DNN's documentation mask determines along which
// dimensions should the scale be applied.
......@@ -289,11 +295,11 @@ class FCPrimitiveFactory {
// becuase we perform per-output-channel quantization
int mask = CreateMask(0, scale_data.size() > 1);
attributes.set_output_scales(mask, scale_data);
auto reorder = mkldnn::reorder(src_mem, dst_mem, attributes);
auto reorder = mkldnn::reorder(*src_mem, *dst_mem, attributes);
mkldnn::stream astream(engine_);
reorder.execute(astream,
{{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}});
{{MKLDNN_ARG_FROM, *src_mem}, {MKLDNN_ARG_TO, *dst_mem}});
astream.wait();
return dst_mem;
......@@ -323,16 +329,38 @@ class FCPrimitiveFactory {
return memory(desc, engine_, data);
}
// Transpose weights through MKL-DNN's reorder from io to oi format.
mkldnn::memory TransposeWeights(const Tensor* weights) {
template <typename T>
std::shared_ptr<mkldnn::memory> CreateMemoryToBeCached(
const mkldnn::memory::desc& desc, const Tensor* tensor) {
return CreateMemoryToBeCached(desc,
platform::to_void_cast<T>(tensor->data<T>()));
}
std::shared_ptr<mkldnn::memory> CreateMemoryToBeCached(
const mkldnn::memory::desc& desc, void* data) {
return std::make_shared<memory>(desc, engine_, data);
}
// Create weights memory and transform to default MKL-DNN format
std::shared_ptr<mkldnn::memory> CreateWeightsMemory(const Tensor* weights) {
auto dims = framework::vectorize(weights->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oi);
// Transpose weights through MKL-DNN's reorder from io to oi format.
return Reorder(src_desc, dst_desc,
platform::to_void_cast<float>(weights->data<float>()));
}
void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) {
const std::string key = platform::CreateKey(platform::ThreadIDasStr());
const std::string weights_key = key + ctx.InputName("W");
const std::string bias_key = key + ctx.InputName("Bias");
dev_ctx.SetBlob(weights_key, weights_);
dev_ctx.SetBlob(bias_key, bias_);
}
// Compute the bias scales so that its values correspond to the
// scale of data being an output of weights and input multiplication
std::vector<float> ComputeBiasScales(const ExecutionContext& ctx) {
......@@ -388,14 +416,14 @@ class FCPrimitiveFactory {
}
void QuantizeWeights(const ExecutionContext& ctx, memory::desc dst) {
weights_ =
Reorder(*weights_, dst, ctx.Attr<std::vector<float>>("Scale_weights"));
weights_ = ReorderWithScale(weights_, dst,
ctx.Attr<std::vector<float>>("Scale_weights"));
}
void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx) {
auto bias_scales = ComputeBiasScales(ctx);
bias_ = Reorder(*bias_, fc_prim_desc.bias_desc(), bias_scales);
bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales);
}
// Fuse relu into FC with activation type attribute has been set to 'relu'
......@@ -463,10 +491,10 @@ class FCPrimitiveFactory {
private:
const mkldnn::engine& engine_;
boost::optional<memory> bias_;
boost::optional<memory> input_;
boost::optional<memory> output_;
boost::optional<memory> weights_;
std::shared_ptr<memory> bias_;
std::shared_ptr<memory> weights_;
boost::optional<inner_product_forward> fc_;
};
......@@ -476,19 +504,13 @@ class FCPrimitiveFactory {
template <typename T_in, typename T_w, typename T_out>
static std::shared_ptr<FCPrimitiveFactory<T_in, T_w, T_out>>
GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx, const Tensor* input,
const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
const std::string key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(weights->dims()), ctx.OutputName("Out"));
const std::string& key) {
auto prim_creator =
std::static_pointer_cast<FCPrimitiveFactory<T_in, T_w, T_out>>(
dev_ctx.GetBlob(key));
if (prim_creator == nullptr) {
prim_creator =
std::make_shared<FCPrimitiveFactory<T_in, T_w, T_out>>(mkldnn_engine);
prim_creator = std::make_shared<FCPrimitiveFactory<T_in, T_w, T_out>>(
dev_ctx.GetEngine());
dev_ctx.SetBlob(key, prim_creator);
}
......@@ -498,24 +520,24 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
// Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float).
template <typename T_in, typename T_w>
static void ExecuteFc(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx, const LoDTensor* input,
static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
const Tensor* w, const Tensor* bias, LoDTensor* output,
const mkldnn::engine& mkldnn_engine, bool fuse_relu,
bool force_fp32_output) {
bool fuse_relu, bool force_fp32_output) {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const std::string prim_key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(w->dims()), ctx.OutputName("Out"));
constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
if (!is_int8 || force_fp32_output) {
GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, ctx, input, w, mkldnn_engine)
->ExecuteFcPrimitive(input, w, bias, output, ctx);
GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else if (fuse_relu) {
GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, ctx, input, w,
mkldnn_engine)
->ExecuteFcPrimitive(input, w, bias, output, ctx);
GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else {
GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, ctx, input, w,
mkldnn_engine)
->ExecuteFcPrimitive(input, w, bias, output, ctx);
GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
}
}
......@@ -526,9 +548,6 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T_in> {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace."));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto input = ctx.Input<LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
......@@ -537,8 +556,8 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T_in> {
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
ExecuteFc<T_in, T_w>(dev_ctx, ctx, input, w, bias, output, mkldnn_engine,
fuse_relu, force_fp32_output);
ExecuteFc<T_in, T_w>(ctx, input, w, bias, output, fuse_relu,
force_fp32_output);
output->set_layout(DataLayout::kMKLDNN);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册