未验证 提交 d78dd7ea 编写于 作者: H HongyuJia 提交者: GitHub

[MKLDNN] Delete mkldnn hard code of prior_box (#47068)

* remove prior_box mkldnn hard code

* add header file

* simplify PD_VISIT_TYPE

* decouple dependency between prior_box and density_prior_box

* fix pragma omp parallel error

* bypass #pragma omp_parallel_for error

* polish code

* remove visit_type headerfile

* polish codestyle

* polish codestyle

* try fix CI error

* add testcase, datatype=float64

* reset test_prior_box testcase

* add datacheck to DenseTensor

* update template name

* call prior_box with macro expand
上级 40ce7f4a
...@@ -34,26 +34,6 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -34,26 +34,6 @@ class PriorBoxOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_input_type = auto input_input_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input"); OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_input_type)) {
auto input_image_type = framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Image")->dtype());
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_image_type == framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kPriorBoxFLOAT;
} else if (input_image_type ==
framework::DataTypeTrait<double>::DataType()) {
customized_type_value = kPriorBoxDOUBLE;
}
return framework::OpKernelType(input_input_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
return framework::OpKernelType(input_input_type, ctx.GetPlace()); return framework::OpKernelType(input_input_type, ctx.GetPlace());
} }
}; };
...@@ -209,44 +189,10 @@ REGISTER_OPERATOR( ...@@ -209,44 +189,10 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
PriorBoxInferShapeFunctor); PriorBoxInferShapeFunctor);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, REGISTER_OP_KERNEL(prior_box,
MKLDNN,
::paddle::platform::CPUPlace,
FF,
ops::kPriorBoxFLOAT,
ops::PriorBoxOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box,
MKLDNN,
::paddle::platform::CPUPlace,
DD,
ops::kPriorBoxDOUBLE,
ops::PriorBoxOpKernel<double, double>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box,
MKLDNN,
::paddle::platform::CPUPlace,
U8F,
ops::kPriorBoxFLOAT,
ops::PriorBoxOpKernel<uint8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box,
MKLDNN,
::paddle::platform::CPUPlace,
S8F,
ops::kPriorBoxFLOAT,
ops::PriorBoxOpKernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box,
MKLDNN,
::paddle::platform::CPUPlace,
U8D,
ops::kPriorBoxDOUBLE,
ops::PriorBoxOpKernel<uint8_t, double>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
S8D, ops::PriorBoxOpKernel<float>,
ops::kPriorBoxDOUBLE, ops::PriorBoxOpKernel<double>,
ops::PriorBoxOpKernel<int8_t, double>); ops::PriorBoxOpKernel<uint8_t>,
ops::PriorBoxOpKernel<int8_t>);
...@@ -18,14 +18,12 @@ limitations under the License. */ ...@@ -18,14 +18,12 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr int kPriorBoxFLOAT = 1;
constexpr int kPriorBoxDOUBLE = 2;
inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior, inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
bool flip, bool flip,
std::vector<float>* output_aspect_ratior) { std::vector<float>* output_aspect_ratior) {
...@@ -50,10 +48,19 @@ inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior, ...@@ -50,10 +48,19 @@ inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
} }
} }
template <typename T, typename K> template <typename T>
class PriorBoxOpKernel : public framework::OpKernel<T> { class PriorBoxOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* image = ctx.Input<phi::DenseTensor>("Image");
PD_VISIT_FLOATING_TYPES(image->dtype(), "PriorBoxOpHandler", ([&] {
PriorBoxOpHandler<data_t>(ctx);
}));
}
template <typename K>
void PriorBoxOpHandler(const framework::ExecutionContext& ctx) const {
auto* input = ctx.Input<phi::DenseTensor>("Input"); auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* image = ctx.Input<phi::DenseTensor>("Image"); auto* image = ctx.Input<phi::DenseTensor>("Image");
auto* boxes = ctx.Output<phi::DenseTensor>("Boxes"); auto* boxes = ctx.Output<phi::DenseTensor>("Boxes");
...@@ -200,7 +207,7 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -200,7 +207,7 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
} }
vars->Resize(var_dim); vars->Resize(var_dim);
} }
}; // namespace operators };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -54,10 +54,7 @@ static const std::unordered_set<std::string> mkldnn_white_list = { ...@@ -54,10 +54,7 @@ static const std::unordered_set<std::string> mkldnn_white_list = {
"flatten", "flatten",
"flatten_grad", "flatten_grad",
"flatten2", "flatten2",
"flatten2_grad", "flatten2_grad"};
// NOTE(jiahongyu): Below ops register kernel with customized_type_value, we
// need to analysis and solve them one-by-one.
"prior_box"};
inline bool in_mkldnn_white_list(const std::string& op_name) { inline bool in_mkldnn_white_list(const std::string& op_name) {
return mkldnn_white_list.find(op_name) != mkldnn_white_list.end(); return mkldnn_white_list.find(op_name) != mkldnn_white_list.end();
......
...@@ -139,8 +139,10 @@ const T* DenseTensor::data() const { ...@@ -139,8 +139,10 @@ const T* DenseTensor::data() const {
dtype(), dtype(),
paddle::experimental::CppTypeToDataType<T>::Type(), paddle::experimental::CppTypeToDataType<T>::Type(),
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The type of data we are trying to retrieve does not match the " "The type of data we are trying to retrieve (%s) does not match the "
"type of data currently contained in the container.")); "type of data (%s) currently contained in the container.",
paddle::experimental::CppTypeToDataType<T>::Type(),
dtype()));
return static_cast<const T*>(data()); return static_cast<const T*>(data());
} }
...@@ -150,8 +152,10 @@ T* DenseTensor::data() { ...@@ -150,8 +152,10 @@ T* DenseTensor::data() {
PADDLE_ENFORCE( PADDLE_ENFORCE(
(dtype() == paddle::experimental::CppTypeToDataType<T>::Type()), (dtype() == paddle::experimental::CppTypeToDataType<T>::Type()),
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The type of data we are trying to retrieve does not match the " "The type of data we are trying to retrieve (%s) does not match the "
"type of data currently contained in the container.")); "type of data (%s) currently contained in the container.",
paddle::experimental::CppTypeToDataType<T>::Type(),
dtype()));
return ret; return ret;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册