未验证 提交 30adea0a 编写于 作者: 石晓伟 提交者: GitHub

tensor_array_to_tensor_op.cc, test=develop (#19289)

上级 0436efd6
......@@ -105,15 +105,7 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
auto out_inx_dim = out_inx.dims();
out_inx_dim[0] = inx.size();
out_inx.Resize(out_inx_dim);
auto &local_scope = scope.NewScope();
std::string var_name = "out_index";
framework::Variable *tmp_index_var = local_scope.Var(var_name);
auto &tmp_index_tensor =
*(tmp_index_var->GetMutable<paddle::framework::LoDTensor>());
tmp_index_tensor.Resize(out_inx_dim);
int *tmp_index_data =
tmp_index_tensor.mutable_data<int>(platform::CPUPlace());
int *tmp_index_data = out_inx.mutable_data<int>(platform::CPUPlace());
auto out_dims = inx[0].dims();
size_t out_dim_sum = 0;
......@@ -122,18 +114,17 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
out_dim_sum += inx_dims[axis];
tmp_index_data[index] = inx_dims[axis];
}
out_inx.ShareDataWith(tmp_index_tensor);
// get input array items' dims
out_dims[axis] = out_dim_sum;
out.Resize(out_dims);
LodTensorArray2LodTensorVector(local_scope, base_name, Input("X"), &names);
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
// Invoke concat Op
auto concat_op = framework::OpRegistry::CreateOp(
"concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs);
concat_op->Run(local_scope, place);
concat_op->Run(scope, place);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册