提交 4cf499c0 编写于 作者: B bingyanghuang 提交者: Tao Luo

cherry-pick PR#20640 to release 1.6, test=release/1.6 (#20706)

上级 33a58e58
...@@ -29,32 +29,34 @@ using mkldnn::stream; ...@@ -29,32 +29,34 @@ using mkldnn::stream;
using platform::to_void_cast; using platform::to_void_cast;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
constexpr int same_scale_mask = 0; inline void GetWeightsTz(std::vector<int>& weights_tz, int groups, // NOLINT
constexpr int o_slice_mask = 1 << 0; // 1 bool is_conv3d) {
constexpr int g_slice_mask = 1 << 1; // 2
constexpr int g_o_slice_mask = g_slice_mask | o_slice_mask; // 3
static int ComputeMask(bool is_multi_channel, int multi_channel_mask) {
return is_multi_channel ? multi_channel_mask : same_scale_mask;
}
static int ComputeWeightsMask(int is_multi_channel, int g) {
int multi_channel_mask = g > 1 ? g_o_slice_mask : o_slice_mask;
return ComputeMask(is_multi_channel, multi_channel_mask);
}
static int ComputeBiasMask(int is_multi_channel) {
return ComputeMask(is_multi_channel, o_slice_mask);
}
inline void GetWeightsTz(std::vector<int>& weights_tz, int groups) { // NOLINT
if (groups > 1) { if (groups > 1) {
// if (is_conv3d) [o, i, dimension, h, w]->[g, o/g, i, dimension, h, w] if (is_conv3d) {
// else [o, i, h, w] -> [g, o/g, i, h, w] int output = weights_tz[0];
weights_tz.push_back(0); int input = weights_tz[1];
std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end()); int dimension = weights_tz[2];
int height = weights_tz[3];
int width = weights_tz[4];
weights_tz.resize(6);
weights_tz[0] = groups; weights_tz[0] = groups;
weights_tz[1] = weights_tz[1] / groups; weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = dimension;
weights_tz[4] = height;
weights_tz[5] = width;
} else {
int output = weights_tz[0];
int input = weights_tz[1];
int height = weights_tz[2];
int width = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = groups;
weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = height;
weights_tz[4] = width;
}
} }
} }
...@@ -67,59 +69,28 @@ inline MKLDNNMemoryFormat GetWeightsFormat(MKLDNNMemoryFormat format, ...@@ -67,59 +69,28 @@ inline MKLDNNMemoryFormat GetWeightsFormat(MKLDNNMemoryFormat format,
} }
} }
static std::vector<float> ComputeOutputShiftScale(
const float scale_out_data, const float scale_in_data,
const std::vector<float>& scale_weights_data) {
int count = scale_weights_data.size();
std::vector<float> output_shift_scale(count);
#pragma omp parallel for
for (int i = 0; i < count; i++) {
if (scale_weights_data[i] == 0.0) {
output_shift_scale[i] = scale_out_data;
} else {
output_shift_scale[i] =
static_cast<float>(static_cast<double>(scale_out_data) /
(static_cast<double>(scale_in_data) *
static_cast<double>(scale_weights_data[i])));
}
}
return output_shift_scale;
}
static std::vector<float> ComputeBiasScale(
const float scale_in_data, const std::vector<float>& scale_weights_data) {
int count = scale_weights_data.size();
std::vector<float> scale_bias_data(count);
#pragma omp parallel for if (count > 1)
for (int i = 0; i < count; i++) {
scale_bias_data[i] = scale_in_data * scale_weights_data[i];
}
return scale_bias_data;
}
static mkldnn::memory::data_type GetDstType(bool is_int8, static mkldnn::memory::data_type GetDstType(bool is_int8,
bool force_fp32_output, bool force_fp32_output,
std::string fuse_activation, std::string fuse_activation,
bool fuse_residual_conn, bool fuse_residual_conn,
const Tensor* residual_param) { const Tensor* residual_param) {
auto dst_dt = mkldnn::memory::data_type::f32; // uint8_t, int8_t, float auto dst_dt = mkldnn::memory::data_type::f32; // uint8_t, int8_t, float
if (is_int8 && !force_fp32_output) { if (is_int8) {
if (fuse_residual_conn && residual_param) {
// when residual exists, dst_dt will follow the residual_param type,
// but output will to be set to u8 if relu exists
auto residual_dt = framework::ToMKLDNNDataType(residual_param->type());
dst_dt = residual_dt;
} else {
// when residual does not exist, if (b)relu exist s8 else s8
dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6") dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6")
? mkldnn::memory::data_type::u8 ? mkldnn::memory::data_type::u8
: mkldnn::memory::data_type::s8; : mkldnn::memory::data_type::s8;
if (force_fp32_output) {
dst_dt = mkldnn::memory::data_type::f32;
}
if (fuse_residual_conn && residual_param) {
auto residual_dt = framework::ToMKLDNNDataType(residual_param->type());
if (dst_dt != residual_dt) dst_dt = residual_dt;
} }
} }
return dst_dt; return dst_dt;
} }
template <typename T> template <typename T, typename K>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
...@@ -215,7 +186,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -215,7 +186,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
auto weights_tz = paddle::framework::vectorize<int>(filter->dims()); auto weights_tz = paddle::framework::vectorize<int>(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g); GetWeightsTz(weights_tz, g, is_conv3d);
auto dst_tz = paddle::framework::vectorize<int>(output->dims()); auto dst_tz = paddle::framework::vectorize<int>(output->dims());
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
...@@ -297,8 +268,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -297,8 +268,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>(); auto residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE( PADDLE_ENFORCE_NE(
residual_param_data != nullptr, residual_param_data, nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion"); "Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the " "Output and elementwise parameter need to have the "
...@@ -358,7 +329,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -358,7 +329,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
template <typename T_out> template <typename T_out>
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const { void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -417,11 +387,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -417,11 +387,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool unsigned_output = bool unsigned_output =
(fuse_activation == "relu" || fuse_activation == "relu6"); (fuse_activation == "relu" || fuse_activation == "relu6");
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
auto scale_out_data =
force_fp32_output ? 1.0f : ctx.Attr<float>("Scale_out");
PADDLE_ENFORCE(!fuse_residual_conn || !force_fp32_output, PADDLE_ENFORCE(!fuse_residual_conn || !force_fp32_output,
"residual fusion does not support force output with fp32"); "residual fusion does not support force output with fp32");
...@@ -442,7 +407,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -442,7 +407,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
auto weights_tz = paddle::framework::vectorize<int>(filter->dims()); auto weights_tz = paddle::framework::vectorize<int>(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g);
GetWeightsTz(weights_tz, g, is_conv3d);
auto dst_tz = paddle::framework::vectorize<int>(output->dims()); auto dst_tz = paddle::framework::vectorize<int>(output->dims());
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
...@@ -451,27 +417,71 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -451,27 +417,71 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::string key = platform::CreateKey( std::string key = platform::CreateKey(
src_tz, src_dt, ctx.op().Input("Input") + ctx.op().Input("Filter")); src_tz, src_dt, ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false;
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p; std::shared_ptr<mkldnn::memory> src_memory_p;
std::shared_ptr<mkldnn::memory> user_src_memory_p; std::shared_ptr<mkldnn::memory> user_src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd; std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
std::shared_ptr<mkldnn::memory> dst_memory_p, user_residual_memory_p; std::shared_ptr<platform::ConvMKLDNNHandler> handler;
// This is workaround for hacky implementation
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
// are merged/unified, this will disappear
std::string key_tid = "";
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
key_tid = "-t:" + platform::ThreadIDasStr();
}
const float* filter_data = filter->data<float>(); auto prim_key = key + key_tid + "@conv_p";
bool is_multi_channel = scale_weights_data.size() > 1; auto dst_key = key + key_tid + "@dst_mem_p";
auto src_key = key + key_tid + "@src_mem_p";
auto user_src_key = key + key_tid + "@user_src_mem_p";
auto src_reorder_key = key + key_tid + "@src_mem_preorder_p";
auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p";
auto output_shift_scale = ComputeOutputShiftScale( conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
scale_out_data, scale_in_data, scale_weights_data); dev_ctx.GetBlob(prim_key));
float scale_residual = if (conv_p == nullptr || !is_test) {
const K* filter_data = filter->data<K>();
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
auto scale_out_data =
force_fp32_output ? 1.0f : ctx.Attr<float>("Scale_out");
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f; fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
bool is_multi_channel = scale_weights_data.size() > 1;
int count = is_multi_channel ? (g > 1 ? (weights_tz)[1] * (weights_tz)[0]
: (weights_tz)[0])
: 1;
std::vector<float> output_shift_scale(count);
#pragma omp parallel for if (count > 1)
for (int i = 0; i < count; i++) {
if (scale_weights_data[i] == 0.0)
output_shift_scale[i] =
scale_out_data; // weights data will contain 0
// in some models, then weights
// scale couldn't be calculated
else
output_shift_scale[i] =
static_cast<float>(static_cast<double>(scale_out_data) /
(static_cast<double>(scale_in_data) *
static_cast<double>(scale_weights_data[i])));
}
auto user_src_md = auto user_src_md =
platform::MKLDNNMemDesc({src_tz}, src_dt, input->format()); platform::MKLDNNMemDesc({src_tz}, src_dt, input->format());
auto user_weights_md = platform::MKLDNNMemDesc( auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<float>(), {weights_tz}, platform::MKLDNNGetDataType<K>(),
((g) == 1) ? mkldnn::memory::format::oihw ((g) == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw);
: mkldnn::memory::format::goihw);
/* create memory descriptor for convolution without specified format /* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
...@@ -481,113 +491,155 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -481,113 +491,155 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
auto src_md = platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format); std::vector<int> bias_tz;
auto weights_md = platform::MKLDNNMemDesc(weights_tz, memory::data_type::s8,
chosen_memory_format); auto src_md =
platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8, chosen_memory_format);
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
platform::ConvMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); handler.reset(
new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key));
// create a conv primitive descriptor and save it for usage in backward
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
std::vector<int> bias_tz;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize<int>(bias->dims()); bias_tz = paddle::framework::vectorize<int>(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
mkldnn::memory::format::x); MKLDNNMemoryFormat::x);
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, src_md, weights_md, bias_md, dst_md, strides, paddings,
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn, mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
propagation, output_shift_scale, scale_residual); fuse_residual_conn, propagation, output_shift_scale, sum_scale);
} else { } else {
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings, src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, propagation, output_shift_scale, scale_residual); fuse_residual_conn, propagation, output_shift_scale, sum_scale);
} }
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
user_src_memory_p = user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory( auto user_weights_memory_p = handler->AcquireWeightsMemory(
user_weights_md, to_void_cast<float>(filter_data)); user_weights_md, to_void_cast<K>(filter_data));
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
src_memory_p = src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler->AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
std::shared_ptr<mkldnn::memory> weights_memory_p; std::shared_ptr<mkldnn::memory> weights_memory_p;
int mask_reorder =
int mask_reorder = ComputeWeightsMask(is_multi_channel, g); is_multi_channel ? ((g != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0;
weights_memory_p = handler->AcquireWeightsMemoryFromPrimitive(
weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test, true, scale_weights_data, user_weights_memory_p, pipeline, is_test, true, scale_weights_data,
mask_reorder); mask_reorder);
if (fuse_residual_conn) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T_out>();
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the " "Output and elementwise parameter need to have the "
"same dimension sizes"); "same dimension sizes");
auto residual_dt = auto residual_dt =
paddle::framework::ToMKLDNNDataType(residual_param->type()); paddle::framework::ToMKLDNNDataType(residual_param->type());
if (residual_param->format() != handler.GetDstFormat()) { if (residual_param->format() != handler->GetDstFormat()) {
auto residual_data_tz = auto residual_data_tz =
paddle::framework::vectorize<int>(residual_param->dims()); paddle::framework::vectorize<int>(residual_param->dims());
auto user_residual_md = platform::MKLDNNMemDesc( auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_dt, residual_param->format()); residual_data_tz, residual_dt, residual_param->format());
dst_memory_p = platform::SetDstMemory<T_out>(
user_residual_memory_p = handler.AcquireResidualDataMemory( ctx, output, residual_param, user_residual_md, handler,
user_residual_md, to_void_cast<T_out>(residual_param_data)); &pipeline);
T_out* output_data = output->mutable_data<T_out>(ctx.GetPlace());
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T_out>(output_data), pipeline);
} else { } else {
output->ShareDataWith(*residual_param); output->ShareDataWith(*residual_param);
auto output_data = output->mutable_data<T_out>(ctx.GetPlace()); dst_memory_p = platform::SetDstMemory<T_out>(ctx, output, handler);
dst_memory_p = handler.AcquireDstMemoryFromPrimitive(
to_void_cast<T_out>(output_data));
} }
need_s8_to_u8 =
(platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8) &&
unsigned_output;
} else { } else {
T_out* output_data = output->mutable_data<T_out>( dst_memory_p = platform::SetDstMemory<T_out>(ctx, output, handler);
ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p = handler.AcquireDstMemoryFromPrimitive(
to_void_cast<T_out>(output_data));
} }
// create convolution op primitive // create convolution op primitive
auto scale_bias_key = key + "@scale_bias";
if (bias) { if (bias) {
const float* bias_data = bias->data<float>(); const K* bias_data = bias->data<K>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<K>(), MKLDNNMemoryFormat::x);
auto user_bias_memory_p = handler.AcquireBiasMemory( auto user_bias_memory_p = handler->AcquireBiasMemory(
user_bias_md, to_void_cast<float>(bias_data)); user_bias_md, to_void_cast<K>(bias_data));
std::shared_ptr<mkldnn::memory> bias_memory_p; std::shared_ptr<mkldnn::memory> bias_memory_p;
int mask_reorder = is_multi_channel ? 1 << 0 : 1;
auto scale_bias_data = int count =
ComputeBiasScale(scale_in_data, scale_weights_data); is_multi_channel
int mask_bias_reorder = ComputeBiasMask(is_multi_channel); ? (g > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0])
bias_memory_p = handler.AcquireBiasMemoryFromPrimitive( : 1;
std::vector<float> scale_bias_data(count);
#pragma omp parallel for if (count > 1)
for (int i = 0; i < count; i++) {
scale_bias_data[i] = scale_in_data * scale_weights_data[i];
}
bias_memory_p = handler->AcquireBiasMemoryFromPrimitive(
user_bias_memory_p, pipeline, is_test, true, scale_bias_data, user_bias_memory_p, pipeline, is_test, true, scale_bias_data,
mask_bias_reorder); mask_reorder);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p); bias_memory_p, dst_memory_p);
} else { } else {
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p,
dst_memory_p); dst_memory_p);
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
} else {
auto src_memory_reorder_p = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(src_reorder_key));
src_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
if (src_memory_reorder_p) {
user_src_memory_p = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(user_src_key));
user_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
} else if (src_memory_p) {
src_memory_p->set_data_handle(to_void_cast<T>(input_data));
}
dst_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
dev_ctx.GetBlob(key_conv_pd));
if (conv_pd) {
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
mkldnn_engine, key));
}
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
output->ShareDataWith(*residual_param);
need_s8_to_u8 =
(platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8) &&
unsigned_output;
}
platform::SetDstMemoryHandler<T_out>(ctx, output, handler, dst_memory_p);
if (src_memory_reorder_p) {
pipeline.push_back(*src_memory_reorder_p);
}
auto residual_reorder_p = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(residual_reorder_key));
if (residual_reorder_p) {
pipeline.push_back(*residual_reorder_p);
}
pipeline.push_back(*conv_p);
}
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
if (platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8 && if (need_s8_to_u8) {
unsigned_output) {
output->mutable_data<uint8_t>(ctx.GetPlace()); output->mutable_data<uint8_t>(ctx.GetPlace());
} }
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
...@@ -649,7 +701,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -649,7 +701,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
auto weights_tz = paddle::framework::vectorize<int>(filter->dims()); auto weights_tz = paddle::framework::vectorize<int>(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g); GetWeightsTz(weights_tz, g, is_conv3d);
auto dst_tz = paddle::framework::vectorize<int>(output_grad->dims()); auto dst_tz = paddle::framework::vectorize<int>(output_grad->dims());
auto src_format = input->format(); auto src_format = input->format();
MKLDNNMemoryFormat weights_format = MKLDNNMemoryFormat weights_format =
...@@ -704,7 +756,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -704,7 +756,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto conv_pd = auto conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>( std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
dev_ctx.GetBlob(key_conv_pd)); dev_ctx.GetBlob(key_conv_pd));
PADDLE_ENFORCE(conv_pd != nullptr, PADDLE_ENFORCE_NE(conv_pd, nullptr,
"Fail to find conv_pd in device context"); "Fail to find conv_pd in device context");
// create backward convolution weights primitive descriptor // create backward convolution weights primitive descriptor
...@@ -786,6 +838,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -786,6 +838,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -794,17 +847,17 @@ namespace ops = paddle::operators; ...@@ -794,17 +847,17 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
::paddle::platform::CPUPlace, FP32, ::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32, ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float>); ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
::paddle::platform::CPUPlace, U8, ::paddle::platform::CPUPlace, U8,
ops::kConvMKLDNNINT8, ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<uint8_t>); ops::ConvMKLDNNOpKernel<uint8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
::paddle::platform::CPUPlace, S8, ::paddle::platform::CPUPlace, S8,
ops::kConvMKLDNNINT8, ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<int8_t>); ops::ConvMKLDNNOpKernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32, ::paddle::platform::CPUPlace, FP32,
...@@ -814,7 +867,7 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, ...@@ -814,7 +867,7 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
::paddle::platform::CPUPlace, FP32, ::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32, ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float>); ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32, ::paddle::platform::CPUPlace, FP32,
......
...@@ -816,6 +816,15 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -816,6 +816,15 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
mkldnn::engine engine, const std::string& base_key) mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {} : platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
// TODO(jczaja): remove after conv int8 is adapted
ConvMKLDNNTemplateHandler(
std::shared_ptr<typename forward_t::primitive_desc> conv_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {
conv_pd_ = conv_pd;
}
ConvMKLDNNTemplateHandler( ConvMKLDNNTemplateHandler(
std::shared_ptr<typename forward_t::primitive_desc> conv_pd, std::shared_ptr<typename forward_t::primitive_desc> conv_pd,
std::shared_ptr<typename backward_data_t::primitive_desc> std::shared_ptr<typename backward_data_t::primitive_desc>
...@@ -1136,6 +1145,47 @@ using ConvTransposeMKLDNNHandler = ...@@ -1136,6 +1145,47 @@ using ConvTransposeMKLDNNHandler =
mkldnn::deconvolution_backward_data, mkldnn::deconvolution_backward_data,
mkldnn::deconvolution_backward_weights>; mkldnn::deconvolution_backward_weights>;
template <typename T>
static std::shared_ptr<mkldnn::memory> SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const std::shared_ptr<ConvMKLDNNHandler>& handler) {
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler->GetDstMemorySize());
std::shared_ptr<mkldnn::memory> dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
return dst_memory_p;
}
template <typename T>
static std::shared_ptr<mkldnn::memory> SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const framework::Tensor* residual_param,
const mkldnn::memory::desc& user_residual_md,
const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::vector<mkldnn::primitive>* pipeline) {
const T* residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
std::shared_ptr<mkldnn::memory> user_residual_memory_p =
handler->AcquireResidualDataMemory(user_residual_md,
to_void_cast<T>(residual_param_data));
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::shared_ptr<mkldnn::memory> dst_memory_p =
handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), *pipeline);
return dst_memory_p;
}
template <typename T>
static void SetDstMemoryHandler(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
}
template <typename T> template <typename T>
static void SetDstMemoryQuantized( static void SetDstMemoryQuantized(
const framework::ExecutionContext& ctx, framework::Tensor* output, const framework::ExecutionContext& ctx, framework::Tensor* output,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册