未验证 提交 73461a7a 编写于 作者: Z Zeng Jinle 提交者: GitHub

Make OperatorWithKernel::InferShape abstract (#21633)

* make OperatorWithKernel::InferShape virtual, test=develop

* fix test_prepare_op by relu, test=develop
上级 686f0ecb
......@@ -488,9 +488,7 @@ class OperatorWithKernel : public OperatorBase {
});
}
virtual void InferShape(InferShapeContext* ctx) const {
Info().infer_shape_(ctx);
}
virtual void InferShape(InferShapeContext* ctx) const = 0;
void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override;
......
......@@ -7,5 +7,5 @@ endif(WIN32)
cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows selected_rows_functor gradient_accumulator)
cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy)
cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split assign_op place)
cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place)
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
......@@ -103,15 +103,18 @@ TEST(test_prepare_op, test_prepare_op) {
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap split_attr_map;
const auto& info = framework::OpInfoMap::Instance().Get("split");
if (info.Checker()) info.Checker()->Check(&split_attr_map);
framework::VariableNameMap var_in_map =
CreateVarNameMap(info, "split", ins, true);
framework::VariableNameMap var_out_map =
CreateVarNameMap(info, "split", outs, false);
framework::OperatorWithKernel op("split", var_in_map, var_out_map,
auto op = framework::OpRegistry::CreateOp("split", var_in_map, var_out_map,
split_attr_map);
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare(
ins, outs, op, place, &split_attr_map));
ins, outs,
dynamic_cast<framework::OperatorWithKernel&>(*op),
place, &split_attr_map));
}
const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
......@@ -147,19 +150,22 @@ TEST(test_prepare_op, test_prepare_data) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {x_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap assign_attr_map;
const auto& info = framework::OpInfoMap::Instance().Get("assign");
const std::string op_type = "relu";
framework::AttributeMap attr_map;
const auto& info = framework::OpInfoMap::Instance().Get(op_type);
if (info.Checker()) info.Checker()->Check(&attr_map);
framework::VariableNameMap var_in_map =
CreateVarNameMap(info, "assign", ins, true);
CreateVarNameMap(info, op_type, ins, true);
framework::VariableNameMap var_out_map =
CreateVarNameMap(info, "assign", outs, false);
framework::OperatorWithKernel assign_op("assign", var_in_map, var_out_map,
assign_attr_map);
CreateVarNameMap(info, op_type, outs, false);
auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map,
attr_map);
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
// test if it can be transformed to GPU place
PreparedOp prepared_op =
PreparedOp::Prepare(ins, outs, assign_op, gpu_place, &assign_attr_map);
PreparedOp prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place,
&attr_map);
for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place(
......@@ -191,19 +197,23 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {x_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap assign_attr_map;
const auto& info = framework::OpInfoMap::Instance().Get("assign");
framework::AttributeMap attr_map;
const std::string op_type = "relu";
const auto& info = framework::OpInfoMap::Instance().Get(op_type);
if (info.Checker()) info.Checker()->Check(&attr_map);
framework::VariableNameMap var_in_map =
CreateVarNameMap(info, "assign", ins, true);
CreateVarNameMap(info, op_type, ins, true);
framework::VariableNameMap var_out_map =
CreateVarNameMap(info, "assign", outs, false);
framework::OperatorWithKernel assign_op("assign", var_in_map, var_out_map,
assign_attr_map);
CreateVarNameMap(info, op_type, outs, false);
auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map,
attr_map);
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
// test if it never transfered on GPU place
PreparedOp prepared_op =
PreparedOp::Prepare(ins, outs, assign_op, cpu_place, &assign_attr_map);
PreparedOp prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place,
&attr_map);
for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place(
......@@ -215,4 +225,4 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
} // namespace paddle
USE_OP(split);
USE_OP(assign);
USE_OP(relu);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册