提交 9ecd8ee7 编写于 作者: L lidanqing 提交者: Tao Luo

change ComputeINT8 to template version to remove checking dst_datatype code (#18756)

* change INT8 to template so that checking dst_dt with if-else could be removed. CI will be enabled after fixing reviews

* reverse user_residual_memory_p and user_bias_memory_p declaration scope
test=develop
上级 d9e7b5b5
...@@ -69,6 +69,26 @@ inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format, ...@@ -69,6 +69,26 @@ inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format,
} }
} }
static mkldnn::memory::data_type GetDstType(bool is_int8,
bool force_fp32_output,
bool fuse_relu, bool fuse_brelu,
bool fuse_residual_conn,
const Tensor* residual_param) {
auto dst_dt = mkldnn::memory::data_type::f32; // uint8_t, int8_t, float
if (is_int8) {
dst_dt = (fuse_relu || fuse_brelu) ? mkldnn::memory::data_type::u8
: mkldnn::memory::data_type::s8;
if (force_fp32_output) {
dst_dt = mkldnn::memory::data_type::f32;
}
if (fuse_residual_conn && residual_param) {
auto residual_dt = framework::ToMKLDNNDataType(residual_param->type());
if (dst_dt != residual_dt) dst_dt = residual_dt;
}
}
return dst_dt;
}
template <typename T, typename K> template <typename T, typename K>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -80,7 +100,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -80,7 +100,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (!is_INT8) { if (!is_INT8) {
ComputeFP32(ctx); ComputeFP32(ctx);
} else { } else {
ComputeINT8(ctx); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool fuse_brelu = ctx.Attr<bool>("fuse_brelu");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto dst_dt = GetDstType(true, force_fp32_output, fuse_relu, fuse_brelu,
fuse_residual_conn, residual_param);
if (dst_dt == mkldnn::memory::data_type::f32) {
ComputeINT8<float>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::u8) {
ComputeINT8<uint8_t>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::s8) {
ComputeINT8<int8_t>(ctx);
}
} }
} }
...@@ -287,7 +320,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -287,7 +320,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
template <typename T_out>
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const { void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -328,10 +361,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -328,10 +361,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
float fuse_brelu_threshold = ctx.Attr<float>("fuse_brelu_threshold"); float fuse_brelu_threshold = ctx.Attr<float>("fuse_brelu_threshold");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool unsigned_output = fuse_relu || fuse_brelu; bool unsigned_output = fuse_relu || fuse_brelu;
if (fuse_residual_conn) {
PADDLE_ENFORCE(force_fp32_output != true, PADDLE_ENFORCE(!fuse_residual_conn || !force_fp32_output,
"residual fusion does not support force output with fp32"); "residual fusion does not support force output with fp32");
}
bool is_conv3d = strides.size() == 3U; bool is_conv3d = strides.size() == 3U;
// TODO(tpatejko): add support for dilation // TODO(tpatejko): add support for dilation
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -356,23 +389,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -356,23 +389,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
auto dst_dt = unsigned_output
? paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<uint8_t>::DataType())
: paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType());
if (force_fp32_output) {
dst_dt = paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<float>::DataType());
}
if (fuse_residual_conn) {
auto residual = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type());
if (dst_dt != residual_dt) dst_dt = residual_dt;
}
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
std::string key; std::string key;
key.reserve(MaxKeyLength); key.reserve(MaxKeyLength);
...@@ -453,28 +469,35 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -453,28 +469,35 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format); platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc( auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8, chosen_memory_format); weights_tz, memory::data_type::s8, chosen_memory_format);
auto dst_md = auto dst_md = platform::MKLDNNMemDesc(
platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
handler.reset(
new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key));
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause // TODO(lidanqing): We use relu post-op instead of brelu post-op cause
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when // mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
// v0.20 is enabled // v0.20 is enabled
std::shared_ptr<memory::desc> bias_md_p; auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
bias_md_p = std::make_shared<memory::desc>(platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
bias_tz, memory::data_type::s32, memory::format::x)); mkldnn::memory::format::x);
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
propagation, output_shift_scale, sum_scale);
} else {
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
propagation, output_shift_scale, sum_scale);
} }
conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, bias_md_p, dst_md, strides, paddings,
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
output_shift_scale, sum_scale, is_test);
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
mkldnn_engine, key));
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
user_src_memory_p = user_src_memory_p =
handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
...@@ -502,38 +525,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -502,38 +525,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (residual_param->format() != handler->GetDstFormat()) { if (residual_param->format() != handler->GetDstFormat()) {
auto residual_data_tz = auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims()); paddle::framework::vectorize2int(residual_param->dims());
auto user_residual_md = platform::MKLDNNMemDesc( auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_dt, residual_param->format()); residual_data_tz, residual_dt, residual_param->format());
dst_memory_p = platform::SetDstMemory<T_out>(
if (residual_dt == mkldnn::memory::data_type::u8) { ctx, output, residual_param, user_residual_md, handler,
dst_memory_p = platform::SetDstMemory<uint8_t>( &pipeline);
ctx, output, residual_param, user_residual_md, handler,
&pipeline);
} else {
need_s8_to_u8 = unsigned_output;
dst_memory_p = platform::SetDstMemory<int8_t>(
ctx, output, residual_param, user_residual_md, handler,
&pipeline);
}
} else { } else {
output->ShareDataWith(*residual_param); output->ShareDataWith(*residual_param);
if (residual_dt == mkldnn::memory::data_type::u8) { dst_memory_p = platform::SetDstMemory<T_out>(ctx, output, handler);
dst_memory_p =
platform::SetDstMemory<uint8_t>(ctx, output, handler);
} else {
need_s8_to_u8 = unsigned_output;
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
}
}
} else if (!force_fp32_output) {
if (unsigned_output) {
dst_memory_p = platform::SetDstMemory<uint8_t>(ctx, output, handler);
} else {
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
} }
need_s8_to_u8 =
(platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8) &&
unsigned_output;
} else { } else {
dst_memory_p = platform::SetDstMemory<float>(ctx, output, handler); dst_memory_p = platform::SetDstMemory<T_out>(ctx, output, handler);
} }
// create convolution op primitive // create convolution op primitive
...@@ -564,7 +569,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -564,7 +569,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p, conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p,
dst_memory_p); dst_memory_p);
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
} else { } else {
...@@ -592,29 +596,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -592,29 +596,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_residual_conn) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_dt =
paddle::framework::ToMKLDNNDataType(residual_param->type());
output->ShareDataWith(*residual_param); output->ShareDataWith(*residual_param);
if (residual_dt == mkldnn::memory::data_type::u8) { need_s8_to_u8 =
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler, (platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8) &&
&dst_memory_p); unsigned_output;
} else {
need_s8_to_u8 = unsigned_output;
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler,
&dst_memory_p);
}
} else if (!force_fp32_output) {
if (unsigned_output) {
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler,
&dst_memory_p);
} else {
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler,
&dst_memory_p);
}
} else {
platform::SetDstMemoryHandler<float>(ctx, output, handler,
&dst_memory_p);
} }
platform::SetDstMemoryHandler<T_out>(ctx, output, handler, dst_memory_p);
if (src_memory_reorder_p) { if (src_memory_reorder_p) {
pipeline.push_back(*src_memory_reorder_p); pipeline.push_back(*src_memory_reorder_p);
...@@ -625,87 +612,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -625,87 +612,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (residual_reorder_p) { if (residual_reorder_p) {
pipeline.push_back(*residual_reorder_p); pipeline.push_back(*residual_reorder_p);
} }
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
if (need_s8_to_u8) { if (need_s8_to_u8) {
output->mutable_data<uint8_t>(ctx.GetPlace()); output->mutable_data<uint8_t>(ctx.GetPlace());
} }
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
private:
mkldnn::primitive_attr CreatePostOps(
bool fuse_relu, bool fuse_residual_conn,
const std::vector<float>& output_shift_scale, float sum_scale,
bool fuse_brelu, float fuse_brelu_threshold) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
if (fuse_residual_conn) {
post_operations.append_sum(sum_scale);
}
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 1.0f; // beta
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
if (fuse_brelu) {
constexpr float scale = 1.0f;
constexpr float placeholder = 0.0f; // beta
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_brelu_threshold, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const std::shared_ptr<memory::desc> bias_md_p,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_brelu,
const float fuse_brelu_threshold,
const std::vector<float>& output_shift_scale,
const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
auto conv_desc =
(bias_md_p != nullptr)
? mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights,
(*bias_md_p), dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero)
: mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights, dst,
stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,
sum_scale, fuse_brelu, fuse_brelu_threshold);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
}; };
template <typename T> template <typename T>
......
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -82,22 +81,24 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { ...@@ -82,22 +81,24 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
template <typename Type> template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() { mkldnn::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_undef; return mkldnn::memory::data_type::data_undef;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<float>() { inline mkldnn::memory::data_type MKLDNNGetDataType<float>() {
return mkldnn::memory::f32; return mkldnn::memory::data_type::f32;
}
template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<int32_t>() {
return mkldnn::memory::data_type::s32;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<int8_t>() { inline mkldnn::memory::data_type MKLDNNGetDataType<int8_t>() {
return mkldnn::memory::s8; return mkldnn::memory::data_type::s8;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() { inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() {
return mkldnn::memory::u8; return mkldnn::memory::data_type::u8;
} }
inline void Reorder(const mkldnn::memory& src, const mkldnn::memory& dst) { inline void Reorder(const mkldnn::memory& src, const mkldnn::memory& dst) {
......
...@@ -1160,18 +1160,24 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -1160,18 +1160,24 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
scale_data, mask); scale_data, mask);
} }
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn, mkldnn::primitive_attr CreatePostOps(
bool fuse_brelu, bool fuse_relu, bool fuse_residual_conn, bool fuse_brelu,
float fuse_brelu_threshold) const { float fuse_brelu_threshold,
const std::vector<float> output_shift_scale = {},
float sum_scale = 1.0f) const {
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
if (output_shift_scale.size() > 0) {
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
}
// Fusion with Elementwise layer relies on adding a sum post-operation with // Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is // the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual // true, the output tensor contains the data coming from residual
// connection. The result of this post_op is: // connection. The result of this post_op is:
// Output = scale * Output + Conv_Out. // Output = scale * Output + Conv_Out.
if (fuse_residual_conn) { if (fuse_residual_conn) {
post_operations.append_sum(1.0f); post_operations.append_sum(sum_scale);
} }
// Fusion with ReLU layer is executed through the PostOps feature. Create a // Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation. // PostOps object and configure it to execute an eltwise relu operation.
...@@ -1202,7 +1208,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -1202,7 +1208,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
const std::vector<int>& paddings, const mkldnn::engine& engine, const std::vector<int>& paddings, const mkldnn::engine& engine,
const bool fuse_relu, const bool fuse_residual_conn, const bool fuse_relu, const bool fuse_residual_conn,
const bool fuse_brelu, const float fuse_brelu_threshold, const bool fuse_brelu, const float fuse_brelu_threshold,
mkldnn::prop_kind fwd_prop_kind) { mkldnn::prop_kind fwd_prop_kind,
const std::vector<float> output_shift_scale = {},
const float sum_scale = 1.0f) {
// Conv PD has to be passed to Grad op that // Conv PD has to be passed to Grad op that
// may be exxecuted by diffrent thread, hence // may be exxecuted by diffrent thread, hence
// for that one we use key that does not contain TID // for that one we use key that does not contain TID
...@@ -1232,8 +1240,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -1232,8 +1240,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
src, weights, dst, stride_dims, padding_dims, src, weights, dst, stride_dims, padding_dims,
padding_dims, mkldnn::padding_kind::zero); padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps( mkldnn::primitive_attr conv_attr =
fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold); CreatePostOps(fuse_relu, fuse_residual_conn, fuse_brelu,
fuse_brelu_threshold, output_shift_scale, sum_scale);
conv_pd_.reset(new typename forward_t::primitive_desc( conv_pd_.reset(new typename forward_t::primitive_desc(
conv_desc, conv_attr, engine)); conv_desc, conv_attr, engine));
...@@ -1393,10 +1402,10 @@ template <typename T> ...@@ -1393,10 +1402,10 @@ template <typename T>
static void SetDstMemoryHandler( static void SetDstMemoryHandler(
const framework::ExecutionContext& ctx, framework::Tensor* output, const framework::ExecutionContext& ctx, framework::Tensor* output,
const std::shared_ptr<ConvMKLDNNHandler>& handler, const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::shared_ptr<mkldnn::memory>* dst_memory_p) { std::shared_ptr<mkldnn::memory> dst_memory_p) {
T* output_data = T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler->GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), handler->GetDstMemorySize());
(*dst_memory_p)->set_data_handle(to_void_cast<T>(output_data)); dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
} }
template <typename T> template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册