未验证 提交 cec2730e 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] fix elementwise op bridge when x or y are from weight (#2272)

test=develop
上级 dc2b853e
......@@ -38,22 +38,41 @@ node_map_type ElementwiseConverter(const std::shared_ptr<lite::OpLite> op,
// get input, and attributes
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
CHECK(input_nodes.count(x_var_name));
CHECK(input_nodes.count(y_var_name));
auto axis = op_info->GetAttr<int>("axis");
auto x_dims = scope->FindTensor(x_var_name)->dims();
auto y_dims = scope->FindTensor(y_var_name)->dims();
auto x_tensor = scope->FindMutableTensor(x_var_name);
auto y_tensor = scope->FindMutableTensor(y_var_name);
auto x_dims = x_tensor->dims();
auto y_dims = y_tensor->dims();
// create x and y node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (input_nodes.count(x_var_name)) {
x_node = input_nodes.at(x_var_name);
} else {
x_node = std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateTensor(
x_var_name, lite::xpu::CvtShape(x_dims), ::xtcl::Float(32)));
auto x_const_tensor = lite::xpu::CvtTensor(x_tensor);
graph_ctx->params->emplace(std::make_pair(x_var_name, *x_const_tensor));
}
std::shared_ptr<xtcl::xExpr> y_node = nullptr;
if (input_nodes.count(y_var_name)) {
y_node = input_nodes.at(y_var_name);
} else {
y_node = std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateTensor(
y_var_name, lite::xpu::CvtShape(y_dims), ::xtcl::Float(32)));
auto y_const_tensor = lite::xpu::CvtTensor(y_tensor);
graph_ctx->params->emplace(std::make_pair(y_var_name, *y_const_tensor));
}
// create elementwise node and set input, attributes
std::shared_ptr<xtcl::xExpr> elementwise_node = nullptr;
if (y_dims.size() == 1) {
elementwise_node =
std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateBiasAdd(
*input_nodes.at(x_var_name), *input_nodes.at(y_var_name), axis));
elementwise_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateBiasAdd(*x_node, *y_node, axis));
} else if (x_dims.size() == y_dims.size()) {
elementwise_node =
std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateBinaryOp(
"add", *input_nodes.at(x_var_name), *input_nodes.at(y_var_name)));
elementwise_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateBinaryOp("add", *x_node, *y_node));
} else {
LOG(ERROR) << "XPU elementwise_add only support y of one dimension, or x "
"and y of the same dimension. But recieved x's dimension: "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册