未验证 提交 826eaf72 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Fix assign_value_op bug (#56082)

* fix bug

* fix bug
上级 b0ed082e
- name: assign_value
inputs: []
attrs:
- {typename: 'int[]', name: shape}
- {typename: DataType, name: dtype}
- {typename: 'Scalar[]', name: values, data_type: 'std::vector<Scalar>'}
- {typename: Place, name: place, default_value: '{}'}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
infer_meta:
func: AssignValueInferMeta
param: [shape, dtype]
kernel:
func: [assign_value]
param: [shape, dtype, values]
backend:
ordered: true
candidates: [place]
layout: null
data_type:
ordered: false
candidates: [dtype]
to_complex_flag: [false]
dispatch: {assign_value: null}
force_backend: null
inplace: null
view: null
backward: null
- name: feed - name: feed
inputs: [] inputs: []
attrs: attrs:
......
...@@ -761,11 +761,11 @@ struct IncrementOpTranscriber : public OpTranscriber { ...@@ -761,11 +761,11 @@ struct IncrementOpTranscriber : public OpTranscriber {
// python/paddle/tensor/creation.py::assign(x, output) // python/paddle/tensor/creation.py::assign(x, output)
struct AssignValueOpTranscriber : public OpTranscriber { struct AssignValueOpTranscriber : public OpTranscriber {
ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
std::string target_op_name = "pd.assign_value_"; std::string target_op_name = "pd.assign_value";
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) { if (!op_info) {
IR_THROW( IR_THROW(
"Op assign_value should have corresponding OpInfo pd.assign_value_"); "Op assign_value should have corresponding OpInfo pd.assign_value");
} }
return op_info; return op_info;
...@@ -836,20 +836,7 @@ struct AssignValueOpTranscriber : public OpTranscriber { ...@@ -836,20 +836,7 @@ struct AssignValueOpTranscriber : public OpTranscriber {
VLOG(10) << "[op assign_value] attribute translation done"; VLOG(10) << "[op assign_value] attribute translation done";
std::vector<int> src_shape = std::vector<ir::OpResult> op_inputs = {};
paddle::get<std::vector<int>>(op_desc.GetAttr("shape"));
std::vector<int64_t> target_shape(src_shape.begin(), src_shape.end());
ir::Builder builder(ctx, program->block());
dialect::FullOp full_op = builder.Build<dialect::FullOp>(
target_shape,
0.0f,
attr_dtype.dyn_cast<dialect::DataTypeAttribute>().data(),
phi::CPUPlace());
std::vector<ir::OpResult> op_inputs = {full_op->result(0)};
VLOG(10) << "[op assign_value] insert a full op to get input";
OpOutputMapping arg_to_idx; OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types; OpOutputTypeList op_output_types;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册