未验证 提交 981d1a10 编写于 作者: Z zhangbo9674 提交者: GitHub

Refine shape op lanch method for standalone executor (#47843)

* refine shape op in new_exe

* Revert "refine shape op in new_exe"

This reverts commit 0e0336ddc5eede3da019b348a0bcc0ef0f3be64e.

* refine shape op in new_exe

* refine shape expected_kernel_type

* add SelectedRows check for shape op

* refine code
上级 9a6465ca
...@@ -321,6 +321,9 @@ OpFuncType AnalyseOpFuncType(const OpFuncNode& op_func_node, ...@@ -321,6 +321,9 @@ OpFuncType AnalyseOpFuncType(const OpFuncNode& op_func_node,
return OpFuncType::kQueueSync; return OpFuncType::kQueueSync;
} }
if (op->Type() == "shape") {
return OpFuncType::kQueueSync;
}
return OpFuncType::kQueueAsync; return OpFuncType::kQueueAsync;
} }
......
...@@ -90,6 +90,13 @@ std::vector<size_t> StreamAnalyzer::GetNeedEventVarIds( ...@@ -90,6 +90,13 @@ std::vector<size_t> StreamAnalyzer::GetNeedEventVarIds(
return false; return false;
}; };
auto is_shape_op = [](std::string op_name) {
if (op_name == "shape") {
return true;
}
return false;
};
bool is_memcpy = bool is_memcpy =
interpreter::IsMemcpyOp(cur_instr) || interpreter::IsMemcpyOp(next_instr); interpreter::IsMemcpyOp(cur_instr) || interpreter::IsMemcpyOp(next_instr);
...@@ -97,7 +104,7 @@ std::vector<size_t> StreamAnalyzer::GetNeedEventVarIds( ...@@ -97,7 +104,7 @@ std::vector<size_t> StreamAnalyzer::GetNeedEventVarIds(
for (auto& item : next_instr.Inputs()) { for (auto& item : next_instr.Inputs()) {
for (auto var_id : item.second) { for (auto var_id : item.second) {
if (unique_var_ids.count(var_id) > 0) { if (unique_var_ids.count(var_id) > 0) {
if (is_memcpy) { if (is_memcpy || is_shape_op(next_instr.OpBase()->Type())) {
if (next_instr.NoDataTransformVars().count(var_id)) { if (next_instr.NoDataTransformVars().count(var_id)) {
VLOG(4) << "Skip inserting event at variable " << item.first VLOG(4) << "Skip inserting event at variable " << item.first
<< " of operator " << next_instr.OpBase()->Type() << " of operator " << next_instr.OpBase()->Type()
...@@ -239,6 +246,11 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -239,6 +246,11 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
*/ */
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) { const Instruction& next_instr) {
if ((cur_instr.KernelType() == OpFuncType::kQueueSync) &&
(next_instr.KernelType() == OpFuncType::kQueueSync)) {
return true;
}
if (cur_instr.KernelType() == next_instr.KernelType() && if (cur_instr.KernelType() == next_instr.KernelType() &&
(&cur_instr.DeviceContext() == &next_instr.DeviceContext())) { (&cur_instr.DeviceContext() == &next_instr.DeviceContext())) {
return true; return true;
......
...@@ -2553,13 +2553,6 @@ void OperatorWithKernel::ParseInputDataType( ...@@ -2553,13 +2553,6 @@ void OperatorWithKernel::ParseInputDataType(
} }
} }
if (t != nullptr) { if (t != nullptr) {
PADDLE_ENFORCE_EQ(t->IsInitialized(),
true,
platform::errors::InvalidArgument(
"The %s Op's Input Variable `%s` "
"contains uninitialized phi::DenseTensor.",
Type(),
name));
*data_type = paddle::framework::TransToProtoVarType(t->dtype()); *data_type = paddle::framework::TransToProtoVarType(t->dtype());
} }
} }
......
...@@ -429,61 +429,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -429,61 +429,6 @@ REGISTER_OP_CPU_KERNEL(
indicate_other_data_type_test, indicate_other_data_type_test,
paddle::framework::EmptyTestKernel<phi::CPUContext, int>); paddle::framework::EmptyTestKernel<phi::CPUContext, int>);
TEST(IndicateVarDataTypeTest, lodtensor) {
paddle::framework::InitDevices();
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_lod_tensor_data_type_test");
BuildVar("phi::DenseTensor", {"lodtensor_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("lodtensor_1");
var->GetMutable<phi::DenseTensor>();
bool caught = false;
try {
op->Run(scope, cpu_place);
} catch (paddle::platform::EnforceNotMet& err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(
ex_msg.find(
"The indicate_lod_tensor_data_type_test Op's Input Variable "
"`phi::DenseTensor` contains uninitialized phi::DenseTensor.") !=
std::string::npos);
}
ASSERT_TRUE(caught);
}
TEST(IndicateVarDataTypeTest, selectedrows) {
paddle::framework::InitDevices();
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_selected_rows_data_type_test");
BuildVar("SelectedRows", {"selected_rows_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("selected_rows_1");
var->GetMutable<phi::SelectedRows>();
bool caught = false;
try {
op->Run(scope, cpu_place);
} catch (paddle::platform::EnforceNotMet& err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(
ex_msg.find("The indicate_selected_rows_data_type_test Op's "
"Input Variable `SelectedRows` contains uninitialized "
"phi::DenseTensor.") != std::string::npos);
}
ASSERT_TRUE(caught);
}
TEST(IndicateVarDataTypeTest, other) { TEST(IndicateVarDataTypeTest, other) {
paddle::framework::InitDevices(); paddle::framework::InitDevices();
paddle::framework::proto::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
......
...@@ -60,6 +60,8 @@ Return the shape of the input. ...@@ -60,6 +60,8 @@ Return the shape of the input.
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ShapeNoNeedBufferVarsInferer, "Input");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -76,4 +78,5 @@ REGISTER_OPERATOR( ...@@ -76,4 +78,5 @@ REGISTER_OPERATOR(
ops::ShapeOpMaker, ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::ShapeNoNeedBufferVarsInferer,
ShapeInferShapeFunctor); ShapeInferShapeFunctor);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册