未验证 提交 9c2dae1a 编写于 作者: H hong 提交者: GitHub

Fix output vector type bug (#54865)

* add fetch kernel

* support fetch var in new ir

* fix bug

* polish code

* change array equal to np.testing

* support feed in new ir

* fix bug

* try to hack combine op

* add scope guard

* revert atan2 op

* polish code

* fix vector type bug

* modify feed data type
上级 99c593bc
......@@ -40,7 +40,12 @@ phi::KernelKey GetKernelKey(
const phi::Place& place,
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair) {
if (op->name() == "pd.feed") {
return {phi::Backend::CPU, phi::DataLayout::ANY, phi::DataType::FLOAT32};
// NOTE, for now feed op don't need a kernel, so the data type from Op
// Result the next op use base program datatype
return {phi::Backend::CPU,
phi::DataLayout::ANY,
TransToPhiDataType(
op->result(0).type().dyn_cast<DenseTensorType>().dtype())};
}
phi::Backend kernel_backend = phi::Backend::UNDEFINED;
phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;
......@@ -223,23 +228,27 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
result_type.dyn_cast<dialect::DenseTensorType>());
op_output_types.push_back(allocated_dense_tensor_dtype);
} else if (result_type.isa<ir::VectorType>()) {
auto pos1 = result_type.dyn_cast<ir::VectorType>().data()[0];
if (pos1.isa<dialect::DenseTensorType>()) {
std::vector<ir::Type> vec_inner_types;
auto base_types = result_type.dyn_cast<ir::VectorType>().data();
for (size_t j = 0; j < base_types.size(); ++j) {
if (base_types[j].isa<dialect::DenseTensorType>()) {
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
phi::TransToPhiPlace(kernel_key.backend()),
pos1.dyn_cast<dialect::DenseTensorType>());
op_output_types.push_back(allocated_dense_tensor_dtype);
base_types[j].dyn_cast<dialect::DenseTensorType>());
vec_inner_types.push_back(allocated_dense_tensor_dtype);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support dense tensor in vector type for now"));
}
}
ir::Type t1 = ir::VectorType::get(ctx, op_output_types);
op_output_types.clear();
ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.push_back(t1);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Result type only support DenseTensorType and VectorType"));
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册