未验证 提交 73461a7a 编写于 作者: Z Zeng Jinle 提交者: GitHub

Make OperatorWithKernel::InferShape abstract (#21633)

* make OperatorWithKernel::InferShape virtual, test=develop

* fix test_prepare_op by relu, test=develop
上级 686f0ecb
...@@ -488,9 +488,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -488,9 +488,7 @@ class OperatorWithKernel : public OperatorBase {
}); });
} }
virtual void InferShape(InferShapeContext* ctx) const { virtual void InferShape(InferShapeContext* ctx) const = 0;
Info().infer_shape_(ctx);
}
void RuntimeInferShape(const Scope& scope, const platform::Place& place, void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override; const RuntimeContext& ctx) const override;
......
...@@ -7,5 +7,5 @@ endif(WIN32) ...@@ -7,5 +7,5 @@ endif(WIN32)
cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows selected_rows_functor gradient_accumulator) cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows selected_rows_functor gradient_accumulator)
cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy) cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy)
cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split assign_op place) cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place)
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
...@@ -103,15 +103,18 @@ TEST(test_prepare_op, test_prepare_op) { ...@@ -103,15 +103,18 @@ TEST(test_prepare_op, test_prepare_op) {
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap split_attr_map; framework::AttributeMap split_attr_map;
const auto& info = framework::OpInfoMap::Instance().Get("split"); const auto& info = framework::OpInfoMap::Instance().Get("split");
if (info.Checker()) info.Checker()->Check(&split_attr_map);
framework::VariableNameMap var_in_map = framework::VariableNameMap var_in_map =
CreateVarNameMap(info, "split", ins, true); CreateVarNameMap(info, "split", ins, true);
framework::VariableNameMap var_out_map = framework::VariableNameMap var_out_map =
CreateVarNameMap(info, "split", outs, false); CreateVarNameMap(info, "split", outs, false);
framework::OperatorWithKernel op("split", var_in_map, var_out_map, auto op = framework::OpRegistry::CreateOp("split", var_in_map, var_out_map,
split_attr_map); split_attr_map);
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare( ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare(
ins, outs, op, place, &split_attr_map)); ins, outs,
dynamic_cast<framework::OperatorWithKernel&>(*op),
place, &split_attr_map));
} }
const framework::Tensor* GetTensorFromVar(const framework::Variable& var); const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
...@@ -147,19 +150,22 @@ TEST(test_prepare_op, test_prepare_data) { ...@@ -147,19 +150,22 @@ TEST(test_prepare_op, test_prepare_data) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout)); var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {x_pair}; imperative::NameVarBaseMap ins = {x_pair};
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap assign_attr_map; const std::string op_type = "relu";
const auto& info = framework::OpInfoMap::Instance().Get("assign"); framework::AttributeMap attr_map;
const auto& info = framework::OpInfoMap::Instance().Get(op_type);
if (info.Checker()) info.Checker()->Check(&attr_map);
framework::VariableNameMap var_in_map = framework::VariableNameMap var_in_map =
CreateVarNameMap(info, "assign", ins, true); CreateVarNameMap(info, op_type, ins, true);
framework::VariableNameMap var_out_map = framework::VariableNameMap var_out_map =
CreateVarNameMap(info, "assign", outs, false); CreateVarNameMap(info, op_type, outs, false);
framework::OperatorWithKernel assign_op("assign", var_in_map, var_out_map, auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map,
assign_attr_map); attr_map);
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
// test if it can be transformed to GPU place // test if it can be transformed to GPU place
PreparedOp prepared_op = PreparedOp prepared_op = PreparedOp::Prepare(
PreparedOp::Prepare(ins, outs, assign_op, gpu_place, &assign_attr_map); ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place,
&attr_map);
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) { for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place( ASSERT_TRUE(platform::is_same_place(
...@@ -191,19 +197,23 @@ TEST(test_prepare_op, test_prepare_data_same_place) { ...@@ -191,19 +197,23 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout)); var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {x_pair}; imperative::NameVarBaseMap ins = {x_pair};
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap assign_attr_map; framework::AttributeMap attr_map;
const auto& info = framework::OpInfoMap::Instance().Get("assign"); const std::string op_type = "relu";
const auto& info = framework::OpInfoMap::Instance().Get(op_type);
if (info.Checker()) info.Checker()->Check(&attr_map);
framework::VariableNameMap var_in_map = framework::VariableNameMap var_in_map =
CreateVarNameMap(info, "assign", ins, true); CreateVarNameMap(info, op_type, ins, true);
framework::VariableNameMap var_out_map = framework::VariableNameMap var_out_map =
CreateVarNameMap(info, "assign", outs, false); CreateVarNameMap(info, op_type, outs, false);
framework::OperatorWithKernel assign_op("assign", var_in_map, var_out_map,
assign_attr_map); auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map,
attr_map);
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
// test if it never transfered on GPU place // test if it never transfered on GPU place
PreparedOp prepared_op = PreparedOp prepared_op = PreparedOp::Prepare(
PreparedOp::Prepare(ins, outs, assign_op, cpu_place, &assign_attr_map); ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place,
&attr_map);
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) { for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place( ASSERT_TRUE(platform::is_same_place(
...@@ -215,4 +225,4 @@ TEST(test_prepare_op, test_prepare_data_same_place) { ...@@ -215,4 +225,4 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
} // namespace paddle } // namespace paddle
USE_OP(split); USE_OP(split);
USE_OP(assign); USE_OP(relu);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册