未验证 提交 26c824db 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] Support elementwise_add (#53955)

* [0D-Tensor] Support elementwise_add

* support elementwise_add ZeroDim2&3
上级 6fde2056
...@@ -94,13 +94,8 @@ FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const { ...@@ -94,13 +94,8 @@ FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const {
feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(*tensor); feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(*tensor);
} }
PADDLE_ENFORCE_NE( VLOG_IF(4, feed_map[feed_name].shape.size() == 0UL)
feed_map[feed_name].shape.size(), << "Shape is empty, Create 0D-Tensor for " << feed_name;
0UL,
platform::errors::PreconditionNotMet(
"The input variable %s's tensor shape cannot be empty,"
"we need the variable's dtype and shape from tensor.",
feed_name.c_str()));
} }
return feed_map; return feed_map;
} }
......
...@@ -58,6 +58,40 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { ...@@ -58,6 +58,40 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
} }
} }
// CINN ops in this white list support 0D-Tensor
const std::unordered_set<std::string> white_op_list{"elementwise_add"};
std::unordered_set<std::string> white_tensor_name;
// enable white_op_list only when graph_node_size = 1, which means single op
// test
int graph_node_size = 0;
for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp()) {
graph_node_size++;
VLOG(6) << "Graph has op node " << n->Op()->Type();
if (white_op_list.find(n->Op()->Type()) != white_op_list.end()) {
for (const ir::Node* var : n->inputs) {
white_tensor_name.insert(var->Var()->Name());
std::vector<int64_t> shape = var->Var()->GetShape();
if (shape.empty()) {
VLOG(6) << "input var " << var->Name()
<< " dims is empty, keep it's 0D-Tensor status";
}
}
for (const ir::Node* var : n->outputs) {
white_tensor_name.insert(var->Var()->Name());
std::vector<int64_t> shape = var->Var()->GetShape();
if (shape.empty()) {
VLOG(6) << "output var " << var->Name()
<< " dims is empty, keep it's 0D-Tensor status";
}
}
}
}
}
VLOG(6) << "Graph has " << graph_node_size << " op node";
for (const ir::Node* n : graph->Nodes()) { for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) { if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) {
if (n->Op()->HasAttr("shape")) { if (n->Op()->HasAttr("shape")) {
...@@ -85,6 +119,11 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { ...@@ -85,6 +119,11 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
} }
if (n->IsVar()) { if (n->IsVar()) {
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) { if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
if (graph_node_size == 1 && white_tensor_name.find(n->Var()->Name()) !=
white_tensor_name.end()) {
VLOG(6) << "Keep 0D-Tensor status of var " << n->Var()->Name();
continue;
}
std::vector<int64_t> shape = n->Var()->GetShape(); std::vector<int64_t> shape = n->Var()->GetShape();
if (shape.empty()) { if (shape.empty()) {
shape.push_back(1); shape.push_back(1);
......
...@@ -27,7 +27,7 @@ TEST(CinnZeroTensorTrickPass, basic) { ...@@ -27,7 +27,7 @@ TEST(CinnZeroTensorTrickPass, basic) {
ir::Layers layers; ir::Layers layers;
auto* x = layers.data("x", {}); auto* x = layers.data("x", {});
auto* y = layers.data("y", {3, 4}); auto* y = layers.data("y", {3, 4});
auto* add_out_0 = layers.elementwise_add(x, y, nullptr, 0); auto* add_out_0 = layers.mul(x, y, nullptr, 0);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = ir::PassRegistry::Instance().Get("cinn_zero_tensor_trick_pass"); auto pass = ir::PassRegistry::Instance().Get("cinn_zero_tensor_trick_pass");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
...@@ -43,7 +43,7 @@ TEST(CinnZeroTensorTrickPass, basic) { ...@@ -43,7 +43,7 @@ TEST(CinnZeroTensorTrickPass, basic) {
shape.empty(), shape.empty(),
false, false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The shape of elementwise_add should not be empty after fuse")); "The shape of mul should not be empty after fuse"));
} }
} }
} }
......
...@@ -257,14 +257,29 @@ void CinnLaunchContext::CheckTensorEquivalent( ...@@ -257,14 +257,29 @@ void CinnLaunchContext::CheckTensorEquivalent(
// check dimension // check dimension
auto cinn_tensor = GetCinnTensorOfVar(var_name); auto cinn_tensor = GetCinnTensorOfVar(var_name);
auto cinn_dims = phi::make_ddim(cinn_tensor->shape().data()); auto cinn_dims = phi::make_ddim(cinn_tensor->shape().data());
PADDLE_ENFORCE_EQ(paddle_tensor.dims(), if (paddle_tensor.dims().size() == 0) {
cinn_dims, // VLOG when paddle inputs 0D-Tensor
platform::errors::PreconditionNotMet( VLOG(4) << "Paddle inputs 0D-Tensor, CINN changes 0D-Tensor " << var_name
"Tensors' shape in variable(%s) are not equivalent, " << " to 1D-Tensor";
"paddle is = [%s], but cinn is = [%s].", PADDLE_ENFORCE_EQ(phi::make_ddim({1}),
var_name, cinn_dims,
paddle_tensor.dims(), phi::errors::PreconditionNotMet(
cinn_dims)); "Tensor's shape of variable(%s) are not consistent, "
"paddle inputs 0D-Tensor, cinn should get 1D-Tensor "
"instead of [%s].",
var_name,
paddle_tensor.dims(),
cinn_dims));
} else {
PADDLE_ENFORCE_EQ(paddle_tensor.dims(),
cinn_dims,
phi::errors::PreconditionNotMet(
"Tensor's shape of variable(%s) are not equivalent, "
"paddle is = [%s], but cinn is = [%s].",
var_name,
paddle_tensor.dims(),
cinn_dims));
}
auto cinn_dtype = auto cinn_dtype =
framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type()); framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type());
......
...@@ -116,9 +116,6 @@ class TestElementwiseAddOp_ZeroDim1(TestElementwiseAddOp): ...@@ -116,9 +116,6 @@ class TestElementwiseAddOp_ZeroDim1(TestElementwiseAddOp):
self.y = np.random.uniform(0.1, 1, []).astype(self.dtype) self.y = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.out = np.add(self.x, self.y) self.out = np.add(self.x, self.y)
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseAddOp_ZeroDim2(TestElementwiseAddOp_ZeroDim1): class TestElementwiseAddOp_ZeroDim2(TestElementwiseAddOp_ZeroDim1):
def init_input_output(self): def init_input_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册