未验证 提交 e9288340 编写于 作者: A Adam Osewski 提交者: GitHub

[OneDNN] Conv op refactor. (#36252)

* Remove unused header.

* Use ConvMKLDNNHandlerT for conv2d INT8.

* Use absolute module path to import.
上级 dc4d5719
...@@ -23,7 +23,6 @@ limitations under the License. */ ...@@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -531,7 +531,13 @@ inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) { ...@@ -531,7 +531,13 @@ inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) { inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32"; return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
} }
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP }; enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };
template <typename T>
bool constexpr is_int8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -20,7 +20,8 @@ import numpy as np ...@@ -20,7 +20,8 @@ import numpy as np
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest, convert_float_to_uint16, get_numeric_gradient from paddle.fluid.tests.unittests.op_test import (
OpTest, convert_float_to_uint16, get_numeric_gradient)
from paddle.fluid.tests.unittests.testsuite import create_op from paddle.fluid.tests.unittests.testsuite import create_op
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
......
...@@ -22,7 +22,7 @@ import paddle.nn as nn ...@@ -22,7 +22,7 @@ import paddle.nn as nn
paddle.enable_static() paddle.enable_static()
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
def conv2dtranspose_forward_naive(input_, filter_, attrs): def conv2dtranspose_forward_naive(input_, filter_, attrs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册