未验证 提交 a20051cd 编写于 作者: H hong 提交者: GitHub

Fix fetch op and null type bug (#55027)

* fix_fetch_op_and_null_type_bug

* fix compile bug

* add test case
上级 b918100a
......@@ -18,6 +18,7 @@
data_transform: {}
attrs:
- {typename: str, name: name}
- {typename: int, name: col}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null
......
......@@ -62,8 +62,9 @@ void BuildScope(ir::Block* block,
for (size_t i = 0; i < input_num; ++i) {
auto var = scope->Var("fetch");
auto fetch_list = var->GetMutable<paddle::framework::FetchList>();
// for now only support one fetch
fetch_list->resize(1);
int index =
(*it)->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
fetch_list->resize(index + 1);
}
continue;
}
......@@ -148,7 +149,11 @@ void BuildScope(ir::Block* block,
}
auto var = scope->Var(name);
// Only support DenseTensor or Vector<DenseTensor>
if (ptr.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
if (!ptr.type()) {
var->GetMutable<phi::DenseTensor>();
} else if (ptr.type()
.isa<paddle::dialect::AllocatedDenseTensorType>()) {
var->GetMutable<phi::DenseTensor>();
} else if (ptr.type().isa<ir::VectorType>()) {
auto tensor_array =
......
......@@ -146,6 +146,19 @@ void BuildPhiContext(
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data());
} else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
if (array_list[0].isa<ir::Int32Attribute>()) {
std::vector<int32_t> vec_res;
for (size_t i = 0; i < array_list.size(); ++i) {
vec_res.push_back(
array_list[0].dyn_cast<ir::Int32Attribute>().data());
}
ctx->EmplaceBackAttr(vec_res);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
attr_type_name));
}
} else if (attr_type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
......@@ -166,14 +179,22 @@ void BuildPhiContext(
// process fetch op
auto fetch_var = scope->Var("fetch");
auto* fetch_list = fetch_var->GetMutable<paddle::framework::FetchList>();
auto* out_tensor = &(PADDLE_GET(phi::DenseTensor, fetch_list->at(0)));
int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
auto* out_tensor = &(PADDLE_GET(phi::DenseTensor, fetch_list->at(index)));
ctx->EmplaceBackOutput(out_tensor);
} else {
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value out_ptr = op->result(i);
auto name = name_map.at(out_ptr);
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(scope->Var(name)->Get<phi::DenseTensor>()))));
if (out_ptr.type()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(scope->Var(name)->Get<phi::DenseTensor>()))));
} else {
phi::DenseTensor* ptr = nullptr;
OutType out_ptr(ptr);
ctx->EmplaceBackOutput(out_ptr);
}
if (output_map != nullptr) {
// only deal with single input for now, [todo] need support multi input
......
......@@ -218,7 +218,9 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
if ((*it)->num_results() > 0) {
for (size_t i = 0; i < (*it)->num_results(); ++i) {
auto result_type = (*it)->result(i).type();
if (result_type.isa<dialect::DenseTensorType>()) {
if (!result_type) {
op_output_types.push_back(result_type);
} else if (result_type.isa<dialect::DenseTensorType>()) {
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
......
......@@ -916,6 +916,8 @@ struct FetchOpTranscriber : public OpTranscriber {
OpOutputTypeList op_output_types;
ir::AttributeMap attribute_map = {
{"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])},
{"col",
ir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists<int>("col"))},
};
op_output_types.push_back(op_inputs[0].type());
......
......@@ -34,6 +34,11 @@ PD_REGISTER_KERNEL(fetch,
double,
int,
int64_t,
uint8_t,
int8_t,
int16_t,
phi::float16,
phi::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
bool) {}
......@@ -414,7 +414,7 @@ void BuildProgram(ir::Builder &builder) { // NOLINT
auto transpose2_op = builder.Build<paddle::dialect::TransposeOp>(
transpose1_op.out(), std::vector<int>{0, 3, 1, 2});
builder.Build<paddle::dialect::FetchOp>(transpose2_op.out(), "out");
builder.Build<paddle::dialect::FetchOp>(transpose2_op.out(), "out", 0);
}
// TODO(wilber): Add a normal test.
......
......@@ -89,5 +89,35 @@ class TestFeedOp(unittest.TestCase):
np.testing.assert_array_equal(out[0], gold_res)
class TestAddGradOp(unittest.TestCase):
def test_with_new_ir(self):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
main_program = paddle.static.Program()
new_scope = paddle.static.Scope()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.static.data("x", [2, 2], dtype="float32")
y = paddle.static.data("y", [2, 2], dtype="float32")
x.stop_gradient = False
z = x * y
paddle.static.gradients(z, x)
np_a = np.random.rand(2, 2).astype("float32")
np_b = np.random.rand(2, 2).astype("float32")
out = exe.run(
main_program,
feed={"x": np_a, "y": np_b},
fetch_list=[z.name],
)
gold_res = np_a * np_b
np.testing.assert_array_equal(out[0], gold_res)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册