未验证 提交 4a4f3f80 编写于 作者: S Sławomir Siwek 提交者: GitHub

migrate convs (#47658)

上级 ca4bed7b
......@@ -776,247 +776,6 @@ class ConvMKLDNNHandlerT
} // anonymous namespace
#define PD_VISIT_FLOAT_AND_INT8_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
template <typename T>
class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
true,
platform::errors::PreconditionNotMet(
"Operator DNNL Conv must use CPUPlace"));
bool is_INT8 =
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
bool is_BFLOAT16 = ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
auto residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto dst_dt = GetDstType(is_INT8,
is_BFLOAT16,
force_fp32_output,
fuse_activation,
fuse_residual_conn,
residual_param);
if (!is_INT8) {
if (dst_dt == dnnl::memory::data_type::f32) {
ComputeFP32<float>(ctx);
} else if (dst_dt == dnnl::memory::data_type::bf16) {
ComputeFP32<platform::bfloat16>(ctx);
}
} else {
if (dst_dt == dnnl::memory::data_type::f32) {
ComputeINT8<float>(ctx);
} else if (dst_dt == dnnl::memory::data_type::u8) {
ComputeINT8<uint8_t>(ctx);
} else if (dst_dt == dnnl::memory::data_type::s8) {
ComputeINT8<int8_t>(ctx);
}
}
}
template <typename T_out>
void ComputeFP32(const framework::ExecutionContext& ctx) const {
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
bool is_test = ctx.Attr<bool>("is_test");
const auto& strides = ctx.Attr<std::vector<int>>("strides");
bool is_conv3d = strides.size() == 3UL;
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups");
const auto* input = ctx.Input<phi::DenseTensor>("Input");
const auto* filter = ctx.Input<phi::DenseTensor>("Filter");
const auto* bias =
ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr;
auto* output = ctx.Output<phi::DenseTensor>("Output");
PD_VISIT_FLOAT_AND_INT8_TYPES(
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
ConvMKLDNNHandlerT<T, data_t, T_out> handler(
ctx,
dev_ctx,
mkldnn_engine,
ctx.GetPlace(),
input,
filter,
bias,
output,
ctx.InputName("Input") + ctx.InputName("Filter"));
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, groups, is_conv3d, is_test);
std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) {
auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
dst_memory_p =
handler.AcquireDstMemoryWithResidual(output, residual_param);
} else {
dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
}
auto conv_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
if (bias) {
auto bias_memory_p =
handler.AcquireBiasMemoryWithReorder(bias, is_test);
args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
conv_p->execute(astream, args);
astream.wait();
output->set_mem_desc(dst_memory_p->get_desc());
}));
}
template <typename T_out>
void ComputeINT8(const framework::ExecutionContext& ctx) const {
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const std::string& fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const bool& fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
const bool& force_fp32_output = ctx.Attr<bool>("force_fp32_output");
const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U;
bool unsigned_output =
(fuse_activation == "relu" || fuse_activation == "relu6");
bool need_s8_to_u8 = false;
PADDLE_ENFORCE_NE(
is_conv3d,
true,
platform::errors::Unimplemented(
"OneDNN int8 convolution does not support 3D inputs currently"));
PADDLE_ENFORCE_EQ(
fuse_residual_conn && force_fp32_output,
false,
platform::errors::Unimplemented(
"residual fusion does not support force output with fp32"));
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* filter = ctx.Input<phi::DenseTensor>("Filter");
auto* bias =
ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr;
auto* output = ctx.Output<phi::DenseTensor>("Output");
PD_VISIT_FLOAT_AND_INT8_TYPES(
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
ConvMKLDNNHandlerT<T, data_t, T_out> handler(
ctx,
dev_ctx,
mkldnn_engine,
ctx.GetPlace(),
input,
filter,
bias,
output,
ctx.InputName("Input") + ctx.InputName("Filter"));
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
const auto& scale_weights_data =
ctx.Attr<std::vector<float>>("Scale_weights");
const bool is_multi_channel = scale_weights_data.size() > 1;
const int& groups = ctx.Attr<int>("groups");
int mask_reorder =
is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0)
: 0;
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, groups, false, true, scale_weights_data, mask_reorder);
std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) {
auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
PADDLE_ENFORCE_EQ(
output->dims(),
residual_param->dims(),
platform::errors::InvalidArgument(
"Output and elementwise parameter need to have the "
"same dimension sizes, but got output's dimension = %d"
" and residual param's dimension =%d .",
output->dims().size(),
residual_param->dims().size()));
dst_memory_p =
handler.AcquireDstMemoryWithResidual(output, residual_param);
need_s8_to_u8 = (platform::MKLDNNGetDataType<T_out>() ==
dnnl::memory::data_type::s8) &&
unsigned_output;
} else {
dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
}
auto conv_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
if (bias) {
std::vector<float> bias_scales;
auto p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder),
bias_scales));
if (ctx.HasAttr("Bias_scales")) {
bias_scales = ctx.Attr<std::vector<float>>("Bias_scales");
p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder),
bias_scales));
} else {
p_scales_tuple = handler.get_int8_bias_scales(ctx);
}
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
bias,
true,
std::get<1>(*p_scales_tuple),
std::get<0>(*p_scales_tuple));
args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
conv_p->execute(astream, args);
astream.wait();
if (need_s8_to_u8) {
output->mutable_data<uint8_t>(ctx.GetPlace());
}
output->set_mem_desc(dst_memory_p->get_desc());
}));
}
};
#define PD_VISIT_FLOAT_AND_BF16_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
......@@ -1184,25 +943,12 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
ops::ConvMKLDNNOpKernel<float>,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::ConvMKLDNNOpKernel<uint8_t>,
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL(depthwise_conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::ConvMKLDNNGradOpKernel<float>,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(conv3d,
MKLDNN,
::paddle::platform::CPUPlace,
ops::ConvMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(conv3d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
......
......@@ -424,6 +424,52 @@ void ConvKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void DepthwiseConvKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
DenseTensor* out) {
ConvKernel<T, Context>(dev_ctx,
input,
filter,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
out);
}
template <typename T, typename Context>
void Conv3DKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
DenseTensor* out) {
ConvKernel<T, Context>(dev_ctx,
input,
filter,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
out);
}
} // namespace phi
PD_REGISTER_KERNEL(conv2d,
......@@ -434,3 +480,14 @@ PD_REGISTER_KERNEL(conv2d,
phi::dtype::bfloat16,
uint8_t,
int8_t) {}
PD_REGISTER_KERNEL(depthwise_conv2d,
OneDNN,
ONEDNN,
phi::DepthwiseConvKernel,
float,
phi::dtype::bfloat16,
uint8_t,
int8_t) {}
PD_REGISTER_KERNEL(conv3d, OneDNN, ONEDNN, phi::Conv3DKernel, float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册