From 82d68281c04f4fd1078cb7e8be72bd536b9366be Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 5 Dec 2018 15:13:56 +0800 Subject: [PATCH] follow comments test=develop --- paddle/fluid/framework/operator_test.cc | 15 ++++++++------- paddle/fluid/operators/conv_op.cc | 5 ++++- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index 8ba66eb4d..ab14732e4 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -97,6 +97,8 @@ TEST(OperatorBase, all) { namespace paddle { namespace framework { +static int special_type_value = 1; + class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: void Make() { @@ -211,15 +213,12 @@ REGISTER_OP_WITHOUT_GRADIENT( op_with_kernel, paddle::framework::OpWithKernelTest, paddle::framework::OpKernelTestProtoAndCheckerMaker); -// REGISTER_OP_CPU_KERNEL(op_with_kernel, -// paddle::framework::CPUKernelTest); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - op_with_kernel, CPU, paddle::platform::CPUPlace, DEFAULT_TYPE, 0, - paddle::framework::CPUKernelTest); +REGISTER_OP_CPU_KERNEL(op_with_kernel, + paddle::framework::CPUKernelTest); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - op_with_kernel, CPU, paddle::platform::CPUPlace, SPECIAL, 1, + op_with_kernel, CPU, paddle::platform::CPUPlace, MY_SPECIAL_NAME, + paddle::framework::special_type_value, paddle::framework::CPUKernel2Test); // test with single input @@ -241,6 +240,7 @@ TEST(OpKernel, all) { auto op = paddle::framework::OpRegistry::CreateOp(op_desc); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); op->Run(scope, cpu_place); + // kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called. ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0); @@ -250,6 +250,7 @@ TEST(OpKernel, all) { attr->set_i(1); auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc); op2->Run(scope, cpu_place); + // kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called. ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1); } diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 9a5dc7403..7455b9492 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -345,6 +345,8 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; framework::LibraryType library_{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready std::string data_format = ctx.Attr("data_format"); @@ -360,12 +362,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; + customized_type_value = kConvMKLDNNFP32; } #endif return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout_, library_); + layout_, library_, customized_type_value); } } // namespace operators -- GitLab