未验证 提交 9c2dae1a 编写于 作者: H hong 提交者: GitHub

Fix output vector type bug (#54865)

* add fetch kernel

* support fetch var in new ir

* fix bug

* polish code

* change array equal to np.testing

* support feed in new ir

* fix bug

* try to hack combine op

* add scope guard

* revert atan2 op

* polish code

* fix vector type bug

* modify feed data type
上级 99c593bc
...@@ -40,7 +40,12 @@ phi::KernelKey GetKernelKey( ...@@ -40,7 +40,12 @@ phi::KernelKey GetKernelKey(
const phi::Place& place, const phi::Place& place,
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair) { const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair) {
if (op->name() == "pd.feed") { if (op->name() == "pd.feed") {
return {phi::Backend::CPU, phi::DataLayout::ANY, phi::DataType::FLOAT32}; // NOTE, for now feed op don't need a kernel, so the data type from Op
// Result the next op use base program datatype
return {phi::Backend::CPU,
phi::DataLayout::ANY,
TransToPhiDataType(
op->result(0).type().dyn_cast<DenseTensorType>().dtype())};
} }
phi::Backend kernel_backend = phi::Backend::UNDEFINED; phi::Backend kernel_backend = phi::Backend::UNDEFINED;
phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED; phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;
...@@ -223,23 +228,27 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { ...@@ -223,23 +228,27 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
result_type.dyn_cast<dialect::DenseTensorType>()); result_type.dyn_cast<dialect::DenseTensorType>());
op_output_types.push_back(allocated_dense_tensor_dtype); op_output_types.push_back(allocated_dense_tensor_dtype);
} else if (result_type.isa<ir::VectorType>()) { } else if (result_type.isa<ir::VectorType>()) {
auto pos1 = result_type.dyn_cast<ir::VectorType>().data()[0]; std::vector<ir::Type> vec_inner_types;
auto base_types = result_type.dyn_cast<ir::VectorType>().data();
if (pos1.isa<dialect::DenseTensorType>()) { for (size_t j = 0; j < base_types.size(); ++j) {
auto allocated_dense_tensor_dtype = if (base_types[j].isa<dialect::DenseTensorType>()) {
paddle::dialect::AllocatedDenseTensorType::get( auto allocated_dense_tensor_dtype =
ctx, paddle::dialect::AllocatedDenseTensorType::get(
phi::TransToPhiPlace(kernel_key.backend()), ctx,
pos1.dyn_cast<dialect::DenseTensorType>()); phi::TransToPhiPlace(kernel_key.backend()),
op_output_types.push_back(allocated_dense_tensor_dtype); base_types[j].dyn_cast<dialect::DenseTensorType>());
} else { vec_inner_types.push_back(allocated_dense_tensor_dtype);
PADDLE_THROW(phi::errors::Unimplemented( } else {
"only support dense tensor in vector type for now")); PADDLE_THROW(phi::errors::Unimplemented(
"only support dense tensor in vector type for now"));
}
} }
ir::Type t1 = ir::VectorType::get(ctx, op_output_types); ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.clear();
op_output_types.push_back(t1); op_output_types.push_back(t1);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Result type only support DenseTensorType and VectorType"));
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册