diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index 8ba66eb4da36fa47a716db918d8582eb9d01541e..ab14732e4d6eab9dd15364da02b436c10ed68a19 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -97,6 +97,8 @@ TEST(OperatorBase, all) { namespace paddle { namespace framework { +static int special_type_value = 1; + class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: void Make() { @@ -211,15 +213,12 @@ REGISTER_OP_WITHOUT_GRADIENT( op_with_kernel, paddle::framework::OpWithKernelTest, paddle::framework::OpKernelTestProtoAndCheckerMaker); -// REGISTER_OP_CPU_KERNEL(op_with_kernel, -// paddle::framework::CPUKernelTest); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - op_with_kernel, CPU, paddle::platform::CPUPlace, DEFAULT_TYPE, 0, - paddle::framework::CPUKernelTest); +REGISTER_OP_CPU_KERNEL(op_with_kernel, + paddle::framework::CPUKernelTest); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - op_with_kernel, CPU, paddle::platform::CPUPlace, SPECIAL, 1, + op_with_kernel, CPU, paddle::platform::CPUPlace, MY_SPECIAL_NAME, + paddle::framework::special_type_value, paddle::framework::CPUKernel2Test); // test with single input @@ -241,6 +240,7 @@ TEST(OpKernel, all) { auto op = paddle::framework::OpRegistry::CreateOp(op_desc); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); op->Run(scope, cpu_place); + // kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called. ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0); @@ -250,6 +250,7 @@ TEST(OpKernel, all) { attr->set_i(1); auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc); op2->Run(scope, cpu_place); + // kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called. ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1); } diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 9a5dc740348d19484ced1b50c56f81de78f2048b..7455b9492f054b32ee7fb1fc90b1a344367ceb81 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -345,6 +345,8 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; framework::LibraryType library_{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready std::string data_format = ctx.Attr("data_format"); @@ -360,12 +362,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; + customized_type_value = kConvMKLDNNFP32; } #endif return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout_, library_); + layout_, library_, customized_type_value); } } // namespace operators