diff --git a/paddle/fluid/ir/dialect/pd_op.yaml b/paddle/fluid/ir/dialect/pd_op.yaml index 0fe59e00610fb3a46ed20e0dacd44b1269f894a3..1c20c409df1eb95daeeaaf752e7191006c876a65 100644 --- a/paddle/fluid/ir/dialect/pd_op.yaml +++ b/paddle/fluid/ir/dialect/pd_op.yaml @@ -1,3 +1,34 @@ +- name: assign_value + inputs: [] + attrs: + - {typename: 'int[]', name: shape} + - {typename: DataType, name: dtype} + - {typename: 'Scalar[]', name: values, data_type: 'std::vector'} + - {typename: Place, name: place, default_value: '{}'} + outputs: + - {typename: Tensor, name: out, optional: false, intermediate: false} + no_need_buffer: null + data_transform: null + infer_meta: + func: AssignValueInferMeta + param: [shape, dtype] + kernel: + func: [assign_value] + param: [shape, dtype, values] + backend: + ordered: true + candidates: [place] + layout: null + data_type: + ordered: false + candidates: [dtype] + to_complex_flag: [false] + dispatch: {assign_value: null} + force_backend: null + inplace: null + view: null + backward: null + - name: feed inputs: [] attrs: diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 9d496fa76c62e3aa39f2ed1449eed172dd75888a..04cc99075bd92a15d5e354383edb16d310494beb 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -761,11 +761,11 @@ struct IncrementOpTranscriber : public OpTranscriber { // python/paddle/tensor/creation.py::assign(x, output) struct AssignValueOpTranscriber : public OpTranscriber { ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { - std::string target_op_name = "pd.assign_value_"; + std::string target_op_name = "pd.assign_value"; const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW( - "Op assign_value should have corresponding OpInfo pd.assign_value_"); + "Op assign_value should have corresponding OpInfo pd.assign_value"); } return op_info; @@ -836,20 +836,7 @@ struct AssignValueOpTranscriber : public OpTranscriber { VLOG(10) << "[op assign_value] attribute translation done"; - std::vector src_shape = - paddle::get>(op_desc.GetAttr("shape")); - std::vector target_shape(src_shape.begin(), src_shape.end()); - - ir::Builder builder(ctx, program->block()); - dialect::FullOp full_op = builder.Build( - target_shape, - 0.0f, - attr_dtype.dyn_cast().data(), - phi::CPUPlace()); - - std::vector op_inputs = {full_op->result(0)}; - - VLOG(10) << "[op assign_value] insert a full op to get input"; + std::vector op_inputs = {}; OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types;