提交 76550d87 编写于 作者: K Krzysztof Binias

Reformat code

上级 c4107748
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
limitations under the License. */ limitations under the License. */
#include "mkldnn.hpp" #include "mkldnn.hpp"
#include "mkldnn_activation_op.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -46,14 +46,18 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -46,14 +46,18 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
// create memory description // create memory description
auto data_md = src_tz.size() == 2 auto data_md = src_tz.size() == 2
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nc) mkldnn::memory::format::nc)
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw); mkldnn::memory::format::nchw);
// create memory primitives // create memory primitives
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src_data); auto src_memory =
auto dst_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)dst_data); mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(src_data)));
auto dst_memory =
mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(dst_data)));
auto forward_desc = mkldnn::eltwise_forward::desc( auto forward_desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta); mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
...@@ -94,17 +98,20 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -94,17 +98,20 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
// create memory description // create memory description
auto data_md = src_tz.size() == 2 auto data_md = src_tz.size() == 2
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nc) mkldnn::memory::format::nc)
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw); mkldnn::memory::format::nchw);
// create memory primitives // create memory primitives
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src); auto src_memory = mkldnn::memory(
{data_md, mkldnn_engine}, static_cast<void *>(const_cast<float *>(src)));
auto diff_src_memory = auto diff_src_memory =
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_src); mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(diff_src)));
auto diff_dst_memory = auto diff_dst_memory =
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_dst); mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(diff_dst)));
auto backward_desc = auto backward_desc =
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta); mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册