提交 4bc2987d 编写于 作者: B Brian Liu 提交者: Tao Luo

Fix bug in quantize kernel which cause crash in vgg16/19 model (#17964)

* Fix bug in quantize kernel which cause crash in vgg16/19 model

test=develop

* refine the code to reduce verbose code; test=develop

* remove useless code; test=develop
上级 dd3f9d19
......@@ -85,11 +85,11 @@ class QuantOpKernel : public framework::OpKernel<T> {
std::shared_ptr<mkldnn::memory::primitive_desc> dst_pd;
if (is_negative) {
platform::ConvMKLDNNHandler::SetDstMemory<int8_t>(
ctx, output, dst_tz, engine, dst_pd, dst_memory);
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_pd, dst_memory);
} else {
platform::ConvMKLDNNHandler::SetDstMemory<uint8_t>(
ctx, output, dst_tz, engine, dst_pd, dst_memory);
platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine,
dst_pd, dst_memory);
}
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, *dst_pd, attri));
......
......@@ -27,6 +27,7 @@ namespace paddle {
namespace platform {
using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = mkldnn::memory;
class MKLDNNHandler {
public:
......@@ -196,21 +197,6 @@ class MKLDNNHandler {
return dims2str(operand_dims) + suffix;
}
template <typename T>
static void SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output,
std::vector<int> dst_tz, const mkldnn::engine& engine,
std::shared_ptr<mkldnn::memory::primitive_desc>& dst_pd, // NOLINT
std::shared_ptr<mkldnn::memory>& dst_memory) { // NOLINT
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<T>::DataType),
mkldnn::memory::format::nhwc);
dst_pd.reset(new mkldnn::memory::primitive_desc(dst_md, engine));
dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast<T>(output_data)));
}
static void AppendKey(
std::string* key, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
......@@ -915,5 +901,26 @@ static void SetDstMemoryHandler(
(*dst_memory_p)->set_data_handle(to_void_cast<T>(output_data));
}
template <typename T>
static void SetDstMemoryQuantized(
const framework::ExecutionContext& ctx, framework::Tensor* output,
std::vector<int> dst_tz, const mkldnn::engine& engine,
std::shared_ptr<mkldnn::memory::primitive_desc>& dst_pd, // NOLINT
std::shared_ptr<mkldnn::memory>& dst_memory) { // NOLINT
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const size_t dst_dims = dst_tz.size();
memory::format dst_fmt;
PADDLE_ENFORCE(dst_dims <= 5,
"Dst memory for quantization can not have dims > 5");
dst_fmt = platform::MKLDNNFormatForSize(dst_dims, memory::format::nhwc);
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<T>::DataType),
dst_fmt);
dst_pd.reset(new mkldnn::memory::primitive_desc(dst_md, engine));
dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast<T>(output_data)));
}
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册