未验证 提交 26c824db 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] Support elementwise_add (#53955)

* [0D-Tensor] Support elementwise_add

* support elementwise_add ZeroDim2&3
上级 6fde2056
......@@ -94,13 +94,8 @@ FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const {
feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(*tensor);
}
PADDLE_ENFORCE_NE(
feed_map[feed_name].shape.size(),
0UL,
platform::errors::PreconditionNotMet(
"The input variable %s's tensor shape cannot be empty,"
"we need the variable's dtype and shape from tensor.",
feed_name.c_str()));
VLOG_IF(4, feed_map[feed_name].shape.size() == 0UL)
<< "Shape is empty, Create 0D-Tensor for " << feed_name;
}
return feed_map;
}
......
......@@ -58,6 +58,40 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
}
}
// CINN ops in this white list support 0D-Tensor
const std::unordered_set<std::string> white_op_list{"elementwise_add"};
std::unordered_set<std::string> white_tensor_name;
// enable white_op_list only when graph_node_size = 1, which means single op
// test
int graph_node_size = 0;
for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp()) {
graph_node_size++;
VLOG(6) << "Graph has op node " << n->Op()->Type();
if (white_op_list.find(n->Op()->Type()) != white_op_list.end()) {
for (const ir::Node* var : n->inputs) {
white_tensor_name.insert(var->Var()->Name());
std::vector<int64_t> shape = var->Var()->GetShape();
if (shape.empty()) {
VLOG(6) << "input var " << var->Name()
<< " dims is empty, keep it's 0D-Tensor status";
}
}
for (const ir::Node* var : n->outputs) {
white_tensor_name.insert(var->Var()->Name());
std::vector<int64_t> shape = var->Var()->GetShape();
if (shape.empty()) {
VLOG(6) << "output var " << var->Name()
<< " dims is empty, keep it's 0D-Tensor status";
}
}
}
}
}
VLOG(6) << "Graph has " << graph_node_size << " op node";
for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) {
if (n->Op()->HasAttr("shape")) {
......@@ -85,6 +119,11 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
}
if (n->IsVar()) {
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
if (graph_node_size == 1 && white_tensor_name.find(n->Var()->Name()) !=
white_tensor_name.end()) {
VLOG(6) << "Keep 0D-Tensor status of var " << n->Var()->Name();
continue;
}
std::vector<int64_t> shape = n->Var()->GetShape();
if (shape.empty()) {
shape.push_back(1);
......
......@@ -27,7 +27,7 @@ TEST(CinnZeroTensorTrickPass, basic) {
ir::Layers layers;
auto* x = layers.data("x", {});
auto* y = layers.data("y", {3, 4});
auto* add_out_0 = layers.elementwise_add(x, y, nullptr, 0);
auto* add_out_0 = layers.mul(x, y, nullptr, 0);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = ir::PassRegistry::Instance().Get("cinn_zero_tensor_trick_pass");
VLOG(3) << DebugString(graph);
......@@ -43,7 +43,7 @@ TEST(CinnZeroTensorTrickPass, basic) {
shape.empty(),
false,
platform::errors::PreconditionNotMet(
"The shape of elementwise_add should not be empty after fuse"));
"The shape of mul should not be empty after fuse"));
}
}
}
......
......@@ -257,14 +257,29 @@ void CinnLaunchContext::CheckTensorEquivalent(
// check dimension
auto cinn_tensor = GetCinnTensorOfVar(var_name);
auto cinn_dims = phi::make_ddim(cinn_tensor->shape().data());
PADDLE_ENFORCE_EQ(paddle_tensor.dims(),
cinn_dims,
platform::errors::PreconditionNotMet(
"Tensors' shape in variable(%s) are not equivalent, "
"paddle is = [%s], but cinn is = [%s].",
var_name,
paddle_tensor.dims(),
cinn_dims));
if (paddle_tensor.dims().size() == 0) {
// VLOG when paddle inputs 0D-Tensor
VLOG(4) << "Paddle inputs 0D-Tensor, CINN changes 0D-Tensor " << var_name
<< " to 1D-Tensor";
PADDLE_ENFORCE_EQ(phi::make_ddim({1}),
cinn_dims,
phi::errors::PreconditionNotMet(
"Tensor's shape of variable(%s) are not consistent, "
"paddle inputs 0D-Tensor, cinn should get 1D-Tensor "
"instead of [%s].",
var_name,
paddle_tensor.dims(),
cinn_dims));
} else {
PADDLE_ENFORCE_EQ(paddle_tensor.dims(),
cinn_dims,
phi::errors::PreconditionNotMet(
"Tensor's shape of variable(%s) are not equivalent, "
"paddle is = [%s], but cinn is = [%s].",
var_name,
paddle_tensor.dims(),
cinn_dims));
}
auto cinn_dtype =
framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type());
......
......@@ -116,9 +116,6 @@ class TestElementwiseAddOp_ZeroDim1(TestElementwiseAddOp):
self.y = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.out = np.add(self.x, self.y)
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseAddOp_ZeroDim2(TestElementwiseAddOp_ZeroDim1):
def init_input_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册