提交 acdf7cbd 编写于 作者: J Jacek Czaja

- Added EPS for softmax MKLDNN op

- EPS added to softmax mkldnn primitive outcome is limited to training
phase

Fixes after review

clang format fixes

clang format fixes
上级 a097d082
......@@ -73,6 +73,15 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
softmax_dst_memory);
std::vector<primitive> pipeline{softmax};
stream(stream::kind::eager).submit(pipeline).wait();
const bool is_test = ctx.Attr<bool>("is_test");
if (!is_test) {
T threshold = exp(-64);
for (size_t i = 0; i < dst_tz[0] * dst_tz[1]; ++i) {
output_data[i] =
output_data[i] < threshold ? threshold : output_data[i];
}
}
}
};
......
......@@ -97,6 +97,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("is_test",
"Disable epsilon adding to softmax results. Used by MKLDNN.")
.SetDefault(false);
AddComment(R"DOC(
Softmax Operator.
......
......@@ -87,6 +87,7 @@ def fc(input,
bias_attr=None,
use_mkldnn=False,
act=None,
is_test=False,
name=None):
"""
**Fully Connected Layer**
......@@ -133,6 +134,7 @@ def fc(input,
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to None, no bias will be added to the output units.
act (str, default None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase.
use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn
library is installed. Default: False
name (str, default None): The name of this layer.
......@@ -177,7 +179,9 @@ def fc(input,
"W": w},
outputs={"Out": tmp},
attrs={"use_mkldnn": use_mkldnn,
"bias_attr": bias_attr})
"is_test": is_test,
"bias_attr": bias_attr
})
return helper.append_activation(tmp)
else:
for input_var, param_attr in helper.iter_inputs_and_params():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册