未验证 提交 653885a5 编写于 作者: J Jacek Czaja 提交者: GitHub

[WIP] Matmul v1 & v2 unification -- part 1 (#44640)

* - Unit tests to be debugged

- fix

- refactor

- diagnostic

- more diagnostic

- fix

- Fix number two

- fix

- fix

- fix

- alpha added

- more fixes

- compilation fix

- removed diagnostic code

- cosmetic fixes

* lint
上级 a6c50a6c
...@@ -416,7 +416,8 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -416,7 +416,8 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
bool trans_y, bool trans_y,
Tensor *out) const { Tensor *out) const {
static const std::vector<int64_t> vec_placeholder; static const std::vector<int64_t> vec_placeholder;
MatMulV2MKLDNNHandler<XT> handler(onednn_engine, MatMulV2MKLDNNHandler<XT> handler(ctx,
onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
x_dims, x_dims,
trans_x, trans_x,
......
...@@ -778,6 +778,59 @@ class BroadcastDataMKLDNNHandler ...@@ -778,6 +778,59 @@ class BroadcastDataMKLDNNHandler
} }
}; };
static void AppendActivation(const framework::ExecutionContext& ctx,
dnnl::post_ops& post_ops,
float activation_scale = 1.0f) {
const auto invalid_attribute =
ctx.HasAttr("fuse_activation")
? ctx.Attr<std::string>("fuse_activation").empty()
: true;
if (invalid_attribute) return;
const auto fuse_activation = ctx.Attr<std::string>("fuse_activation");
const auto fuse_alpha =
ctx.HasAttr("fuse_alpha") ? ctx.Attr<float>("fuse_alpha") : 0.0f;
const auto fuse_beta =
ctx.HasAttr("fuse_beta") ? ctx.Attr<float>("fuse_beta") : 0.0f;
if (fuse_activation == "hard_sigmoid") {
post_ops.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha,
fuse_beta);
post_ops.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else {
const std::unordered_map<std::string, dnnl::algorithm> activation_map = {
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_erf", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"hard_swish", dnnl::algorithm::eltwise_hardswish},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"mish", dnnl::algorithm::eltwise_mish},
{"relu", dnnl::algorithm::eltwise_relu},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"swish", dnnl::algorithm::eltwise_swish},
{"tanh", dnnl::algorithm::eltwise_tanh}};
const auto& activation_type = activation_map.find(fuse_activation);
PADDLE_ENFORCE_NE(
activation_type,
activation_map.end(),
platform::errors::InvalidArgument(
"Activation '%s' not found in oneDNN algorithms mapper",
fuse_activation));
post_ops.append_eltwise(
activation_scale, activation_type->second, fuse_alpha, fuse_beta);
}
}
template <typename T> template <typename T>
class ReductionMKLDNNHandler class ReductionMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction> {
...@@ -810,7 +863,8 @@ template <typename T> ...@@ -810,7 +863,8 @@ template <typename T>
class MatMulV2MKLDNNHandler class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> { : public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public: public:
MatMulV2MKLDNNHandler(const dnnl::engine engine, MatMulV2MKLDNNHandler(const framework::ExecutionContext& ctx,
const dnnl::engine engine,
paddle::platform::Place cpu_place, paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, const std::vector<int64_t>& x_org_dims,
bool trans_x, bool trans_x,
...@@ -888,7 +942,26 @@ class MatMulV2MKLDNNHandler ...@@ -888,7 +942,26 @@ class MatMulV2MKLDNNHandler
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides); auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides); auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md); const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
}
// TODO(jczaja) : Adapt to int8
dnnl::primitive_attr CreateMatmulAttrs(
const framework::ExecutionContext& ctx) {
dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
if (alpha != 1.0f) {
matmul_attrs.set_output_scales(0, {alpha});
}
AppendActivation(ctx, post_operations);
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
} }
std::vector<int64_t> FakeTransposeStrides( std::vector<int64_t> FakeTransposeStrides(
...@@ -1013,59 +1086,6 @@ class ActivationMKLDNNHandler ...@@ -1013,59 +1086,6 @@ class ActivationMKLDNNHandler
} }
}; };
static void AppendActivation(const framework::ExecutionContext& ctx,
dnnl::post_ops& post_ops,
float activation_scale = 1.0f) {
const auto invalid_attribute =
ctx.HasAttr("fuse_activation")
? ctx.Attr<std::string>("fuse_activation").empty()
: true;
if (invalid_attribute) return;
const auto fuse_activation = ctx.Attr<std::string>("fuse_activation");
const auto fuse_alpha =
ctx.HasAttr("fuse_alpha") ? ctx.Attr<float>("fuse_alpha") : 0.0f;
const auto fuse_beta =
ctx.HasAttr("fuse_beta") ? ctx.Attr<float>("fuse_beta") : 0.0f;
if (fuse_activation == "hard_sigmoid") {
post_ops.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha,
fuse_beta);
post_ops.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else {
const std::unordered_map<std::string, dnnl::algorithm> activation_map = {
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_erf", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"hard_swish", dnnl::algorithm::eltwise_hardswish},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"mish", dnnl::algorithm::eltwise_mish},
{"relu", dnnl::algorithm::eltwise_relu},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"swish", dnnl::algorithm::eltwise_swish},
{"tanh", dnnl::algorithm::eltwise_tanh}};
const auto& activation_type = activation_map.find(fuse_activation);
PADDLE_ENFORCE_NE(
activation_type,
activation_map.end(),
platform::errors::InvalidArgument(
"Activation '%s' not found in oneDNN algorithms mapper",
fuse_activation));
post_ops.append_eltwise(
activation_scale, activation_type->second, fuse_alpha, fuse_beta);
}
}
static std::unordered_map<std::string, std::string> GetAttributeMap( static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) { std::string act_type) {
std::unordered_map<std::string, std::string> attr_map; std::unordered_map<std::string, std::string> attr_map;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册