未验证 提交 75088bbf 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse] Unified api args name (#47529) (#47627)

Unified api args name
上级 d4bf8b1a
......@@ -325,7 +325,7 @@
func : softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr}
- backward_op : sparse_coo_tensor_grad
forward : sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out)
forward : sparse_coo_tensor(Tensor values, Tensor indices, int64_t[] shape) -> Tensor(out)
args : (Tensor indices, Tensor out_grad)
output : Tensor(values_grad)
infer_meta :
......
......@@ -269,7 +269,7 @@
backward : softmax_grad
- op : sparse_coo_tensor
args : (Tensor values, Tensor indices, IntArray dense_shape)
args : (Tensor values, Tensor indices, int64_t[] shape={})
output : Tensor(out)
infer_meta :
func : sparse::SparseCooTensorInferMeta
......
......@@ -136,9 +136,9 @@ void Pool3dInferMeta(const MetaTensor& x,
void SparseCooTensorInferMeta(const MetaTensor& values,
const MetaTensor& indices,
const IntArray& dense_shape,
const std::vector<int64_t>& shape,
MetaTensor* out) {
out->set_dims(phi::make_ddim(dense_shape.GetData()));
out->set_dims(phi::make_ddim(shape));
out->set_dtype(values.dtype());
out->set_layout(values.layout());
}
......
......@@ -45,7 +45,7 @@ void Pool3dInferMeta(const MetaTensor& x,
void SparseCooTensorInferMeta(const MetaTensor& values,
const MetaTensor& indices,
const IntArray& dense_shape,
const std::vector<int64_t>& shape,
MetaTensor* out);
} // namespace sparse
......
......@@ -168,10 +168,9 @@ template <typename T, typename Context>
void SparseCooTensorKernel(const Context& dev_ctx,
const DenseTensor& values,
const DenseTensor& indices,
const IntArray& dense_shape,
const std::vector<int64_t>& shape,
SparseCooTensor* out) {
*out =
SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData()));
*out = SparseCooTensor(indices, values, phi::make_ddim(shape));
}
} // namespace sparse
......
......@@ -180,7 +180,7 @@ def sparse_coo_tensor(
inputs = {'values': values, 'indices': indices}
if shape[0] is None:
shape[0] = -1
attrs = {'dense_shape': shape}
attrs = {'shape': shape}
helper = LayerHelper(op_type)
out = helper.create_sparse_variable_for_type_inference(dtype)
helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册