未验证 提交 75088bbf 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse] Unified api args name (#47529) (#47627)

Unified api args name
上级 d4bf8b1a
...@@ -325,7 +325,7 @@ ...@@ -325,7 +325,7 @@
func : softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr} func : softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr}
- backward_op : sparse_coo_tensor_grad - backward_op : sparse_coo_tensor_grad
forward : sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out) forward : sparse_coo_tensor(Tensor values, Tensor indices, int64_t[] shape) -> Tensor(out)
args : (Tensor indices, Tensor out_grad) args : (Tensor indices, Tensor out_grad)
output : Tensor(values_grad) output : Tensor(values_grad)
infer_meta : infer_meta :
......
...@@ -269,7 +269,7 @@ ...@@ -269,7 +269,7 @@
backward : softmax_grad backward : softmax_grad
- op : sparse_coo_tensor - op : sparse_coo_tensor
args : (Tensor values, Tensor indices, IntArray dense_shape) args : (Tensor values, Tensor indices, int64_t[] shape={})
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : sparse::SparseCooTensorInferMeta func : sparse::SparseCooTensorInferMeta
......
...@@ -136,9 +136,9 @@ void Pool3dInferMeta(const MetaTensor& x, ...@@ -136,9 +136,9 @@ void Pool3dInferMeta(const MetaTensor& x,
void SparseCooTensorInferMeta(const MetaTensor& values, void SparseCooTensorInferMeta(const MetaTensor& values,
const MetaTensor& indices, const MetaTensor& indices,
const IntArray& dense_shape, const std::vector<int64_t>& shape,
MetaTensor* out) { MetaTensor* out) {
out->set_dims(phi::make_ddim(dense_shape.GetData())); out->set_dims(phi::make_ddim(shape));
out->set_dtype(values.dtype()); out->set_dtype(values.dtype());
out->set_layout(values.layout()); out->set_layout(values.layout());
} }
......
...@@ -45,7 +45,7 @@ void Pool3dInferMeta(const MetaTensor& x, ...@@ -45,7 +45,7 @@ void Pool3dInferMeta(const MetaTensor& x,
void SparseCooTensorInferMeta(const MetaTensor& values, void SparseCooTensorInferMeta(const MetaTensor& values,
const MetaTensor& indices, const MetaTensor& indices,
const IntArray& dense_shape, const std::vector<int64_t>& shape,
MetaTensor* out); MetaTensor* out);
} // namespace sparse } // namespace sparse
......
...@@ -168,10 +168,9 @@ template <typename T, typename Context> ...@@ -168,10 +168,9 @@ template <typename T, typename Context>
void SparseCooTensorKernel(const Context& dev_ctx, void SparseCooTensorKernel(const Context& dev_ctx,
const DenseTensor& values, const DenseTensor& values,
const DenseTensor& indices, const DenseTensor& indices,
const IntArray& dense_shape, const std::vector<int64_t>& shape,
SparseCooTensor* out) { SparseCooTensor* out) {
*out = *out = SparseCooTensor(indices, values, phi::make_ddim(shape));
SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData()));
} }
} // namespace sparse } // namespace sparse
......
...@@ -180,7 +180,7 @@ def sparse_coo_tensor( ...@@ -180,7 +180,7 @@ def sparse_coo_tensor(
inputs = {'values': values, 'indices': indices} inputs = {'values': values, 'indices': indices}
if shape[0] is None: if shape[0] is None:
shape[0] = -1 shape[0] = -1
attrs = {'dense_shape': shape} attrs = {'shape': shape}
helper = LayerHelper(op_type) helper = LayerHelper(op_type)
out = helper.create_sparse_variable_for_type_inference(dtype) out = helper.create_sparse_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册