未验证 提交 981d1a10 编写于 作者: Z zhangbo9674 提交者: GitHub

Refine shape op lanch method for standalone executor (#47843)

* refine shape op in new_exe

* Revert "refine shape op in new_exe"

This reverts commit 0e0336ddc5eede3da019b348a0bcc0ef0f3be64e.

* refine shape op in new_exe

* refine shape expected_kernel_type

* add SelectedRows check for shape op

* refine code
上级 9a6465ca
......@@ -321,6 +321,9 @@ OpFuncType AnalyseOpFuncType(const OpFuncNode& op_func_node,
return OpFuncType::kQueueSync;
}
if (op->Type() == "shape") {
return OpFuncType::kQueueSync;
}
return OpFuncType::kQueueAsync;
}
......
......@@ -90,6 +90,13 @@ std::vector<size_t> StreamAnalyzer::GetNeedEventVarIds(
return false;
};
auto is_shape_op = [](std::string op_name) {
if (op_name == "shape") {
return true;
}
return false;
};
bool is_memcpy =
interpreter::IsMemcpyOp(cur_instr) || interpreter::IsMemcpyOp(next_instr);
......@@ -97,7 +104,7 @@ std::vector<size_t> StreamAnalyzer::GetNeedEventVarIds(
for (auto& item : next_instr.Inputs()) {
for (auto var_id : item.second) {
if (unique_var_ids.count(var_id) > 0) {
if (is_memcpy) {
if (is_memcpy || is_shape_op(next_instr.OpBase()->Type())) {
if (next_instr.NoDataTransformVars().count(var_id)) {
VLOG(4) << "Skip inserting event at variable " << item.first
<< " of operator " << next_instr.OpBase()->Type()
......@@ -239,6 +246,11 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
*/
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) {
if ((cur_instr.KernelType() == OpFuncType::kQueueSync) &&
(next_instr.KernelType() == OpFuncType::kQueueSync)) {
return true;
}
if (cur_instr.KernelType() == next_instr.KernelType() &&
(&cur_instr.DeviceContext() == &next_instr.DeviceContext())) {
return true;
......
......@@ -2553,13 +2553,6 @@ void OperatorWithKernel::ParseInputDataType(
}
}
if (t != nullptr) {
PADDLE_ENFORCE_EQ(t->IsInitialized(),
true,
platform::errors::InvalidArgument(
"The %s Op's Input Variable `%s` "
"contains uninitialized phi::DenseTensor.",
Type(),
name));
*data_type = paddle::framework::TransToProtoVarType(t->dtype());
}
}
......
......@@ -429,61 +429,6 @@ REGISTER_OP_CPU_KERNEL(
indicate_other_data_type_test,
paddle::framework::EmptyTestKernel<phi::CPUContext, int>);
TEST(IndicateVarDataTypeTest, lodtensor) {
paddle::framework::InitDevices();
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_lod_tensor_data_type_test");
BuildVar("phi::DenseTensor", {"lodtensor_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("lodtensor_1");
var->GetMutable<phi::DenseTensor>();
bool caught = false;
try {
op->Run(scope, cpu_place);
} catch (paddle::platform::EnforceNotMet& err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(
ex_msg.find(
"The indicate_lod_tensor_data_type_test Op's Input Variable "
"`phi::DenseTensor` contains uninitialized phi::DenseTensor.") !=
std::string::npos);
}
ASSERT_TRUE(caught);
}
TEST(IndicateVarDataTypeTest, selectedrows) {
paddle::framework::InitDevices();
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_selected_rows_data_type_test");
BuildVar("SelectedRows", {"selected_rows_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("selected_rows_1");
var->GetMutable<phi::SelectedRows>();
bool caught = false;
try {
op->Run(scope, cpu_place);
} catch (paddle::platform::EnforceNotMet& err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(
ex_msg.find("The indicate_selected_rows_data_type_test Op's "
"Input Variable `SelectedRows` contains uninitialized "
"phi::DenseTensor.") != std::string::npos);
}
ASSERT_TRUE(caught);
}
TEST(IndicateVarDataTypeTest, other) {
paddle::framework::InitDevices();
paddle::framework::proto::OpDesc op_desc;
......
......@@ -60,6 +60,8 @@ Return the shape of the input.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ShapeNoNeedBufferVarsInferer, "Input");
} // namespace operators
} // namespace paddle
......@@ -76,4 +78,5 @@ REGISTER_OPERATOR(
ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::ShapeNoNeedBufferVarsInferer,
ShapeInferShapeFunctor);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册