未验证 提交 eded6013 编写于 作者: C Chen Weihang 提交者: GitHub

Simplify conv_mkldnn op registration (#46907)

* simplify conv_mkldnn op registration

* remove custom type value in conv grad op
上级 2010bdc3
...@@ -511,13 +511,6 @@ function(op_library TARGET) ...@@ -511,13 +511,6 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator # Append first implemented MKLDNN activation operator
if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op") elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
file(APPEND ${pybind_file} file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n") "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
......
...@@ -229,31 +229,13 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -229,31 +229,13 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
(input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType())
? OperatorWithKernel::IndicateVarDataType(ctx, "Filter") ==
framework::DataTypeTrait<int8_t>::DataType()
? kConvMKLDNNINT8WS8
: kConvMKLDNNINT8
: kConvMKLDNNFP32;
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN, framework::LibraryType::kMKLDNN);
customized_type_value);
} }
#endif #endif
// #ifndef PADDLE_WITH_ASCEND_CL
// if (input_data_type == framework::proto::VarType::FP16) {
// PADDLE_ENFORCE_EQ(
// library, framework::LibraryType::kCUDNN,
// platform::errors::InvalidArgument(
// "float16 can only be used when CUDNN or NPU is used"));
// }
// #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -517,8 +499,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -517,8 +499,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
return framework::OpKernelType(data_type, return framework::OpKernelType(data_type,
ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN, framework::LibraryType::kMKLDNN);
kConvMKLDNNFP32);
} }
#endif #endif
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/core/visit_type.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace { namespace {
...@@ -774,7 +776,22 @@ class ConvMKLDNNHandlerT ...@@ -774,7 +776,22 @@ class ConvMKLDNNHandlerT
} // anonymous namespace } // anonymous namespace
template <typename T, typename K> #define PD_VISIT_FLOAT_AND_INT8_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
template <typename T>
class ConvMKLDNNOpKernel : public framework::OpKernel<T> { class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -828,7 +845,9 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -828,7 +845,9 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr; ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr;
auto* output = ctx.Output<phi::DenseTensor>("Output"); auto* output = ctx.Output<phi::DenseTensor>("Output");
ConvMKLDNNHandlerT<T, K, T_out> handler( PD_VISIT_FLOAT_AND_INT8_TYPES(
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
ConvMKLDNNHandlerT<T, data_t, T_out> handler(
ctx, ctx,
dev_ctx, dev_ctx,
mkldnn_engine, mkldnn_engine,
...@@ -861,7 +880,8 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -861,7 +880,8 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
if (bias) { if (bias) {
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test); auto bias_memory_p =
handler.AcquireBiasMemoryWithReorder(bias, is_test);
args.insert({DNNL_ARG_BIAS, *bias_memory_p}); args.insert({DNNL_ARG_BIAS, *bias_memory_p});
} }
...@@ -870,6 +890,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -870,6 +890,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
astream.wait(); astream.wait();
output->set_mem_desc(dst_memory_p->get_desc()); output->set_mem_desc(dst_memory_p->get_desc());
}));
} }
template <typename T_out> template <typename T_out>
...@@ -905,7 +926,9 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -905,7 +926,9 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr; ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr;
auto* output = ctx.Output<phi::DenseTensor>("Output"); auto* output = ctx.Output<phi::DenseTensor>("Output");
ConvMKLDNNHandlerT<T, K, T_out> handler( PD_VISIT_FLOAT_AND_INT8_TYPES(
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
ConvMKLDNNHandlerT<T, data_t, T_out> handler(
ctx, ctx,
dev_ctx, dev_ctx,
mkldnn_engine, mkldnn_engine,
...@@ -923,7 +946,8 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -923,7 +946,8 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
const bool is_multi_channel = scale_weights_data.size() > 1; const bool is_multi_channel = scale_weights_data.size() > 1;
const int& groups = ctx.Attr<int>("groups"); const int& groups = ctx.Attr<int>("groups");
int mask_reorder = int mask_reorder =
is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0; is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0)
: 0;
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, groups, false, true, scale_weights_data, mask_reorder); filter, groups, false, true, scale_weights_data, mask_reorder);
...@@ -959,17 +983,19 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -959,17 +983,19 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
std::vector<float> bias_scales; std::vector<float> bias_scales;
auto p_scales_tuple = auto p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>( std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder), bias_scales)); std::make_tuple(static_cast<float>(mask_reorder),
bias_scales));
if (ctx.HasAttr("Bias_scales")) { if (ctx.HasAttr("Bias_scales")) {
bias_scales = ctx.Attr<std::vector<float>>("Bias_scales"); bias_scales = ctx.Attr<std::vector<float>>("Bias_scales");
p_scales_tuple = p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>( std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder), bias_scales)); std::make_tuple(static_cast<float>(mask_reorder),
bias_scales));
} else { } else {
p_scales_tuple = handler.get_int8_bias_scales(ctx); p_scales_tuple = handler.get_int8_bias_scales(ctx);
} }
auto bias_memory_p = auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
handler.AcquireBiasMemoryWithReorder(bias, bias,
true, true,
std::get<1>(*p_scales_tuple), std::get<1>(*p_scales_tuple),
std::get<0>(*p_scales_tuple)); std::get<0>(*p_scales_tuple));
...@@ -985,10 +1011,28 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -985,10 +1011,28 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
} }
output->set_mem_desc(dst_memory_p->get_desc()); output->set_mem_desc(dst_memory_p->get_desc());
}));
} }
}; };
template <typename T, typename K> #define PD_VISIT_FLOAT_AND_BF16_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::BFLOAT16, \
::phi::dtype::bfloat16, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
template <typename T>
class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -1013,8 +1057,10 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1013,8 +1057,10 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
if (!input_grad && !filter_grad) return; if (!input_grad && !filter_grad) return;
PD_VISIT_FLOAT_AND_BF16_TYPES(
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
// TODO(jczaja): Are all tensors really needed? // TODO(jczaja): Are all tensors really needed?
ConvMKLDNNHandlerT<T, K, T> handler( ConvMKLDNNHandlerT<T, data_t, T> handler(
ctx, ctx,
dev_ctx, dev_ctx,
ctx.GetPlace(), ctx.GetPlace(),
...@@ -1053,16 +1099,17 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1053,16 +1099,17 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
astream.wait(); astream.wait();
// For convolution with groups convert from blocked to NCHW // For convolution with groups convert from blocked to NCHW
// otherwise there will be problems in next operators working on this data // otherwise there will be problems in next operators working on
// this data
if (g > 1) { if (g > 1) {
// in OneDNN groups in convolution are treated as separate dimension // in OneDNN groups in convolution are treated as separate
// which is not the case in paddlepaddle // dimension which is not the case in paddlepaddle
dnnl::memory::data_type in_type = framework::ToMKLDNNDataType( dnnl::memory::data_type in_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(filter->dtype())); framework::TransToProtoVarType(filter->dtype()));
// for 3d conv with groups (six dimensional data reorder to goidhw) // for 3d conv with groups (six dimensional data reorder to
// for 2d conv with groups (five dimensional data reorder to goihw) // goidhw) for 2d conv with groups (five dimensional data reorder
// auto weights_tz = phi::vectorize(filter->dims()); // to goihw) auto weights_tz = phi::vectorize(filter->dims());
auto weights_tz = diff_weights_memory_p->get_desc().dims(); auto weights_tz = diff_weights_memory_p->get_desc().dims();
dnnl::memory::format_tag out_format = dnnl::memory::format_tag out_format =
...@@ -1073,11 +1120,11 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1073,11 +1120,11 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
framework::TransToProtoVarType(filter->dtype()), framework::TransToProtoVarType(filter->dtype()),
in_type, in_type,
mkldnn_engine); mkldnn_engine);
auto reorder_dst_memory_p = auto reorder_dst_memory_p = handler.AcquireDstMemory(
handler.AcquireDstMemory(filter_grad, out_format, ctx.GetPlace()); filter_grad, out_format, ctx.GetPlace());
auto reorder_p = auto reorder_p = handler.AcquireReorder(reorder_dst_memory_p,
handler.AcquireReorder(reorder_dst_memory_p, diff_weights_memory_p); diff_weights_memory_p);
{ {
platform::RecordEvent record_reorder( platform::RecordEvent record_reorder(
...@@ -1090,14 +1137,14 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1090,14 +1137,14 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
astream.wait(); astream.wait();
} }
// So here we have a data in goihw , which can be interpreted as OIHW // So here we have a data in goihw , which can be interpreted as
// (OIDHW for conv3d) // OIHW (OIDHW for conv3d) because filter_grad shape is set for
// because filter_grad shape is set for OIHW (OIDHW for conv3d) // OIHW (OIDHW for conv3d)
dnnl::memory::format_tag target_format = dnnl::memory::format_tag target_format =
weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw
: dnnl::memory::format_tag::oihw; : dnnl::memory::format_tag::oihw;
filter_grad->set_mem_desc( filter_grad->set_mem_desc(dnnl::memory::desc(
dnnl::memory::desc(phi::vectorize<int64_t>(filter_grad->dims()), phi::vectorize<int64_t>(filter_grad->dims()),
in_type, in_type,
target_format)); target_format));
} else { } else {
...@@ -1126,6 +1173,7 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1126,6 +1173,7 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
input_grad->set_mem_desc(diff_src_memory_p->get_desc()); input_grad->set_mem_desc(diff_src_memory_p->get_desc());
} }
}));
} }
}; };
...@@ -1134,119 +1182,40 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1134,119 +1182,40 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, REGISTER_OP_KERNEL(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<uint8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
U8WS8,
ops::kConvMKLDNNINT8WS8,
ops::ConvMKLDNNOpKernel<uint8_t, int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
S8WS8,
ops::kConvMKLDNNINT8WS8,
ops::ConvMKLDNNOpKernel<int8_t, int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
U8, ops::ConvMKLDNNOpKernel<float>,
ops::kConvMKLDNNINT8, ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::ConvMKLDNNOpKernel<uint8_t, float>); ops::ConvMKLDNNOpKernel<uint8_t>,
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, REGISTER_OP_KERNEL(conv2d_grad,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
S8, ops::ConvMKLDNNGradOpKernel<float>,
ops::kConvMKLDNNINT8, ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16>);
ops::ConvMKLDNNOpKernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d_grad, REGISTER_OP_KERNEL(depthwise_conv2d,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
FP32, ops::ConvMKLDNNOpKernel<float>,
ops::kConvMKLDNNFP32, ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::ConvMKLDNNGradOpKernel<float, float>); ops::ConvMKLDNNOpKernel<uint8_t>,
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( REGISTER_OP_KERNEL(depthwise_conv2d_grad,
depthwise_conv2d_grad,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
BF16, ops::ConvMKLDNNGradOpKernel<float>,
ops::kConvMKLDNNFP32, ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16>);
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, REGISTER_OP_KERNEL(conv3d,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
FP32, ops::ConvMKLDNNOpKernel<float>);
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, REGISTER_OP_KERNEL(conv3d_grad,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
FP32, ops::ConvMKLDNNGradOpKernel<float>);
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float, float>);
...@@ -36,7 +36,7 @@ PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN); ...@@ -36,7 +36,7 @@ PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN);
USE_OP_ITSELF(softmax); USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP_ITSELF(conv2d); USE_OP_ITSELF(conv2d);
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32); USE_OP_DEVICE_KERNEL(conv2d, MKLDNN);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册