提交 06f9c2e7 编写于 作者: T TensorFlower Gardener

Merge pull request #24055 from Intel-tensorflow:guizili/leakyrelu

PiperOrigin-RevId: 225237733
......@@ -262,6 +262,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
csinfo_.fused_conv2d = "_FusedConv2D";
csinfo_.identity = "Identity";
csinfo_.leakyrelu = "LeakyRelu";
csinfo_.leakyrelu_grad = "LeakyReluGrad";
csinfo_.lrn = "LRN";
csinfo_.lrn_grad = "LRNGrad";
csinfo_.matmul = "MatMul";
......@@ -392,6 +394,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
CopyAttrsLRN, LrnGradRewrite});
rinfo_.push_back({csinfo_.leakyrelu,
mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
CopyAttrsLeakyRelu, LeakyReluRewrite});
rinfo_.push_back({csinfo_.leakyrelu_grad,
mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
CopyAttrsLeakyRelu, LeakyReluRewrite});
rinfo_.push_back({csinfo_.max_pool,
mkl_op_registry::GetMklOpName(csinfo_.max_pool),
CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
......@@ -671,6 +679,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string fused_batch_norm_grad;
string fused_conv2d;
string identity;
string leakyrelu;
string leakyrelu_grad;
string lrn;
string lrn_grad;
string matmul;
......@@ -1148,6 +1158,30 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return do_rewrite;
}
// MKL-DNN's LeakyRelu(feature) = feature (if feature > 0), or
// feature * alpha (otherwise),
// while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
// These two algorithms are not consistent when alpha > 1,
// so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
static bool LeakyReluRewrite(const Node* n) {
DCHECK(n);
float alpha;
bool has_attr = GetNodeAttr(n->def(), "alpha", &alpha).ok();
DCHECK(has_attr);
// If the alpha of LeakyRelu is less than 1, rewrite the node.
// Otherwise eigen node is used instead.
if (alpha <= 1) {
return true;
}
VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
<< "which case is not optimized by Intel MKL, thus using Eigen op"
<< "for LeakyRelu ";
return false;
}
static bool MaxpoolGradRewrite(const Node* n) {
CHECK_NOTNULL(n);
bool do_rewrite = false;
......@@ -1358,6 +1392,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
bool change_format = false);
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsLeakyRelu(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsFusedConv2D(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
......@@ -2061,6 +2097,21 @@ void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
nb->Attr("beta", beta);
}
void MklLayoutRewritePass::CopyAttrsLeakyRelu(const Node* orig_node,
NodeBuilder* nb,
bool change_format) {
DataType T;
float alpha;
// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
// Add attributes to new node.
nb->Attr("T", T);
nb->Attr("alpha", alpha);
}
void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
NodeBuilder* nb,
bool change_format) {
......
......@@ -1648,6 +1648,85 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu6Relu6Grad_Positive) {
"DMT/_1->C:2");
}
TEST_F(MklLayoutPassTest, NodeRewrite_LeakyRelu_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'LeakyRelu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.1 } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLeakyRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;"
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_LeakyRelu_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'LeakyRelu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 2.0 } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(LeakyRelu);C(Zeta)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_LeakyReluGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'LeakyReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.1 } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklLeakyReluGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
}
TEST_F(MklLayoutPassTest, NodeRewrite_LeakyReluGrad_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'LeakyReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 2.0 } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(LeakyReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_LeakyReluLeakyReluGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'LeakyRelu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.1 } }"
" input: ['A'] }"
"node { name: 'C' op: 'LeakyReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.1 } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLeakyRelu);C(_MklLeakyReluGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
"DMT/_1->C:2");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
......
......@@ -204,7 +204,7 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
~MklEltwiseFwdPrimitiveFactory() {}
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams,
memory::format src_fmt) {
memory::format src_fmt) {
string prefix = "eltwise_fwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
......@@ -422,8 +422,8 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
private:
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams,
const memory::format& src_fmt,
const memory::format& diff_dst_fmt) {
const memory::format& src_fmt,
const memory::format& diff_dst_fmt) {
string prefix = "eltwise_bwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
......@@ -856,9 +856,9 @@ class MklReluOpBase : public OpKernel {
Tensor* dst_tensor = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{static_cast<const int>(src_index)},
static_cast<const int>(dst_index),
tf_shape_dst, &dst_tensor));
{static_cast<const int>(src_index)},
static_cast<const int>(dst_index),
tf_shape_dst, &dst_tensor));
AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst);
T* dst_data = dst_tensor->flat<T>().data();
......@@ -867,18 +867,19 @@ class MklReluOpBase : public OpKernel {
eltwise_fwd->Execute(src_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) +
", in file " + string(__FILE__) + ":" +
std::to_string(__LINE__);
OP_REQUIRES_OK(context,
errors::Aborted("Operation received an exception:",
error_msg));
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
}
}
private:
engine cpu_engine = engine(engine::cpu, 0);
std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
protected:
float alpha_;
float beta_;
};
......@@ -947,11 +948,11 @@ class MklReluGradOpBase : public OpKernel {
auto diff_dst_tf_data_format =
MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
src_dims = (src_tensor.dims() == 4)
? TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
diff_dst_tf_data_format)
: TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
diff_dst_tf_data_format);
src_dims = (src_tensor.dims() == 4)
? TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
diff_dst_tf_data_format)
: TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
diff_dst_tf_data_format);
src_md =
memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format);
} else {
......@@ -1001,8 +1002,7 @@ class MklReluGradOpBase : public OpKernel {
// allocate diff_src tensor
MklDnnShape dnn_shape_diff_src;
TensorShape tf_shape_diff_src;
if (dnn_shape_src.IsMklTensor() ||
dnn_shape_diff_dst.IsMklTensor()) {
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc();
dnn_shape_diff_src.SetMklTensor(true);
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
......@@ -1012,9 +1012,10 @@ class MklReluGradOpBase : public OpKernel {
dnn_shape_src.GetSizesAsMklDnnDims(),
dnn_shape_src.GetTfDataFormat());
} else {
dnn_shape_diff_src.SetTfLayout(dnn_shape_diff_dst.GetDimension(),
dnn_shape_diff_dst.GetSizesAsMklDnnDims(),
dnn_shape_diff_dst.GetTfDataFormat());
dnn_shape_diff_src.SetTfLayout(
dnn_shape_diff_dst.GetDimension(),
dnn_shape_diff_dst.GetSizesAsMklDnnDims(),
dnn_shape_diff_dst.GetTfDataFormat());
}
tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
} else {
......@@ -1045,6 +1046,8 @@ class MklReluGradOpBase : public OpKernel {
private:
engine cpu_engine = engine(engine::cpu, 0);
std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
protected:
float alpha_;
float beta_;
};
......@@ -1312,8 +1315,86 @@ class MklRelu6GradOp
T* out_o = diff_src_tensor->flat<T>().data();
T* user_i = const_cast<T*>(src_tensor.flat<T>().data());
T* user_g = const_cast<T*>(diff_dst_tensor.flat<T>().data());
out_o[0] = user_g[0] * user_i[0] > 0 &&
(user_i[0] < static_cast<T>(RELU6_UPPER_BOUND));
out_o[0] = user_g[0] * (user_i[0] > 0 &&
(user_i[0] < static_cast<T>(RELU6_UPPER_BOUND)));
return;
}
};
template <typename Device, typename T>
class MklLeakyReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
public:
~MklLeakyReluOp() {}
explicit MklLeakyReluOp(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {
float alpha;
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
OP_REQUIRES(
context, alpha <= 1,
errors::InvalidArgument("MKL LeakyRelu only supports alpha <= 1. "
"alpha is: ",
alpha));
this->alpha_ = alpha;
}
virtual void Compute_Scalar(OpKernelContext* context) {
const size_t src_index = 0; // index of src input tensor
const size_t dst_index = 0; // index of dst output tensor
const Tensor& src_tensor = MklGetInput(context, src_index);
MklDnnShape dnn_shape_src;
GetMklShape(context, src_index, &dnn_shape_src);
Tensor* dst_tensor = nullptr;
T* user_i = const_cast<T*>(src_tensor.flat<T>().data());
MklDnnShape dnn_shape_dst;
dnn_shape_dst.SetMklTensor(false);
AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
src_tensor.shape(), dnn_shape_dst);
T* out_o = dst_tensor->flat<T>().data();
out_o[0] = user_i[0] >= 0 ? user_g[0] : user_g[0] * this->alpha_;
return;
}
};
template <typename Device, typename T>
class MklLeakyReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
public:
~MklLeakyReluGradOp() {}
explicit MklLeakyReluGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {
float alpha;
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
OP_REQUIRES(
context, alpha <= 1,
errors::InvalidArgument("MKL LeakyRelu only supports alpha <= 1. "
"alpha is: ",
alpha));
this->alpha_ = alpha;
}
virtual void Compute_Scalar(OpKernelContext* context) {
const size_t diff_dst_index = 0; // index of diff_dst input tensor
const size_t src_index = 1; // index of src input tensor
const size_t diff_src_index = 0; // index of diff_src output tensor
const Tensor& src_tensor = MklGetInput(context, src_index);
const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
Tensor* diff_src_tensor = nullptr;
MklDnnShape dnn_shape_diff_dst;
GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
MklDnnShape dnn_shape_diff_src;
dnn_shape_diff_src.SetMklTensor(false);
AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
diff_dst_tensor.shape(), dnn_shape_diff_src);
T* out_o = diff_src_tensor->flat<T>().data();
T* user_i = const_cast<T*>(src_tensor.flat<T>().data());
T* user_g = const_cast<T*>(diff_dst_tensor.flat<T>().data());
out_o[0] = user_i[0] >= 0 ? user_g[0] : user_g[0] * this->alpha_;
return;
}
};
......@@ -1376,6 +1457,19 @@ TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES);
MklRelu6GradOp<CPUDevice, type>);
TF_CALL_float(REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES);
#define REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES(type) \
REGISTER_KERNEL_BUILDER(Name("_MklLeakyRelu") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklLeakyReluOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("_MklLeakyReluGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklLeakyReluGradOp<CPUDevice, type>);
TF_CALL_float(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES);
#endif
} // namespace tensorflow
......
......@@ -1964,6 +1964,40 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
REGISTER_OP("_MklLeakyRelu")
.Input("features: T")
.Input("mkl_features: uint8")
.Output("activations: T")
.Output("mkl_activations: uint8")
.Attr("T: {half, float, double} = DT_FLOAT")
.Attr("alpha: float = 0.2")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
MKL version of LeakyRelu operator. Uses MKL DNN APIs to implement
LeakyRelu operator.
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
REGISTER_OP("_MklLeakyReluGrad")
.Input("gradients: T")
.Input("features: T")
.Input("mkl_gradients: uint8")
.Input("mkl_features: uint8")
.Output("backprops: T")
.Output("mkl_backprops: uint8")
.Attr("T: {half, float, double} = DT_FLOAT")
.Attr("alpha: float = 0.2")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
MKL version of LeakyReluGrad operator. Uses MKL DNN APIs to compute rectified
linear gradients for LeakyReluGrad operation.
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
REGISTER_OP("_MklElu")
.Input("features: T")
.Input("mkl_features: uint8")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册