From 826eaf729db76fb43094ae84f4f3f4ebc31f726e Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 9 Aug 2023 10:49:03 +0800 Subject: [PATCH] [IR] Fix assign_value_op bug (#56082) * fix bug * fix bug --- paddle/fluid/ir/dialect/pd_op.yaml | 31 +++++++++++++++++++ .../ir_adaptor/translator/op_translator.cc | 19 ++---------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_op.yaml b/paddle/fluid/ir/dialect/pd_op.yaml index 0fe59e00610..1c20c409df1 100644 --- a/paddle/fluid/ir/dialect/pd_op.yaml +++ b/paddle/fluid/ir/dialect/pd_op.yaml @@ -1,3 +1,34 @@ +- name: assign_value + inputs: [] + attrs: + - {typename: 'int[]', name: shape} + - {typename: DataType, name: dtype} + - {typename: 'Scalar[]', name: values, data_type: 'std::vector'} + - {typename: Place, name: place, default_value: '{}'} + outputs: + - {typename: Tensor, name: out, optional: false, intermediate: false} + no_need_buffer: null + data_transform: null + infer_meta: + func: AssignValueInferMeta + param: [shape, dtype] + kernel: + func: [assign_value] + param: [shape, dtype, values] + backend: + ordered: true + candidates: [place] + layout: null + data_type: + ordered: false + candidates: [dtype] + to_complex_flag: [false] + dispatch: {assign_value: null} + force_backend: null + inplace: null + view: null + backward: null + - name: feed inputs: [] attrs: diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 9d496fa76c6..04cc99075bd 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -761,11 +761,11 @@ struct IncrementOpTranscriber : public OpTranscriber { // python/paddle/tensor/creation.py::assign(x, output) struct AssignValueOpTranscriber : public OpTranscriber { ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { - std::string target_op_name = "pd.assign_value_"; + std::string target_op_name = "pd.assign_value"; const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW( - "Op assign_value should have corresponding OpInfo pd.assign_value_"); + "Op assign_value should have corresponding OpInfo pd.assign_value"); } return op_info; @@ -836,20 +836,7 @@ struct AssignValueOpTranscriber : public OpTranscriber { VLOG(10) << "[op assign_value] attribute translation done"; - std::vector src_shape = - paddle::get>(op_desc.GetAttr("shape")); - std::vector target_shape(src_shape.begin(), src_shape.end()); - - ir::Builder builder(ctx, program->block()); - dialect::FullOp full_op = builder.Build( - target_shape, - 0.0f, - attr_dtype.dyn_cast().data(), - phi::CPUPlace()); - - std::vector op_inputs = {full_op->result(0)}; - - VLOG(10) << "[op assign_value] insert a full op to get input"; + std::vector op_inputs = {}; OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; -- GitLab