From 73461a7ae63bdab5ed04ce6bf204c0b8a169e31a Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 11 Dec 2019 03:55:13 -0600 Subject: [PATCH] Make OperatorWithKernel::InferShape abstract (#21633) * make OperatorWithKernel::InferShape virtual, test=develop * fix test_prepare_op by relu, test=develop --- paddle/fluid/framework/operator.h | 4 +- paddle/fluid/imperative/tests/CMakeLists.txt | 2 +- .../fluid/imperative/tests/test_prepare_op.cc | 50 +++++++++++-------- 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 18e56f25b5..f30620a0a7 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -488,9 +488,7 @@ class OperatorWithKernel : public OperatorBase { }); } - virtual void InferShape(InferShapeContext* ctx) const { - Info().infer_shape_(ctx); - } + virtual void InferShape(InferShapeContext* ctx) const = 0; void RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const override; diff --git a/paddle/fluid/imperative/tests/CMakeLists.txt b/paddle/fluid/imperative/tests/CMakeLists.txt index 67e6294f8b..a128a774d0 100644 --- a/paddle/fluid/imperative/tests/CMakeLists.txt +++ b/paddle/fluid/imperative/tests/CMakeLists.txt @@ -7,5 +7,5 @@ endif(WIN32) cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows selected_rows_functor gradient_accumulator) cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy) -cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split assign_op place) +cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index 48065eafa7..4304376a9e 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -103,15 +103,18 @@ TEST(test_prepare_op, test_prepare_op) { imperative::NameVarBaseMap outs = {out_pair}; framework::AttributeMap split_attr_map; const auto& info = framework::OpInfoMap::Instance().Get("split"); + if (info.Checker()) info.Checker()->Check(&split_attr_map); framework::VariableNameMap var_in_map = CreateVarNameMap(info, "split", ins, true); framework::VariableNameMap var_out_map = CreateVarNameMap(info, "split", outs, false); - framework::OperatorWithKernel op("split", var_in_map, var_out_map, - split_attr_map); + auto op = framework::OpRegistry::CreateOp("split", var_in_map, var_out_map, + split_attr_map); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare( - ins, outs, op, place, &split_attr_map)); + ins, outs, + dynamic_cast(*op), + place, &split_attr_map)); } const framework::Tensor* GetTensorFromVar(const framework::Variable& var); @@ -147,19 +150,22 @@ TEST(test_prepare_op, test_prepare_data) { var_pair out_pair = var_pair("Out", vb_vector(1, vout)); imperative::NameVarBaseMap ins = {x_pair}; imperative::NameVarBaseMap outs = {out_pair}; - framework::AttributeMap assign_attr_map; - const auto& info = framework::OpInfoMap::Instance().Get("assign"); + const std::string op_type = "relu"; + framework::AttributeMap attr_map; + const auto& info = framework::OpInfoMap::Instance().Get(op_type); + if (info.Checker()) info.Checker()->Check(&attr_map); framework::VariableNameMap var_in_map = - CreateVarNameMap(info, "assign", ins, true); + CreateVarNameMap(info, op_type, ins, true); framework::VariableNameMap var_out_map = - CreateVarNameMap(info, "assign", outs, false); - framework::OperatorWithKernel assign_op("assign", var_in_map, var_out_map, - assign_attr_map); + CreateVarNameMap(info, op_type, outs, false); + auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map, + attr_map); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); // test if it can be transformed to GPU place - PreparedOp prepared_op = - PreparedOp::Prepare(ins, outs, assign_op, gpu_place, &assign_attr_map); + PreparedOp prepared_op = PreparedOp::Prepare( + ins, outs, dynamic_cast(*op), gpu_place, + &attr_map); for (const auto& name_pair : ins) { for (const auto& vb : name_pair.second) { ASSERT_TRUE(platform::is_same_place( @@ -191,19 +197,23 @@ TEST(test_prepare_op, test_prepare_data_same_place) { var_pair out_pair = var_pair("Out", vb_vector(1, vout)); imperative::NameVarBaseMap ins = {x_pair}; imperative::NameVarBaseMap outs = {out_pair}; - framework::AttributeMap assign_attr_map; - const auto& info = framework::OpInfoMap::Instance().Get("assign"); + framework::AttributeMap attr_map; + const std::string op_type = "relu"; + const auto& info = framework::OpInfoMap::Instance().Get(op_type); + if (info.Checker()) info.Checker()->Check(&attr_map); framework::VariableNameMap var_in_map = - CreateVarNameMap(info, "assign", ins, true); + CreateVarNameMap(info, op_type, ins, true); framework::VariableNameMap var_out_map = - CreateVarNameMap(info, "assign", outs, false); - framework::OperatorWithKernel assign_op("assign", var_in_map, var_out_map, - assign_attr_map); + CreateVarNameMap(info, op_type, outs, false); + + auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map, + attr_map); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); // test if it never transfered on GPU place - PreparedOp prepared_op = - PreparedOp::Prepare(ins, outs, assign_op, cpu_place, &assign_attr_map); + PreparedOp prepared_op = PreparedOp::Prepare( + ins, outs, dynamic_cast(*op), cpu_place, + &attr_map); for (const auto& name_pair : ins) { for (const auto& vb : name_pair.second) { ASSERT_TRUE(platform::is_same_place( @@ -215,4 +225,4 @@ TEST(test_prepare_op, test_prepare_data_same_place) { } // namespace paddle USE_OP(split); -USE_OP(assign); +USE_OP(relu); -- GitLab