Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ad037caa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ad037caa
编写于
3月 11, 2022
作者:
Jeffrey Chen
提交者:
GitHub
3月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PHI] Migrate shard_index op (#40254)
上级
8cabb9f3
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
268 addition
and
207 deletion
+268
-207
paddle/fluid/operators/shard_index_op.cc
paddle/fluid/operators/shard_index_op.cc
+11
-26
paddle/fluid/operators/shard_index_op.cu
paddle/fluid/operators/shard_index_op.cu
+0
-96
paddle/fluid/operators/shard_index_op.h
paddle/fluid/operators/shard_index_op.h
+0
-84
paddle/fluid/operators/shard_index_op_npu.cc
paddle/fluid/operators/shard_index_op_npu.cc
+1
-1
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+28
-0
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+8
-0
paddle/phi/kernels/cpu/shard_index_kernel.cc
paddle/phi/kernels/cpu/shard_index_kernel.cc
+91
-0
paddle/phi/kernels/gpu/shard_index_kernel.cu
paddle/phi/kernels/gpu/shard_index_kernel.cu
+99
-0
paddle/phi/kernels/shard_index_kernel.h
paddle/phi/kernels/shard_index_kernel.h
+30
-0
未找到文件。
paddle/fluid/operators/shard_index_op.cc
浏览文件 @
ad037caa
...
...
@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/shard_index_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -20,27 +23,6 @@ namespace operators {
class
ShardIndexOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"ShardIndex"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"ShardIndex"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Rank of Input(X) should be at least 2, "
"but the value given is %d."
,
x_dims
.
size
()));
if
(
ctx
->
IsRuntime
()
||
x_dims
[
x_dims
.
size
()
-
1
]
>
0
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
x_dims
.
size
()
-
1
],
1U
,
platform
::
errors
::
InvalidArgument
(
"The last dimension of Input(X) should be 1, "
"but the value given is %d."
,
x_dims
[
x_dims
.
size
()
-
1
]));
}
ctx
->
SetOutputDim
(
"Out"
,
x_dims
);
ctx
->
ShareLoD
(
"X"
,
/* --> */
"Out"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
...
...
@@ -114,7 +96,10 @@ Examples:
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
shard_index
,
ops
::
ShardIndexOp
,
ops
::
ShardIndexOpMaker
);
REGISTER_OP_CPU_KERNEL
(
shard_index
,
ops
::
ShardIndexCPUKernel
<
int
>
,
ops
::
ShardIndexCPUKernel
<
int64_t
>
);
DECLARE_INFER_SHAPE_FUNCTOR
(
shard_index
,
ShardIndexInferShapeFunctor
,
PD_INFER_META
(
phi
::
ShardIndexInferMeta
));
REGISTER_OPERATOR
(
shard_index
,
ops
::
ShardIndexOp
,
ops
::
ShardIndexOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
ShardIndexInferShapeFunctor
);
paddle/fluid/operators/shard_index_op.cu
已删除
100644 → 0
浏览文件 @
8cabb9f3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/shard_index_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
template
<
typename
T
>
__global__
void
ShardIndexInner
(
const
T
*
in_data
,
T
*
out_data
,
const
int64_t
numel
,
const
int
index_num
,
const
int
nshards
,
const
int
shard_id
,
const
int
ignore_value
)
{
int
shard_size
=
(
index_num
+
nshards
-
1
)
/
nshards
;
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
<
numel
)
{
assert
(
in_data
[
idx
]
>=
0
&&
in_data
[
idx
]
<
index_num
);
if
(
in_data
[
idx
]
/
shard_size
==
shard_id
)
{
out_data
[
idx
]
=
in_data
[
idx
]
%
shard_size
;
}
else
{
out_data
[
idx
]
=
ignore_value
;
}
}
}
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
>
class
ShardIndexCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int
index_num
=
context
.
Attr
<
int
>
(
"index_num"
);
int
nshards
=
context
.
Attr
<
int
>
(
"nshards"
);
int
shard_id
=
context
.
Attr
<
int
>
(
"shard_id"
);
int
ignore_value
=
context
.
Attr
<
int
>
(
"ignore_value"
);
PADDLE_ENFORCE_GT
(
index_num
,
0
,
platform
::
errors
::
InvalidArgument
(
"The value 'index_num' for Op(shard_index) must be greater than 0, "
"but the value given is %d."
,
index_num
));
PADDLE_ENFORCE_GT
(
nshards
,
0
,
platform
::
errors
::
InvalidArgument
(
"The value 'nshard' for Op(shard_index) must be "
"greater than 0, but the value given is %d."
,
nshards
));
PADDLE_ENFORCE_GE
(
shard_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The value 'shard_id' for Op(shard_index) must be greater or "
"equal to 0, but the value given is %d."
,
shard_id
));
PADDLE_ENFORCE_LT
(
shard_id
,
nshards
,
platform
::
errors
::
InvalidArgument
(
"The value 'shard_id' for Op(shard_index) must be less than "
"nshards (%d), but the value given is %d."
,
nshards
,
shard_id
));
out
->
Resize
(
in
->
dims
());
out
->
set_lod
(
in
->
lod
());
auto
*
in_data
=
in
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
numel
=
in
->
numel
();
auto
stream
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
();
ShardIndexInner
<<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_data
,
out_data
,
numel
,
index_num
,
nshards
,
shard_id
,
ignore_value
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
shard_index
,
ops
::
ShardIndexCUDAKernel
<
int
>
,
ops
::
ShardIndexCUDAKernel
<
int64_t
>
);
paddle/fluid/operators/shard_index_op.h
已删除
100644 → 0
浏览文件 @
8cabb9f3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
>
class
ShardIndexCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int
index_num
=
context
.
Attr
<
int
>
(
"index_num"
);
int
nshards
=
context
.
Attr
<
int
>
(
"nshards"
);
int
shard_id
=
context
.
Attr
<
int
>
(
"shard_id"
);
int
ignore_value
=
context
.
Attr
<
int
>
(
"ignore_value"
);
PADDLE_ENFORCE_GT
(
index_num
,
0
,
platform
::
errors
::
InvalidArgument
(
"The value 'index_num' for Op(shard_index) must be greater than 0, "
"but the value given is %d."
,
index_num
));
PADDLE_ENFORCE_GT
(
nshards
,
0
,
platform
::
errors
::
InvalidArgument
(
"The value 'nshard' for Op(shard_index) must be "
"greater than 0, but the value given is %d."
,
nshards
));
PADDLE_ENFORCE_GE
(
shard_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The value 'shard_id' for Op(shard_index) must be greater or "
"equal to 0, but the value given is %d."
,
shard_id
));
PADDLE_ENFORCE_LT
(
shard_id
,
nshards
,
platform
::
errors
::
InvalidArgument
(
"The value 'shard_id' for Op(shard_index) must be less than "
"nshards (%d), but the value given is %d."
,
nshards
,
shard_id
));
int
shard_size
=
(
index_num
+
nshards
-
1
)
/
nshards
;
out
->
Resize
(
in
->
dims
());
out
->
set_lod
(
in
->
lod
());
auto
*
in_data
=
in
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
numel
=
in
->
numel
();
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
PADDLE_ENFORCE_GE
(
in_data
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"The input_index for Op(shard_index) must be "
"greater or equal to 0, but the value given is %d."
,
in_data
[
i
]));
PADDLE_ENFORCE_LT
(
in_data
[
i
],
index_num
,
platform
::
errors
::
InvalidArgument
(
"The input_index for Op(shard_index) must be less "
"than index_num (%d), but the value given is %d."
,
index_num
,
in_data
[
i
]));
if
(
in_data
[
i
]
/
shard_size
==
shard_id
)
{
out_data
[
i
]
=
in_data
[
i
]
%
shard_size
;
}
else
{
out_data
[
i
]
=
ignore_value
;
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/shard_index_op_npu.cc
浏览文件 @
ad037caa
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/
operators/shard_index_op
.h"
#include "paddle/fluid/
framework/op_registry
.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace
paddle
{
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
ad037caa
...
...
@@ -1312,6 +1312,34 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
out
->
set_dtype
(
DataType
::
INT64
);
}
void
ShardIndexInferMeta
(
const
MetaTensor
&
in
,
int
index_num
,
int
nshards
,
int
shard_id
,
int
ignore_value
,
MetaTensor
*
out
,
MetaConfig
config
)
{
auto
x_dims
=
in
.
dims
();
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
2
,
phi
::
errors
::
InvalidArgument
(
"Rank of Input(X) should be at least 2, "
"but the value given is %d."
,
x_dims
.
size
()));
if
(
config
.
is_runtime
||
x_dims
[
x_dims
.
size
()
-
1
]
>
0
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
x_dims
.
size
()
-
1
],
1U
,
phi
::
errors
::
InvalidArgument
(
"The last dimension of Input(X) should be 1, "
"but the value given is %d."
,
x_dims
[
x_dims
.
size
()
-
1
]));
}
out
->
set_dims
(
x_dims
);
out
->
share_lod
(
in
);
out
->
set_dtype
(
in
.
dtype
());
}
}
// namespace phi
PD_REGISTER_INFER_META_FN
(
copy_to
,
phi
::
CopyToInferMeta
);
...
...
paddle/phi/infermeta/unary.h
浏览文件 @
ad037caa
...
...
@@ -190,4 +190,12 @@ void EighInferMeta(const MetaTensor& x,
void
WhereIndexInferMeta
(
const
MetaTensor
&
condition
,
MetaTensor
*
out
);
void
ShardIndexInferMeta
(
const
MetaTensor
&
in
,
int
index_num
,
int
nshards
,
int
shard_id
,
int
ignore_value
,
MetaTensor
*
out
,
MetaConfig
config
=
MetaConfig
());
}
// namespace phi
paddle/phi/kernels/cpu/shard_index_kernel.cc
0 → 100644
浏览文件 @
ad037caa
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/shard_index_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
ShardIndexKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
in
,
int
index_num
,
int
nshards
,
int
shard_id
,
int
ignore_value
,
DenseTensor
*
out
)
{
PADDLE_ENFORCE_GT
(
index_num
,
0
,
errors
::
InvalidArgument
(
"The value 'index_num' for Op(shard_index) must be greater than 0, "
"but the value given is %d."
,
index_num
));
PADDLE_ENFORCE_GT
(
nshards
,
0
,
errors
::
InvalidArgument
(
"The value 'nshard' for Op(shard_index) must be "
"greater than 0, but the value given is %d."
,
nshards
));
PADDLE_ENFORCE_GE
(
shard_id
,
0
,
errors
::
InvalidArgument
(
"The value 'shard_id' for Op(shard_index) must be greater or "
"equal to 0, but the value given is %d."
,
shard_id
));
PADDLE_ENFORCE_LT
(
shard_id
,
nshards
,
errors
::
InvalidArgument
(
"The value 'shard_id' for Op(shard_index) must be less than "
"nshards (%d), but the value given is %d."
,
nshards
,
shard_id
));
int
shard_size
=
(
index_num
+
nshards
-
1
)
/
nshards
;
out
->
Resize
(
in
.
dims
());
out
->
set_lod
(
in
.
lod
());
auto
*
in_data
=
in
.
data
<
T
>
();
auto
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
int64_t
numel
=
in
.
numel
();
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
PADDLE_ENFORCE_GE
(
in_data
[
i
],
0
,
errors
::
InvalidArgument
(
"The input_index for Op(shard_index) must be "
"greater or equal to 0, but the value given is %d."
,
in_data
[
i
]));
PADDLE_ENFORCE_LT
(
in_data
[
i
],
index_num
,
errors
::
InvalidArgument
(
"The input_index for Op(shard_index) must be less "
"than index_num (%d), but the value given is %d."
,
index_num
,
in_data
[
i
]));
if
(
in_data
[
i
]
/
shard_size
==
shard_id
)
{
out_data
[
i
]
=
in_data
[
i
]
%
shard_size
;
}
else
{
out_data
[
i
]
=
ignore_value
;
}
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
shard_index
,
CPU
,
ALL_LAYOUT
,
phi
::
ShardIndexKernel
,
int
,
int64_t
)
{}
paddle/phi/kernels/gpu/shard_index_kernel.cu
0 → 100644
浏览文件 @
ad037caa
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/shard_index_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
template
<
typename
T
>
__global__
void
ShardIndexInner
(
const
T
*
in_data
,
T
*
out_data
,
const
int64_t
numel
,
const
int
index_num
,
const
int
nshards
,
const
int
shard_id
,
const
int
ignore_value
)
{
int
shard_size
=
(
index_num
+
nshards
-
1
)
/
nshards
;
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
<
numel
)
{
assert
(
in_data
[
idx
]
>=
0
&&
in_data
[
idx
]
<
index_num
);
if
(
in_data
[
idx
]
/
shard_size
==
shard_id
)
{
out_data
[
idx
]
=
in_data
[
idx
]
%
shard_size
;
}
else
{
out_data
[
idx
]
=
ignore_value
;
}
}
}
template
<
typename
T
,
typename
Context
>
void
ShardIndexKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
in
,
int
index_num
,
int
nshards
,
int
shard_id
,
int
ignore_value
,
DenseTensor
*
out
)
{
PADDLE_ENFORCE_GT
(
index_num
,
0
,
phi
::
errors
::
InvalidArgument
(
"The value 'index_num' for Op(shard_index) must be greater than 0, "
"but the value given is %d."
,
index_num
));
PADDLE_ENFORCE_GT
(
nshards
,
0
,
phi
::
errors
::
InvalidArgument
(
"The value 'nshard' for Op(shard_index) must be "
"greater than 0, but the value given is %d."
,
nshards
));
PADDLE_ENFORCE_GE
(
shard_id
,
0
,
phi
::
errors
::
InvalidArgument
(
"The value 'shard_id' for Op(shard_index) must be greater or "
"equal to 0, but the value given is %d."
,
shard_id
));
PADDLE_ENFORCE_LT
(
shard_id
,
nshards
,
phi
::
errors
::
InvalidArgument
(
"The value 'shard_id' for Op(shard_index) must be less than "
"nshards (%d), but the value given is %d."
,
nshards
,
shard_id
));
out
->
Resize
(
in
.
dims
());
out
->
set_lod
(
in
.
lod
());
auto
*
in_data
=
in
.
data
<
T
>
();
auto
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
int64_t
numel
=
in
.
numel
();
auto
stream
=
dev_ctx
.
stream
();
ShardIndexInner
<
T
><<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_data
,
out_data
,
numel
,
index_num
,
nshards
,
shard_id
,
ignore_value
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
shard_index
,
GPU
,
ALL_LAYOUT
,
phi
::
ShardIndexKernel
,
int
,
int64_t
)
{}
paddle/phi/kernels/shard_index_kernel.h
0 → 100644
浏览文件 @
ad037caa
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
ShardIndexKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
in
,
int
index_num
,
int
nshards
,
int
shard_id
,
int
ignore_value
,
DenseTensor
*
out
);
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录