未验证 提交 7ee727a8 编写于 作者: C CtfGo 提交者: GitHub

cinn_launch_op: remove the check on extracting temporary variables (#36997)

cinn_launch_op: remove the check on extracting temporary variables
上级 09d407b0
...@@ -137,9 +137,7 @@ std::vector<std::string> SeperateTempVar( ...@@ -137,9 +137,7 @@ std::vector<std::string> SeperateTempVar(
[](const auto& name_view) { return std::string(name_view.data()); }); [](const auto& name_view) { return std::string(name_view.data()); });
auto exclude_fn = [&all_cinn_names](const auto& cinn_name) { auto exclude_fn = [&all_cinn_names](const auto& cinn_name) {
PADDLE_ENFORCE_EQ(all_cinn_names.erase(cinn_name), 1, all_cinn_names.erase(cinn_name);
platform::errors::NotFound(
"Variable(%s) not found in cinn scope", cinn_name));
}; };
std::for_each(input_cinn_names.begin(), input_cinn_names.end(), exclude_fn); std::for_each(input_cinn_names.begin(), input_cinn_names.end(), exclude_fn);
......
...@@ -178,7 +178,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -178,7 +178,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
tensor); tensor);
} }
VLOG(4) << "Prepare outnput argument-" << i << ":" VLOG(4) << "Prepare output argument-" << i << ":"
<< "name(" << var_name << "->" << cinn_name << "), " << "name(" << var_name << "->" << cinn_name << "), "
<< "tensor(type:" << tensor->type() << "," << "tensor(type:" << tensor->type() << ","
<< "dims:" << tensor->dims() << ")."; << "dims:" << tensor->dims() << ").";
...@@ -187,9 +187,12 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -187,9 +187,12 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
hold_buffers.emplace_back(std::move(buffer)); hold_buffers.emplace_back(std::move(buffer));
} }
// 3.3 Prepare temporary variables: Create a temporary scope // 3.3 Prepare internal or temporary variables: Create a temporary
// to keep temporary variables needed by compiled runtime program // scope to keep internal variables within graph or temporary
// in addition, they directly use the names from CinnScope. // variables needed by the compiled runtime program in addition.
// Here we directly use the names from CinnScope as Paddle variable
// names, because they will not be used outside the graph
// and should be destructed after computation finished.
auto temp_variable_names = details::SeperateTempVar( auto temp_variable_names = details::SeperateTempVar(
cinn_scope, input_cinn_names, output_cinn_names); cinn_scope, input_cinn_names, output_cinn_names);
auto temp_scope = scope.NewTmpScope(); auto temp_scope = scope.NewTmpScope();
......
...@@ -263,9 +263,6 @@ TEST(CinnLaunchOpHelperTest, TestSeperateTempVar) { ...@@ -263,9 +263,6 @@ TEST(CinnLaunchOpHelperTest, TestSeperateTempVar) {
SeperateTempVar(cinn_scope, {"cinn_var1", "cinn_var2"}, {"cinn_var4"}); SeperateTempVar(cinn_scope, {"cinn_var1", "cinn_var2"}, {"cinn_var4"});
ASSERT_EQ(temp_names.size(), 1); ASSERT_EQ(temp_names.size(), 1);
EXPECT_EQ(temp_names.front(), "cinn_var3"); EXPECT_EQ(temp_names.front(), "cinn_var3");
ASSERT_THROW(
SeperateTempVar(cinn_scope, {"cinn_var1", "not_exist"}, {"cinn_var4"}),
paddle::platform::EnforceNotMet);
} }
TEST(CinnLaunchOpHelperTest, TestShareTensorWithCinnBuffer) { TEST(CinnLaunchOpHelperTest, TestShareTensorWithCinnBuffer) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册