未验证 提交 066a8063 编写于 作者: W wangxinxin08 提交者: GitHub

fix attr missing in conv cudnn kernel (#38827)

上级 b4dd7828
......@@ -65,7 +65,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
int groups = ctx.Attr<int>("groups");
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
FLAGS_cudnn_exhaustive_search || (ctx.HasAttr("exhaustive_search") &&
ctx.Attr<bool>("exhaustive_search"));
bool deterministic = FLAGS_cudnn_deterministic;
auto exhaustive_deterministic = exhaustive_search && deterministic;
PADDLE_ENFORCE_EQ(exhaustive_deterministic, false,
......@@ -386,7 +387,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
int groups = ctx.Attr<int>("groups");
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
FLAGS_cudnn_exhaustive_search || (ctx.HasAttr("exhaustive_search") &&
ctx.Attr<bool>("exhaustive_search"));
bool deterministic = FLAGS_cudnn_deterministic;
auto exhaustive_deterministic = exhaustive_search && deterministic;
PADDLE_ENFORCE_EQ(exhaustive_deterministic, false,
......@@ -437,7 +439,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
ctx, input_grad, &transformed_input_grad_channel);
// NOTE(zhiqiu): If inplace_addto strategy is enabled, we need to copy
// the data of input_grad to transformed_input_grad_channel.
if (ctx.Attr<bool>("use_addto")) {
if (ctx.HasAttr("use_addto") && ctx.Attr<bool>("use_addto")) {
TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input_grad, &transformed_input_grad_channel);
}
......@@ -703,15 +705,17 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// MIOPEN ONLY support beta to be 0.0f
ScalingParamType<T> beta = 0.0f;
#else
ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f : 0.0f;
ScalingParamType<T> beta =
(ctx.HasAttr("use_addto") && ctx.Attr<bool>("use_addto")) ? 1.0f : 0.0f;
#endif
VLOG(4) << "Conv_grad: use_addto = " << ctx.Attr<bool>("use_addto");
VLOG(4) << "Conv_grad: use_addto = "
<< (ctx.HasAttr("use_addto") && ctx.Attr<bool>("use_addto"));
if (input_grad) {
// When beta is 0, it is unnecessary to reset input_grad.
// When beta is 1, the output cannot be reset since addt strategy used.
#ifdef PADDLE_WITH_HIP
if (ctx.Attr<bool>("use_addto")) {
if (ctx.HasAttr("use_addto") && ctx.Attr<bool>("use_addto")) {
Tensor temp_tensor(transformed_input_grad.type());
temp_tensor.Resize(transformed_input_grad.dims());
T* temp_tensor_data = temp_tensor.mutable_data<T>(ctx.GetPlace());
......@@ -878,7 +882,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
int groups = ctx.Attr<int>("groups");
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
FLAGS_cudnn_exhaustive_search || (ctx.HasAttr("exhaustive_search") &&
ctx.Attr<bool>("exhaustive_search"));
bool deterministic = FLAGS_cudnn_deterministic;
auto exhaustive_deterministic = exhaustive_search && deterministic;
PADDLE_ENFORCE_EQ(exhaustive_deterministic, false,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册