未验证 提交 7ee727a8 编写于 作者: C CtfGo 提交者: GitHub

cinn_launch_op: remove the check on extracting temporary variables (#36997)

cinn_launch_op: remove the check on extracting temporary variables
上级 09d407b0
......@@ -137,9 +137,7 @@ std::vector<std::string> SeperateTempVar(
[](const auto& name_view) { return std::string(name_view.data()); });
auto exclude_fn = [&all_cinn_names](const auto& cinn_name) {
PADDLE_ENFORCE_EQ(all_cinn_names.erase(cinn_name), 1,
platform::errors::NotFound(
"Variable(%s) not found in cinn scope", cinn_name));
all_cinn_names.erase(cinn_name);
};
std::for_each(input_cinn_names.begin(), input_cinn_names.end(), exclude_fn);
......
......@@ -178,7 +178,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
tensor);
}
VLOG(4) << "Prepare outnput argument-" << i << ":"
VLOG(4) << "Prepare output argument-" << i << ":"
<< "name(" << var_name << "->" << cinn_name << "), "
<< "tensor(type:" << tensor->type() << ","
<< "dims:" << tensor->dims() << ").";
......@@ -187,9 +187,12 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
hold_buffers.emplace_back(std::move(buffer));
}
// 3.3 Prepare temporary variables: Create a temporary scope
// to keep temporary variables needed by compiled runtime program
// in addition, they directly use the names from CinnScope.
// 3.3 Prepare internal or temporary variables: Create a temporary
// scope to keep internal variables within graph or temporary
// variables needed by the compiled runtime program in addition.
// Here we directly use the names from CinnScope as Paddle variable
// names, because they will not be used outside the graph
// and should be destructed after computation finished.
auto temp_variable_names = details::SeperateTempVar(
cinn_scope, input_cinn_names, output_cinn_names);
auto temp_scope = scope.NewTmpScope();
......
......@@ -263,9 +263,6 @@ TEST(CinnLaunchOpHelperTest, TestSeperateTempVar) {
SeperateTempVar(cinn_scope, {"cinn_var1", "cinn_var2"}, {"cinn_var4"});
ASSERT_EQ(temp_names.size(), 1);
EXPECT_EQ(temp_names.front(), "cinn_var3");
ASSERT_THROW(
SeperateTempVar(cinn_scope, {"cinn_var1", "not_exist"}, {"cinn_var4"}),
paddle::platform::EnforceNotMet);
}
TEST(CinnLaunchOpHelperTest, TestShareTensorWithCinnBuffer) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册