未验证 提交 7afb1df1 编写于 作者: W Wojciech Uss 提交者: GitHub

Decouple weights and bias from fc primitive in MKLDNN cache (#26708)

* decouple weights and bias from fc primitive in cache

* removed reduntant update of pointers
上级 f32ae272
...@@ -44,6 +44,7 @@ class FCPrimitiveFactory { ...@@ -44,6 +44,7 @@ class FCPrimitiveFactory {
void ExecuteFcPrimitive(const LoDTensor* input, const Tensor* weights, void ExecuteFcPrimitive(const LoDTensor* input, const Tensor* weights,
const Tensor* bias, LoDTensor* output, const Tensor* bias, LoDTensor* output,
const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) { const ExecutionContext& ctx) {
RecomputeOutputDims(ctx, input, weights, output); RecomputeOutputDims(ctx, input, weights, output);
// If primitive has already been created and cached, don't create new one, // If primitive has already been created and cached, don't create new one,
...@@ -74,8 +75,8 @@ class FCPrimitiveFactory { ...@@ -74,8 +75,8 @@ class FCPrimitiveFactory {
"input format is equal to ncw.")); "input format is equal to ncw."));
} }
// Transform weights to default MKL-DNN format weights_ = CreateWeightsMemory(weights);
weights_ = TransposeWeights(weights);
// Since MKL-DNN has a lot of limitations on what the input/weights/output // Since MKL-DNN has a lot of limitations on what the input/weights/output
// dimensions should be, to simplify the code, the creation of primitive // dimensions should be, to simplify the code, the creation of primitive
// descriptor has been divided into separate cases, based on the number // descriptor has been divided into separate cases, based on the number
...@@ -112,10 +113,13 @@ class FCPrimitiveFactory { ...@@ -112,10 +113,13 @@ class FCPrimitiveFactory {
// Quantize weights and reorder to format chosen by FC primitive descriptor. // Quantize weights and reorder to format chosen by FC primitive descriptor.
QuantizeWeights(ctx, fc_prim_desc->weights_desc()); QuantizeWeights(ctx, fc_prim_desc->weights_desc());
bias_ = CreateMemory<float>(fc_prim_desc->bias_desc(), bias); bias_ = CreateMemoryToBeCached<float>(fc_prim_desc->bias_desc(), bias);
// If int8 is desired, quantize bias into 32-bit signed int // If int8 is desired, quantize bias into 32-bit signed int
QuantizeBias(*fc_prim_desc, ctx); QuantizeBias(*fc_prim_desc, ctx);
// Store weights and bias in the mkldnn cache
CacheWeightsAndBias(dev_ctx, ctx);
// Based on format determined by inner_product, create output in desired // Based on format determined by inner_product, create output in desired
// memory format // memory format
output_ = CreateDstMemory(*fc_prim_desc, ctx, output); output_ = CreateDstMemory(*fc_prim_desc, ctx, output);
...@@ -262,14 +266,15 @@ class FCPrimitiveFactory { ...@@ -262,14 +266,15 @@ class FCPrimitiveFactory {
} }
// Convert data from one data format to another // Convert data from one data format to another
mkldnn::memory Reorder(const memory::desc& src_desc, std::shared_ptr<mkldnn::memory> Reorder(const memory::desc& src_desc,
const memory::desc& dst_desc, void* src_data) { const memory::desc& dst_desc,
void* src_data) {
auto src_mem = memory(src_desc, engine_, src_data); auto src_mem = memory(src_desc, engine_, src_data);
auto dst_mem = memory(dst_desc, engine_); auto dst_mem = std::make_shared<memory>(dst_desc, engine_);
auto reorder = mkldnn::reorder(src_mem, dst_mem); auto reorder = mkldnn::reorder(src_mem, *dst_mem);
mkldnn::stream astream(engine_); mkldnn::stream astream(engine_);
reorder.execute(astream, src_mem, dst_mem); reorder.execute(astream, src_mem, *dst_mem);
astream.wait(); astream.wait();
return dst_mem; return dst_mem;
...@@ -277,9 +282,10 @@ class FCPrimitiveFactory { ...@@ -277,9 +282,10 @@ class FCPrimitiveFactory {
// Convert data from one data format to another and rescale it. // Convert data from one data format to another and rescale it.
// If the desired data type is (un)signed int8, quantization occurs here. // If the desired data type is (un)signed int8, quantization occurs here.
mkldnn::memory Reorder(const memory& src_mem, const memory::desc& dst_md, std::shared_ptr<mkldnn::memory> ReorderWithScale(
const std::shared_ptr<memory> src_mem, const memory::desc& dst_md,
const std::vector<float>& scale_data) { const std::vector<float>& scale_data) {
mkldnn::memory dst_mem = mkldnn::memory(dst_md, engine_); auto dst_mem = std::make_shared<mkldnn::memory>(dst_md, engine_);
mkldnn::primitive_attr attributes; mkldnn::primitive_attr attributes;
// According to MKL-DNN's documentation mask determines along which // According to MKL-DNN's documentation mask determines along which
// dimensions should the scale be applied. // dimensions should the scale be applied.
...@@ -289,11 +295,11 @@ class FCPrimitiveFactory { ...@@ -289,11 +295,11 @@ class FCPrimitiveFactory {
// becuase we perform per-output-channel quantization // becuase we perform per-output-channel quantization
int mask = CreateMask(0, scale_data.size() > 1); int mask = CreateMask(0, scale_data.size() > 1);
attributes.set_output_scales(mask, scale_data); attributes.set_output_scales(mask, scale_data);
auto reorder = mkldnn::reorder(src_mem, dst_mem, attributes); auto reorder = mkldnn::reorder(*src_mem, *dst_mem, attributes);
mkldnn::stream astream(engine_); mkldnn::stream astream(engine_);
reorder.execute(astream, reorder.execute(astream,
{{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}}); {{MKLDNN_ARG_FROM, *src_mem}, {MKLDNN_ARG_TO, *dst_mem}});
astream.wait(); astream.wait();
return dst_mem; return dst_mem;
...@@ -323,16 +329,38 @@ class FCPrimitiveFactory { ...@@ -323,16 +329,38 @@ class FCPrimitiveFactory {
return memory(desc, engine_, data); return memory(desc, engine_, data);
} }
// Transpose weights through MKL-DNN's reorder from io to oi format. template <typename T>
mkldnn::memory TransposeWeights(const Tensor* weights) { std::shared_ptr<mkldnn::memory> CreateMemoryToBeCached(
const mkldnn::memory::desc& desc, const Tensor* tensor) {
return CreateMemoryToBeCached(desc,
platform::to_void_cast<T>(tensor->data<T>()));
}
std::shared_ptr<mkldnn::memory> CreateMemoryToBeCached(
const mkldnn::memory::desc& desc, void* data) {
return std::make_shared<memory>(desc, engine_, data);
}
// Create weights memory and transform to default MKL-DNN format
std::shared_ptr<mkldnn::memory> CreateWeightsMemory(const Tensor* weights) {
auto dims = framework::vectorize(weights->dims()); auto dims = framework::vectorize(weights->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::io); auto src_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oi); auto dst_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oi);
// Transpose weights through MKL-DNN's reorder from io to oi format.
return Reorder(src_desc, dst_desc, return Reorder(src_desc, dst_desc,
platform::to_void_cast<float>(weights->data<float>())); platform::to_void_cast<float>(weights->data<float>()));
} }
void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) {
const std::string key = platform::CreateKey(platform::ThreadIDasStr());
const std::string weights_key = key + ctx.InputName("W");
const std::string bias_key = key + ctx.InputName("Bias");
dev_ctx.SetBlob(weights_key, weights_);
dev_ctx.SetBlob(bias_key, bias_);
}
// Compute the bias scales so that its values correspond to the // Compute the bias scales so that its values correspond to the
// scale of data being an output of weights and input multiplication // scale of data being an output of weights and input multiplication
std::vector<float> ComputeBiasScales(const ExecutionContext& ctx) { std::vector<float> ComputeBiasScales(const ExecutionContext& ctx) {
...@@ -388,14 +416,14 @@ class FCPrimitiveFactory { ...@@ -388,14 +416,14 @@ class FCPrimitiveFactory {
} }
void QuantizeWeights(const ExecutionContext& ctx, memory::desc dst) { void QuantizeWeights(const ExecutionContext& ctx, memory::desc dst) {
weights_ = weights_ = ReorderWithScale(weights_, dst,
Reorder(*weights_, dst, ctx.Attr<std::vector<float>>("Scale_weights")); ctx.Attr<std::vector<float>>("Scale_weights"));
} }
void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc, void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx) { const ExecutionContext& ctx) {
auto bias_scales = ComputeBiasScales(ctx); auto bias_scales = ComputeBiasScales(ctx);
bias_ = Reorder(*bias_, fc_prim_desc.bias_desc(), bias_scales); bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales);
} }
// Fuse relu into FC with activation type attribute has been set to 'relu' // Fuse relu into FC with activation type attribute has been set to 'relu'
...@@ -463,10 +491,10 @@ class FCPrimitiveFactory { ...@@ -463,10 +491,10 @@ class FCPrimitiveFactory {
private: private:
const mkldnn::engine& engine_; const mkldnn::engine& engine_;
boost::optional<memory> bias_;
boost::optional<memory> input_; boost::optional<memory> input_;
boost::optional<memory> output_; boost::optional<memory> output_;
boost::optional<memory> weights_; std::shared_ptr<memory> bias_;
std::shared_ptr<memory> weights_;
boost::optional<inner_product_forward> fc_; boost::optional<inner_product_forward> fc_;
}; };
...@@ -476,19 +504,13 @@ class FCPrimitiveFactory { ...@@ -476,19 +504,13 @@ class FCPrimitiveFactory {
template <typename T_in, typename T_w, typename T_out> template <typename T_in, typename T_w, typename T_out>
static std::shared_ptr<FCPrimitiveFactory<T_in, T_w, T_out>> static std::shared_ptr<FCPrimitiveFactory<T_in, T_w, T_out>>
GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx, GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx, const Tensor* input, const std::string& key) {
const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
const std::string key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(weights->dims()), ctx.OutputName("Out"));
auto prim_creator = auto prim_creator =
std::static_pointer_cast<FCPrimitiveFactory<T_in, T_w, T_out>>( std::static_pointer_cast<FCPrimitiveFactory<T_in, T_w, T_out>>(
dev_ctx.GetBlob(key)); dev_ctx.GetBlob(key));
if (prim_creator == nullptr) { if (prim_creator == nullptr) {
prim_creator = prim_creator = std::make_shared<FCPrimitiveFactory<T_in, T_w, T_out>>(
std::make_shared<FCPrimitiveFactory<T_in, T_w, T_out>>(mkldnn_engine); dev_ctx.GetEngine());
dev_ctx.SetBlob(key, prim_creator); dev_ctx.SetBlob(key, prim_creator);
} }
...@@ -498,24 +520,24 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx, ...@@ -498,24 +520,24 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
// Choose appropriate primitive factory implementation based on inferred // Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float). // output type (uint8, int8 or float).
template <typename T_in, typename T_w> template <typename T_in, typename T_w>
static void ExecuteFc(const MKLDNNDeviceContext& dev_ctx, static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
const ExecutionContext& ctx, const LoDTensor* input,
const Tensor* w, const Tensor* bias, LoDTensor* output, const Tensor* w, const Tensor* bias, LoDTensor* output,
const mkldnn::engine& mkldnn_engine, bool fuse_relu, bool fuse_relu, bool force_fp32_output) {
bool force_fp32_output) { auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const std::string prim_key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(w->dims()), ctx.OutputName("Out"));
constexpr bool is_int8 = constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value; std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
if (!is_int8 || force_fp32_output) { if (!is_int8 || force_fp32_output) {
GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, ctx, input, w, mkldnn_engine) GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, ctx); ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else if (fuse_relu) { } else if (fuse_relu) {
GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, ctx, input, w, GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, prim_key)
mkldnn_engine) ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
->ExecuteFcPrimitive(input, w, bias, output, ctx);
} else { } else {
GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, ctx, input, w, GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, prim_key)
mkldnn_engine) ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
->ExecuteFcPrimitive(input, w, bias, output, ctx);
} }
} }
...@@ -526,9 +548,6 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T_in> { ...@@ -526,9 +548,6 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T_in> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true, platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace.")); platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace."));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto input = ctx.Input<LoDTensor>("Input"); auto input = ctx.Input<LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W"); auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias"); auto bias = ctx.Input<Tensor>("Bias");
...@@ -537,8 +556,8 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T_in> { ...@@ -537,8 +556,8 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T_in> {
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu"; bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
ExecuteFc<T_in, T_w>(dev_ctx, ctx, input, w, bias, output, mkldnn_engine, ExecuteFc<T_in, T_w>(ctx, input, w, bias, output, fuse_relu,
fuse_relu, force_fp32_output); force_fp32_output);
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册