未验证 提交 d9262145 编写于 作者: L lidanqing 提交者: GitHub

[UT coverage] improve the mul_mkldnn_op line coverage (#22408)

* improve the mul_mkldnn_op line coverage
test=develop

* remove fp32 mul mkldnn kernel
test=develop

* locally refactoring
test=develop

* change according to reviews
test=develop
上级 c65c6ae5
......@@ -40,17 +40,24 @@ class MulPrimitiveFactory {
explicit MulPrimitiveFactory(const mkldnn::engine &engine)
: engine_(engine) {}
virtual ~MulPrimitiveFactory() {}
virtual inner_product_forward CreateMulPrimitive(
const Tensor *input_x, const Tensor *input_y, Tensor *output,
inner_product_forward CreateMulPrimitive(const Tensor *x_input,
const Tensor *y_input,
Tensor *output,
const ExecutionContext &ctx) {
/* check format and reorder if need */
/* check data format and reorder if need */
int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");
auto x_matrix = UpdateDataFormat<XT>(input_x, x_num_col_dims, ctx);
auto y_matrix = UpdateDataFormat<YT>(input_y, y_num_col_dims, ctx);
// TODO(intel-minghui) : Remove the restriction that only supports Input(Y)
// as weights
PADDLE_ENFORCE_EQ(
(std::is_same<YT, float>::value), true,
platform::errors::InvalidArgument(
"Input(Y) must be fp32 data type since only fp32 data type is "
"supported in the current design of MKLDNN INT8."));
auto x_matrix = UpdateDataFormat<XT>(x_input, x_num_col_dims, ctx);
auto y_matrix = UpdateDataFormat<YT>(y_input, y_num_col_dims, ctx);
auto output_dim = output->dims();
if (output_dim.size() != 2) {
......@@ -60,17 +67,110 @@ class MulPrimitiveFactory {
if (mul_) {
UpdateDataPointers(ctx, output, &x_matrix);
Execute();
return *mul_;
return *(mul_);
}
auto src_desc = CreateMemDescriptor<XT>(&x_matrix, MKLDNNMemoryFormat::nc);
x_input_ = CreateMemory<XT>(src_desc, &x_matrix);
if (is_int8_) {
const auto trans_y = TransposeInputY(&y_matrix);
auto scale_y = ctx.Attr<std::vector<float>>("scale_y");
y_input_ = QuantInputY(trans_y, scale_y);
} else {
y_input_ = TransposeInputY(&y_matrix);
}
auto dst_desc = CreateMemDescriptor<OT>(output, MKLDNNMemoryFormat::any);
mul_ = CreateMulPrimitive(*x_input_, *y_input_, dst_desc, output, ctx);
Execute();
return *mul_;
return *(mul_);
}
private:
memory ReorderWithScale(const memory::desc &src_desc,
const memory::desc &dst_desc, void *src_data,
const std::vector<float> &scale) {
auto mask = scale.size() > 1 ? 1 : 0;
mkldnn::primitive_attr attr;
attr.set_output_scales(mask, scale);
auto src_mem = memory(src_desc, engine_, src_data);
auto dst_mem = memory(dst_desc, engine_);
auto reorder_pd = mkldnn::reorder::primitive_desc(src_mem, dst_mem, attr);
auto reorder = mkldnn::reorder(reorder_pd);
mkldnn::stream astream(engine_);
reorder.execute(astream, src_mem, dst_mem);
astream.wait();
return dst_mem;
}
memory QuantInputY(memory input_y, const std::vector<float> &scale_y) {
const auto &dims = input_y.get_desc().data.dims;
auto ndims = input_y.get_desc().data.ndims;
auto y_dims = std::vector<int64_t>(dims, dims + ndims);
auto user_y_desc = CreateMemDescriptor<YT>(y_dims, MKLDNNMemoryFormat::oi);
auto y_desc = CreateMemDescriptor<int8_t>(y_dims, MKLDNNMemoryFormat::oi);
return ReorderWithScale(user_y_desc, y_desc, input_y.get_data_handle(),
scale_y);
}
mkldnn::primitive_attr CreateMulAttr(const ExecutionContext &ctx,
bool force_fp32_output) {
mkldnn::primitive_attr mul_attr;
auto scale_y_data = ctx.Attr<std::vector<float>>("scale_y");
auto scale_x_data = ctx.Attr<float>("scale_x");
auto scale_out_data =
force_fp32_output ? 1.0f : ctx.Attr<float>("scale_out");
bool is_multi_channel = scale_y_data.size() > 1;
int count = is_multi_channel ? scale_y_data.size() : 1;
std::vector<float> output_shift_scale(count);
for (int i = 0; i < count; i++) {
if (scale_y_data[i] == 0.0)
output_shift_scale[i] = scale_out_data;
else
output_shift_scale[i] =
scale_out_data / (scale_x_data * scale_y_data[i]);
}
int mul_mask = is_multi_channel ? 1 : 0;
mul_attr.set_output_scales(mul_mask, output_shift_scale);
return mul_attr;
}
inner_product_forward CreateMulPrimitive(const memory &x_memory,
const memory &y_memory,
const memory::desc &dst_desc,
Tensor *output,
const ExecutionContext &ctx) {
const auto x_desc = x_memory.get_desc();
const auto y_desc = y_memory.get_desc();
inner_product_forward::primitive_desc mul_prim_desc;
const auto &mul_desc = inner_product_forward::desc(
prop_kind::forward, x_desc, y_desc, dst_desc);
if (is_int8_) {
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto mul_attr = CreateMulAttr(ctx, force_fp32_output);
mul_prim_desc =
inner_product_forward::primitive_desc(mul_desc, mul_attr, engine_);
} else {
mul_prim_desc = inner_product_forward::primitive_desc(mul_desc, engine_);
}
output_ = CreateDstMemory(mul_prim_desc, ctx, output);
return inner_product_forward(mul_prim_desc);
}
void Execute() {
......@@ -81,7 +181,6 @@ class MulPrimitiveFactory {
astream.wait();
}
protected:
template <typename T>
Tensor UpdateDataFormat(const Tensor *data, int num_col_dims,
const ExecutionContext &ctx) {
......@@ -176,176 +275,13 @@ class MulPrimitiveFactory {
return Reorder(src_desc, dst_desc, to_void_cast<YT>(input_y->data<YT>()));
}
inner_product_forward CreateMulPrimitive(const memory &x_memory,
const memory &y_memory,
const memory::desc &dst_desc,
Tensor *output,
const ExecutionContext &ctx) {
const auto y_desc = y_memory.get_desc();
const auto x_desc = x_memory.get_desc();
auto mul_prim_desc = CreateMulPrimDesc(x_desc, y_desc, dst_desc);
output_ = CreateDstMemory(mul_prim_desc, ctx, output);
return inner_product_forward(mul_prim_desc);
}
inner_product_forward::primitive_desc CreateMulPrimDesc(
const memory::desc &x_desc, const memory::desc &y_desc,
const memory::desc &dst_desc) {
auto mul_desc = inner_product_forward::desc(prop_kind::forward, x_desc,
y_desc, dst_desc);
return inner_product_forward::primitive_desc(mul_desc, engine_);
}
protected:
const mkldnn::engine &engine_;
boost::optional<memory> x_input_;
boost::optional<memory> y_input_;
boost::optional<memory> output_;
boost::optional<inner_product_forward> mul_;
}; // namespace operators
template <typename XT, typename YT, typename OT>
class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> {
public:
using MulPrimitiveFactory<XT, YT, OT>::MulPrimitiveFactory;
virtual inner_product_forward CreateMulPrimitive(
const Tensor *x_input, const Tensor *y_input, Tensor *output,
const ExecutionContext &ctx) {
/* check data format and reorder if need */
int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");
auto scale_y = ctx.Attr<std::vector<float>>("scale_y");
// TODO(intel-minghui) : Remove the restriction that only supports Input(Y)
// as weights
bool enforce = std::is_same<YT, float>::value;
PADDLE_ENFORCE(
enforce == true,
"Input(Y) supposed to be fp32 data type since only fp32 data type is "
"supported in the current design of MKLDNN INT8.");
auto x_matrix =
this->template UpdateDataFormat<XT>(x_input, x_num_col_dims, ctx);
auto y_matrix =
this->template UpdateDataFormat<YT>(y_input, y_num_col_dims, ctx);
auto output_dim = output->dims();
if (output_dim.size() != 2) {
output->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
if (this->mul_) {
this->UpdateDataPointers(ctx, output, &x_matrix);
this->Execute();
return *(this->mul_);
}
auto src_desc = this->template CreateMemDescriptor<XT>(
&x_matrix, MKLDNNMemoryFormat::nc);
this->x_input_ = this->template CreateMemory<XT>(src_desc, &x_matrix);
const auto trans_y = this->TransposeInputY(&y_matrix);
this->y_input_ = QuantInputY(trans_y, scale_y);
auto dst_desc =
this->template CreateMemDescriptor<OT>(output, MKLDNNMemoryFormat::any);
this->mul_ = CreateMulPrimitive(*(this->x_input_), *(this->y_input_),
dst_desc, output, ctx);
this->Execute();
return *(this->mul_);
}
memory ReorderWithScale(const memory::desc &src_desc,
const memory::desc &dst_desc, void *src_data,
const std::vector<float> &scale) {
auto mask = scale.size() > 1 ? 1 : 0;
mkldnn::primitive_attr attr;
attr.set_output_scales(mask, scale);
auto src_mem = memory(src_desc, this->engine_, src_data);
auto dst_mem = memory(dst_desc, this->engine_);
auto reorder_pd = mkldnn::reorder::primitive_desc(src_mem, dst_mem, attr);
auto reorder = mkldnn::reorder(reorder_pd);
mkldnn::stream astream(this->engine_);
reorder.execute(astream, src_mem, dst_mem);
astream.wait();
return dst_mem;
}
memory QuantInputY(memory input_y, const std::vector<float> &scale_y) {
const auto &dims = input_y.get_desc().data.dims;
auto ndims = input_y.get_desc().data.ndims;
auto y_dims = std::vector<int64_t>(dims, dims + ndims);
auto user_y_desc =
this->template CreateMemDescriptor<YT>(y_dims, MKLDNNMemoryFormat::oi);
auto y_desc = this->template CreateMemDescriptor<int8_t>(
y_dims, MKLDNNMemoryFormat::oi);
return ReorderWithScale(user_y_desc, y_desc, input_y.get_data_handle(),
scale_y);
}
mkldnn::primitive_attr CreateMulAttr(const ExecutionContext &ctx,
bool force_fp32_output) {
mkldnn::primitive_attr mul_attr;
auto scale_y_data = ctx.Attr<std::vector<float>>("scale_y");
auto scale_x_data = ctx.Attr<float>("scale_x");
auto scale_out_data =
force_fp32_output ? 1.0f : ctx.Attr<float>("scale_out");
bool is_multi_channel = scale_y_data.size() > 1;
int count = is_multi_channel ? scale_y_data.size() : 1;
std::vector<float> output_shift_scale(count);
for (int i = 0; i < count; i++) {
if (scale_y_data[i] == 0.0)
output_shift_scale[i] = scale_out_data;
else
output_shift_scale[i] =
scale_out_data / (scale_x_data * scale_y_data[i]);
}
int mul_mask = is_multi_channel ? 1 : 0;
mul_attr.set_output_scales(mul_mask, output_shift_scale);
return mul_attr;
}
inner_product_forward CreateMulPrimitive(const memory &x_memory,
const memory &y_memory,
const memory::desc &dst_desc,
Tensor *output,
const ExecutionContext &ctx) {
const auto x_desc = x_memory.get_desc();
const auto y_desc = y_memory.get_desc();
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
mkldnn::primitive_attr mul_attr = CreateMulAttr(ctx, force_fp32_output);
auto mul_prim_desc = CreateMulPrimDesc(x_desc, y_desc, dst_desc, mul_attr);
this->output_ = this->CreateDstMemory(mul_prim_desc, ctx, output);
return inner_product_forward(mul_prim_desc);
}
inner_product_forward::primitive_desc CreateMulPrimDesc(
const memory::desc &x_desc, const memory::desc &y_desc,
const memory::desc &dst_desc, const mkldnn::primitive_attr &mul_attr) {
const auto &mul_desc = inner_product_forward::desc(
prop_kind::forward, x_desc, y_desc, dst_desc);
return inner_product_forward::primitive_desc(mul_desc, mul_attr,
this->engine_);
}
static constexpr bool is_int8_ =
std::is_same<XT, int8_t>::value || std::is_same<XT, uint8_t>::value;
};
/* OT: output data type */
......@@ -353,7 +289,7 @@ template <typename XT, typename YT, typename OT>
std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
const MKLDNNDeviceContext &dev_ctx, const ExecutionContext &ctx,
const Tensor *input_x, const Tensor *input_y,
const mkldnn::engine &mkldnn_engine, bool enable_quant) {
const mkldnn::engine &mkldnn_engine) {
const std::string key = platform::CreateKey(
input_x->type(), framework::vectorize(input_x->dims()), input_y->type(),
framework::vectorize(input_y->dims()), ctx.OutputName("Out"));
......@@ -363,10 +299,7 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
if (prim_creator == nullptr) {
prim_creator =
enable_quant
? std::make_shared<QuantMulPrimitiveFactory<XT, YT, OT>>(
mkldnn_engine)
: std::make_shared<MulPrimitiveFactory<XT, YT, OT>>(mkldnn_engine);
std::make_shared<MulPrimitiveFactory<XT, YT, OT>>(mkldnn_engine);
dev_ctx.SetBlob(key, prim_creator);
}
......@@ -379,18 +312,18 @@ inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx,
const Tensor *input_x,
const Tensor *input_y, Tensor *output,
const mkldnn::engine &mkldnn_engine) {
bool enable_quant =
constexpr bool is_int8 =
std::is_same<XT, int8_t>::value || std::is_same<XT, uint8_t>::value;
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
if (enable_quant && !force_fp32_output) {
if (is_int8 && !force_fp32_output) {
return GetPrimitiveFactory<XT, YT, int8_t>(dev_ctx, ctx, input_x, input_y,
mkldnn_engine, enable_quant)
mkldnn_engine)
->CreateMulPrimitive(input_x, input_y, output, ctx);
} else {
return GetPrimitiveFactory<XT, YT, float>(dev_ctx, ctx, input_x, input_y,
mkldnn_engine, enable_quant)
mkldnn_engine)
->CreateMulPrimitive(input_x, input_y, output, ctx);
}
}
......
......@@ -52,11 +52,11 @@ class TestMKLDNNMulOpS8S8(OpTest):
# limit random range inside |-127, 127| to avoid overflow on SKL
if self.srctype == np.int8:
A_data = np.random.randint(-127, 127, (2, 5)).astype(np.int8)
A_data = np.random.randint(-127, 127, (20, 5)).astype(np.int8)
else:
A_data = np.random.randint(0, 127, (2, 5)).astype(np.uint8)
A_data = np.random.randint(0, 127, (20, 5)).astype(np.uint8)
B_data = np.random.uniform(-127, 127, (5, 3)).astype(np.float32)
B_data = np.random.uniform(-127, 127, (5, 20)).astype(np.float32)
quant_B = np.round(B_data * self.scale_y[0]).astype(np.int)
output = np.dot(A_data, quant_B)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册