提交 bf95ac36 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: further reformatting

上级 cbe122ae
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
......
......@@ -303,7 +303,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool fuse_eltwise = ctx.Attr<bool>("fuse_eltwise");
int groups = ctx.Attr<int>("groups");
// TODO: add support for dilation
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
......@@ -386,21 +386,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data));
T* output_data = nullptr;
if (fuse_eltwise) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(residual_param_data != nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), "Output and elementwise parameter need to have the same dimension sizes");
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the "
"same dimension sizes");
output_data = output->mutable_data<T>(ctx.GetPlace());
output->ShareDataWith(*residual_param);
} else {
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
}
// create reorder primitive if the input format is not the preferred one
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册