Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2f34fc7a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2f34fc7a
编写于
11月 17, 2022
作者:
H
huangjiyi
提交者:
GitHub
11月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm "paddle/fluid/framework/convert_utils.h" in phi (#48001)
上级
f3650201
变更
20
显示空白变更内容
内联
并排
Showing
20 changed file
with
138 addition
and
161 deletion
+138
-161
paddle/fluid/framework/convert_utils.cc
paddle/fluid/framework/convert_utils.cc
+0
-35
paddle/fluid/framework/convert_utils.h
paddle/fluid/framework/convert_utils.h
+3
-1
paddle/fluid/operators/prune_gate_by_capacity_op.cu
paddle/fluid/operators/prune_gate_by_capacity_op.cu
+1
-1
paddle/phi/core/utils/data_type.h
paddle/phi/core/utils/data_type.h
+45
-0
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+5
-9
paddle/phi/kernels/cpu/index_sample_grad_kernel.cc
paddle/phi/kernels/cpu/index_sample_grad_kernel.cc
+9
-13
paddle/phi/kernels/cpu/index_sample_kernel.cc
paddle/phi/kernels/cpu/index_sample_kernel.cc
+9
-13
paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc
paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc
+5
-6
paddle/phi/kernels/cpu/put_along_axis_kernel.cc
paddle/phi/kernels/cpu/put_along_axis_kernel.cc
+8
-9
paddle/phi/kernels/cpu/take_along_axis_kernel.cc
paddle/phi/kernels/cpu/take_along_axis_kernel.cc
+4
-5
paddle/phi/kernels/funcs/math_function.h
paddle/phi/kernels/funcs/math_function.h
+0
-1
paddle/phi/kernels/funcs/unique_functor.h
paddle/phi/kernels/funcs/unique_functor.h
+9
-13
paddle/phi/kernels/gpu/fill_diagonal_kernel.cu
paddle/phi/kernels/gpu/fill_diagonal_kernel.cu
+0
-1
paddle/phi/kernels/gpu/index_sample_grad_kernel.cu
paddle/phi/kernels/gpu/index_sample_grad_kernel.cu
+9
-13
paddle/phi/kernels/gpu/index_sample_kernel.cu
paddle/phi/kernels/gpu/index_sample_kernel.cu
+9
-13
paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu
paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu
+5
-6
paddle/phi/kernels/gpu/put_along_axis_kernel.cu
paddle/phi/kernels/gpu/put_along_axis_kernel.cu
+8
-9
paddle/phi/kernels/gpu/sync_batch_norm_utils.h
paddle/phi/kernels/gpu/sync_batch_norm_utils.h
+1
-3
paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu
paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu
+4
-5
paddle/phi/kernels/gpu/take_along_axis_kernel.cu
paddle/phi/kernels/gpu/take_along_axis_kernel.cu
+4
-5
未找到文件。
paddle/fluid/framework/convert_utils.cc
浏览文件 @
2f34fc7a
...
@@ -162,40 +162,5 @@ DataType String2DataType(const std::string& str) {
...
@@ -162,40 +162,5 @@ DataType String2DataType(const std::string& str) {
}
}
}
}
std
::
string
DataType2String
(
DataType
dtype
)
{
switch
(
dtype
)
{
case
DataType
::
BOOL
:
return
"bool"
;
case
DataType
::
INT8
:
return
"int8"
;
case
DataType
::
UINT8
:
return
"uint8"
;
case
DataType
::
INT16
:
return
"int16"
;
case
DataType
::
INT32
:
return
"int32"
;
case
DataType
::
INT64
:
return
"int64"
;
case
DataType
::
FLOAT16
:
return
"float16"
;
case
DataType
::
FLOAT32
:
return
"float32"
;
case
DataType
::
FLOAT64
:
return
"float64"
;
case
DataType
::
COMPLEX64
:
return
"complex64"
;
case
DataType
::
COMPLEX128
:
return
"complex128"
;
case
DataType
::
PSTRING
:
return
"pstring"
;
case
DataType
::
BFLOAT16
:
return
"bfloat16"
;
default:
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"Unknow phi::DataType, the int value = %d."
,
static_cast
<
int
>
(
dtype
)));
return
""
;
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/convert_utils.h
浏览文件 @
2f34fc7a
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/utils/data_type.h"
// TODO(chenweihang): this file may need to be removed
// TODO(chenweihang): this file may need to be removed
...
@@ -37,7 +38,8 @@ paddle::framework::proto::VarType::Type TransToProtoVarType(
...
@@ -37,7 +38,8 @@ paddle::framework::proto::VarType::Type TransToProtoVarType(
size_t
DataTypeSize
(
DataType
dtype
);
size_t
DataTypeSize
(
DataType
dtype
);
DataType
String2DataType
(
const
std
::
string
&
str
);
DataType
String2DataType
(
const
std
::
string
&
str
);
std
::
string
DataType2String
(
DataType
dtype
);
using
phi
::
DataType2String
;
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/prune_gate_by_capacity_op.cu
浏览文件 @
2f34fc7a
...
@@ -121,7 +121,7 @@ class PruneGateByCapacityCUDAKernel : public framework::OpKernel<T> {
...
@@ -121,7 +121,7 @@ class PruneGateByCapacityCUDAKernel : public framework::OpKernel<T> {
framework
::
TensorCopy
(
*
expert_count
,
context
.
GetPlace
(),
&
expert_count_out
);
framework
::
TensorCopy
(
*
expert_count
,
context
.
GetPlace
(),
&
expert_count_out
);
PruneGateByCapacityFunctor
<
DeviceContext
,
T
>
functor
(
PruneGateByCapacityFunctor
<
DeviceContext
,
T
>
functor
(
context
,
gate_idx
,
&
expert_count_out
,
new_gate_idx_data
);
context
,
gate_idx
,
&
expert_count_out
,
new_gate_idx_data
);
VisitDataType
(
expert_count
->
type
(),
functor
);
::
paddle
::
operators
::
VisitDataType
(
expert_count
->
type
(),
functor
);
}
}
};
};
...
...
paddle/phi/core/utils/data_type.h
浏览文件 @
2f34fc7a
...
@@ -41,6 +41,14 @@ static std::map<int, phi::DataType> var_type_map{{1, phi::DataType::INT16},
...
@@ -41,6 +41,14 @@ static std::map<int, phi::DataType> var_type_map{{1, phi::DataType::INT16},
{
6
,
phi
::
DataType
::
FLOAT64
},
{
6
,
phi
::
DataType
::
FLOAT64
},
{
20
,
phi
::
DataType
::
UINT8
}};
{
20
,
phi
::
DataType
::
UINT8
}};
static
std
::
map
<
phi
::
DataType
,
int
>
map_to_var_type
{{
phi
::
DataType
::
INT16
,
1
},
{
phi
::
DataType
::
INT32
,
2
},
{
phi
::
DataType
::
INT64
,
3
},
{
phi
::
DataType
::
FLOAT16
,
4
},
{
phi
::
DataType
::
FLOAT32
,
5
},
{
phi
::
DataType
::
FLOAT64
,
6
},
{
phi
::
DataType
::
UINT8
,
20
}};
#define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \
#define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \
callback(cpp_type, data_type);
callback(cpp_type, data_type);
...
@@ -129,4 +137,41 @@ inline DataType ToRealType(const DataType& type) {
...
@@ -129,4 +137,41 @@ inline DataType ToRealType(const DataType& type) {
type
));
type
));
}
}
}
}
inline
std
::
string
DataType2String
(
DataType
dtype
)
{
switch
(
dtype
)
{
case
DataType
::
BOOL
:
return
"bool"
;
case
DataType
::
INT8
:
return
"int8"
;
case
DataType
::
UINT8
:
return
"uint8"
;
case
DataType
::
INT16
:
return
"int16"
;
case
DataType
::
INT32
:
return
"int32"
;
case
DataType
::
INT64
:
return
"int64"
;
case
DataType
::
FLOAT16
:
return
"float16"
;
case
DataType
::
FLOAT32
:
return
"float32"
;
case
DataType
::
FLOAT64
:
return
"float64"
;
case
DataType
::
COMPLEX64
:
return
"complex64"
;
case
DataType
::
COMPLEX128
:
return
"complex128"
;
case
DataType
::
PSTRING
:
return
"pstring"
;
case
DataType
::
BFLOAT16
:
return
"bfloat16"
;
default:
PADDLE_THROW
(
errors
::
InvalidArgument
(
"Unknow phi::DataType, the int value = %d."
,
static_cast
<
int
>
(
dtype
)));
return
""
;
}
}
}
// namespace phi
}
// namespace phi
paddle/phi/infermeta/unary.cc
浏览文件 @
2f34fc7a
...
@@ -17,11 +17,11 @@ limitations under the License. */
...
@@ -17,11 +17,11 @@ limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <set>
#include <set>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
...
@@ -133,12 +133,9 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
...
@@ -133,12 +133,9 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
phi
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]"
,
"received [%s]"
,
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
DataType
::
INT32
),
paddle
::
framework
::
proto
::
VarType
::
INT32
),
phi
::
DataType2String
(
DataType
::
INT64
),
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
var_type_map
[
dtype
])));
paddle
::
framework
::
proto
::
VarType
::
INT64
),
paddle
::
framework
::
DataTypeToString
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
dtype
))));
if
(
!
config
.
is_runtime
&&
axis
.
FromTensor
())
{
if
(
!
config
.
is_runtime
&&
axis
.
FromTensor
())
{
std
::
vector
<
int64_t
>
vec
;
std
::
vector
<
int64_t
>
vec
;
...
@@ -180,11 +177,10 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
...
@@ -180,11 +177,10 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
auto
x_rank
=
x_dims
.
size
();
auto
x_rank
=
x_dims
.
size
();
if
(
int_axis
<
0
)
int_axis
+=
x_rank
;
if
(
int_axis
<
0
)
int_axis
+=
x_rank
;
if
(
config
.
is_runtime
)
{
if
(
config
.
is_runtime
)
{
if
(
dtype
==
paddle
::
framework
::
proto
::
VarType
::
INT32
)
{
if
(
dtype
==
map_to_var_type
[
DataType
::
INT32
]
)
{
int64_t
all_element_num
=
0
;
int64_t
all_element_num
=
0
;
if
(
flatten
)
{
if
(
flatten
)
{
all_element_num
=
phi
::
product
(
x_dims
);
all_element_num
=
phi
::
product
(
x_dims
);
}
else
{
}
else
{
all_element_num
=
x_dims
[
int_axis
];
all_element_num
=
x_dims
[
int_axis
];
}
}
...
...
paddle/phi/kernels/cpu/index_sample_grad_kernel.cc
浏览文件 @
2f34fc7a
...
@@ -14,11 +14,11 @@
...
@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/index_sample_grad_kernel.h"
#include "paddle/phi/kernels/index_sample_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
typename
Context
,
typename
IndexT
=
int
>
template
<
typename
T
,
typename
Context
,
typename
IndexT
=
int
>
void
IndexSampleGradInner
(
const
Context
&
context
,
void
IndexSampleGradInner
(
const
Context
&
context
,
...
@@ -76,18 +76,14 @@ void IndexSampleGradKernel(const Context& ctx,
...
@@ -76,18 +76,14 @@ void IndexSampleGradKernel(const Context& ctx,
auto
index_type
=
index
.
dtype
();
auto
index_type
=
index
.
dtype
();
bool
index_type_match
=
bool
index_type_match
=
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
index_type_match
,
true
,
true
,
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Input(Index) holds the wrong type, it holds %s, but "
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s"
,
"desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
index_type
),
paddle
::
framework
::
TransToProtoVarType
(
index_type
)),
phi
::
DataType2String
(
DataType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
DataType
::
INT64
)));
paddle
::
framework
::
TransToProtoVarType
(
DataType
::
INT32
)),
paddle
::
framework
::
DataTypeToString
(
paddle
::
framework
::
TransToProtoVarType
((
DataType
::
INT64
)))));
if
(
index_type
==
DataType
::
INT32
)
{
if
(
index_type
==
DataType
::
INT32
)
{
IndexSampleGradInner
<
T
,
Context
,
int
>
(
ctx
,
out_grad
,
index
,
x_grad
);
IndexSampleGradInner
<
T
,
Context
,
int
>
(
ctx
,
out_grad
,
index
,
x_grad
);
}
else
if
(
index_type
==
DataType
::
INT64
)
{
}
else
if
(
index_type
==
DataType
::
INT64
)
{
...
...
paddle/phi/kernels/cpu/index_sample_kernel.cc
浏览文件 @
2f34fc7a
...
@@ -21,11 +21,11 @@
...
@@ -21,11 +21,11 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
typename
Context
,
typename
IndexT
=
int
>
template
<
typename
T
,
typename
Context
,
typename
IndexT
=
int
>
void
IndexSampleInner
(
const
Context
&
context
,
void
IndexSampleInner
(
const
Context
&
context
,
...
@@ -89,18 +89,14 @@ void IndexSampleKernel(const Context &ctx,
...
@@ -89,18 +89,14 @@ void IndexSampleKernel(const Context &ctx,
auto
index_type
=
index
.
dtype
();
auto
index_type
=
index
.
dtype
();
bool
index_type_match
=
bool
index_type_match
=
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
index_type_match
,
true
,
true
,
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Input(Index) holds the wrong type, it holds %s, but "
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s"
,
"desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
index_type
),
paddle
::
framework
::
TransToProtoVarType
(
index_type
)),
phi
::
DataType2String
(
DataType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
DataType
::
INT64
)));
paddle
::
framework
::
TransToProtoVarType
(
DataType
::
INT32
)),
paddle
::
framework
::
DataTypeToString
(
paddle
::
framework
::
TransToProtoVarType
((
DataType
::
INT64
)))));
if
(
index_type
==
DataType
::
INT32
)
{
if
(
index_type
==
DataType
::
INT32
)
{
IndexSampleInner
<
T
,
Context
,
int
>
(
ctx
,
x
,
index
,
out
);
IndexSampleInner
<
T
,
Context
,
int
>
(
ctx
,
x
,
index
,
out
);
}
else
if
(
index_type
==
DataType
::
INT64
)
{
}
else
if
(
index_type
==
DataType
::
INT64
)
{
...
...
paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc
浏览文件 @
2f34fc7a
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
...
@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
...
@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
true
,
true
,
errors
::
PreconditionNotMet
(
"PutAlongAxisGradOpKernel only runs on CPU."
));
errors
::
PreconditionNotMet
(
"PutAlongAxisGradOpKernel only runs on CPU."
));
const
auto
&
index_type
=
const
auto
&
index_type
=
index
.
dtype
();
paddle
::
framework
::
TransToProtoVarType
(
index
.
dtype
());
if
(
x_grad
)
{
if
(
x_grad
)
{
phi
::
Copy
(
dev_ctx
,
out_grad
,
dev_ctx
.
GetPlace
(),
false
,
x_grad
);
phi
::
Copy
(
dev_ctx
,
out_grad
,
dev_ctx
.
GetPlace
(),
false
,
x_grad
);
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
cpu_scatter_input_grad_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
cpu_scatter_input_grad_kernel
<
T
,
int32_t
>
(
// Here passing an unused argument out_grad, because it's
// Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the
// convenient to instantiate a bunch of template function with the
...
@@ -60,10 +59,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
...
@@ -60,10 +59,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if
(
value_grad
)
{
if
(
value_grad
)
{
value_grad
->
Resize
(
index
.
dims
());
value_grad
->
Resize
(
index
.
dims
());
value_grad
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
value_grad
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
cpu_gather_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
cpu_gather_kernel
<
T
,
int32_t
>
(
out_grad
,
axis
,
index
,
*
value_grad
,
dev_ctx
);
out_grad
,
axis
,
index
,
*
value_grad
,
dev_ctx
);
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
cpu_gather_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
cpu_gather_kernel
<
T
,
int64_t
>
(
out_grad
,
axis
,
index
,
*
value_grad
,
dev_ctx
);
out_grad
,
axis
,
index
,
*
value_grad
,
dev_ctx
);
}
}
...
...
paddle/phi/kernels/cpu/put_along_axis_kernel.cc
浏览文件 @
2f34fc7a
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
...
@@ -37,29 +37,28 @@ void PutAlongAxisKernel(const Context& dev_ctx,
...
@@ -37,29 +37,28 @@ void PutAlongAxisKernel(const Context& dev_ctx,
errors
::
PreconditionNotMet
(
"PutAlongAxisOpKernel only runs on CPU."
));
errors
::
PreconditionNotMet
(
"PutAlongAxisOpKernel only runs on CPU."
));
phi
::
Copy
(
dev_ctx
,
x
,
dev_ctx
.
GetPlace
(),
false
,
out
);
phi
::
Copy
(
dev_ctx
,
x
,
dev_ctx
.
GetPlace
(),
false
,
out
);
const
auto
&
index_type
=
const
auto
&
index_type
=
index
.
dtype
();
paddle
::
framework
::
TransToProtoVarType
(
index
.
dtype
());
if
(
reduce
==
"add"
)
{
if
(
reduce
==
"add"
)
{
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
cpu_scatter_add_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
cpu_scatter_add_kernel
<
T
,
int32_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
cpu_scatter_add_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
cpu_scatter_add_kernel
<
T
,
int64_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
}
}
else
if
(
reduce
==
"multiply"
||
reduce
==
"mul"
)
{
}
else
if
(
reduce
==
"multiply"
||
reduce
==
"mul"
)
{
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
cpu_scatter_mul_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
cpu_scatter_mul_kernel
<
T
,
int32_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
cpu_scatter_mul_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
cpu_scatter_mul_kernel
<
T
,
int64_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
}
}
else
if
(
reduce
==
"assign"
)
{
}
else
if
(
reduce
==
"assign"
)
{
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
cpu_scatter_assign_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
cpu_scatter_assign_kernel
<
T
,
int32_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
cpu_scatter_assign_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
cpu_scatter_assign_kernel
<
T
,
int64_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
}
...
...
paddle/phi/kernels/cpu/take_along_axis_kernel.cc
浏览文件 @
2f34fc7a
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
...
@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
...
@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
out
->
Resize
(
index
.
dims
());
out
->
Resize
(
index
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
out
);
dev_ctx
.
template
Alloc
<
T
>(
out
);
const
auto
&
index_type
=
const
auto
&
index_type
=
index
.
dtype
();
paddle
::
framework
::
TransToProtoVarType
(
index
.
dtype
());
if
(
index_type
==
DataType
::
INT32
)
{
if
(
index_type
==
paddle
::
framework
::
proto
::
VarType
::
INT32
)
{
paddle
::
operators
::
cpu_gather_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
cpu_gather_kernel
<
T
,
int32_t
>
(
x
,
axis
,
index
,
*
out
,
dev_ctx
);
x
,
axis
,
index
,
*
out
,
dev_ctx
);
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
cpu_gather_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
cpu_gather_kernel
<
T
,
int64_t
>
(
x
,
axis
,
index
,
*
out
,
dev_ctx
);
x
,
axis
,
index
,
*
out
,
dev_ctx
);
}
}
...
...
paddle/phi/kernels/funcs/math_function.h
浏览文件 @
2f34fc7a
...
@@ -17,7 +17,6 @@ limitations under the License. */
...
@@ -17,7 +17,6 @@ limitations under the License. */
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
...
...
paddle/phi/kernels/funcs/unique_functor.h
浏览文件 @
2f34fc7a
...
@@ -13,8 +13,8 @@
...
@@ -13,8 +13,8 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
@@ -77,18 +77,14 @@ struct UniqueOpFunctor {
...
@@ -77,18 +77,14 @@ struct UniqueOpFunctor {
const
auto
&
index_type
=
index_
->
dtype
();
const
auto
&
index_type
=
index_
->
dtype
();
bool
index_type_match
=
bool
index_type_match
=
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
index_type_match
,
true
,
true
,
phi
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds %s, "
"Index holds the wrong type, it holds %s, "
"but desires to be %s or %s"
,
"but desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
index_type
),
paddle
::
framework
::
TransToProtoVarType
(
index_type
)),
phi
::
DataType2String
(
DataType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
DataType
::
INT64
)));
paddle
::
framework
::
TransToProtoVarType
(
DataType
::
INT32
)),
paddle
::
framework
::
DataTypeToString
(
paddle
::
framework
::
TransToProtoVarType
(
DataType
::
INT64
))));
if
(
index_type
==
DataType
::
INT32
)
{
if
(
index_type
==
DataType
::
INT32
)
{
for
(
auto
i
=
0
;
i
<
in_
->
numel
();
++
i
)
{
for
(
auto
i
=
0
;
i
<
in_
->
numel
();
++
i
)
{
...
...
paddle/phi/kernels/gpu/fill_diagonal_kernel.cu
浏览文件 @
2f34fc7a
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
#include <algorithm>
#include <algorithm>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
...
...
paddle/phi/kernels/gpu/index_sample_grad_kernel.cu
浏览文件 @
2f34fc7a
...
@@ -17,11 +17,11 @@
...
@@ -17,11 +17,11 @@
#include <algorithm>
#include <algorithm>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
namespace
phi
{
...
@@ -70,18 +70,14 @@ void IndexSampleGradKernel(const Context& ctx,
...
@@ -70,18 +70,14 @@ void IndexSampleGradKernel(const Context& ctx,
auto
index_type
=
index
.
dtype
();
auto
index_type
=
index
.
dtype
();
bool
index_type_match
=
bool
index_type_match
=
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
index_type_match
,
true
,
true
,
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Input(Index) holds the wrong type, it holds %s, but "
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s"
,
"desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
index_type
),
paddle
::
framework
::
TransToProtoVarType
(
index_type
)),
phi
::
DataType2String
(
DataType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
DataType
::
INT64
)));
paddle
::
framework
::
TransToProtoVarType
(
DataType
::
INT32
)),
paddle
::
framework
::
DataTypeToString
(
paddle
::
framework
::
TransToProtoVarType
((
DataType
::
INT64
)))));
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
).
stream
();
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
).
stream
();
auto
input_num
=
x
.
numel
();
auto
input_num
=
x
.
numel
();
...
...
paddle/phi/kernels/gpu/index_sample_kernel.cu
浏览文件 @
2f34fc7a
...
@@ -17,10 +17,10 @@
...
@@ -17,10 +17,10 @@
#include <algorithm>
#include <algorithm>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
namespace
phi
{
...
@@ -59,18 +59,14 @@ void IndexSampleKernel(const Context& ctx,
...
@@ -59,18 +59,14 @@ void IndexSampleKernel(const Context& ctx,
auto
index_type
=
index
.
dtype
();
auto
index_type
=
index
.
dtype
();
bool
index_type_match
=
bool
index_type_match
=
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
index_type_match
,
true
,
true
,
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Input(Index) holds the wrong type, it holds %s, but "
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s"
,
"desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
index_type
),
paddle
::
framework
::
TransToProtoVarType
(
index_type
)),
phi
::
DataType2String
(
DataType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
phi
::
DataType2String
(
DataType
::
INT64
)));
paddle
::
framework
::
TransToProtoVarType
(
DataType
::
INT32
)),
paddle
::
framework
::
DataTypeToString
(
paddle
::
framework
::
TransToProtoVarType
((
DataType
::
INT64
)))));
const
T
*
in_data
=
x
.
data
<
T
>
();
const
T
*
in_data
=
x
.
data
<
T
>
();
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
).
stream
();
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
).
stream
();
...
...
paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu
浏览文件 @
2f34fc7a
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/utils/data_type.h"
namespace
phi
{
namespace
phi
{
...
@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
...
@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
errors
::
PreconditionNotMet
(
errors
::
PreconditionNotMet
(
"PutAlongAxisGradOpCUDAKernel only runs on GPU."
));
"PutAlongAxisGradOpCUDAKernel only runs on GPU."
));
const
auto
&
index_type
=
const
auto
&
index_type
=
index
.
dtype
();
paddle
::
framework
::
TransToProtoVarType
(
index
.
dtype
());
if
(
x_grad
)
{
if
(
x_grad
)
{
phi
::
Copy
(
dev_ctx
,
out_grad
,
dev_ctx
.
GetPlace
(),
false
,
x_grad
);
phi
::
Copy
(
dev_ctx
,
out_grad
,
dev_ctx
.
GetPlace
(),
false
,
x_grad
);
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
gpu_scatter_input_grad_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
gpu_scatter_input_grad_kernel
<
T
,
int32_t
>
(
out_grad
,
axis
,
index
,
*
x_grad
,
dev_ctx
);
out_grad
,
axis
,
index
,
*
x_grad
,
dev_ctx
);
}
else
{
}
else
{
...
@@ -52,14 +51,14 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
...
@@ -52,14 +51,14 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if
(
value_grad
)
{
if
(
value_grad
)
{
value_grad
->
Resize
(
index
.
dims
());
value_grad
->
Resize
(
index
.
dims
());
value_grad
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
value_grad
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
gpu_gather_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
gpu_gather_kernel
<
T
,
int32_t
>
(
out_grad
,
out_grad
,
axis
,
axis
,
index
,
index
,
*
value_grad
,
*
value_grad
,
dev_ctx
);
// the gradient of scatter is gather
dev_ctx
);
// the gradient of scatter is gather
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
gpu_gather_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
gpu_gather_kernel
<
T
,
int64_t
>
(
out_grad
,
axis
,
index
,
*
value_grad
,
dev_ctx
);
out_grad
,
axis
,
index
,
*
value_grad
,
dev_ctx
);
}
}
...
...
paddle/phi/kernels/gpu/put_along_axis_kernel.cu
浏览文件 @
2f34fc7a
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/utils/data_type.h"
namespace
phi
{
namespace
phi
{
...
@@ -36,31 +36,30 @@ void PutAlongAxisKernel(const Context& dev_ctx,
...
@@ -36,31 +36,30 @@ void PutAlongAxisKernel(const Context& dev_ctx,
errors
::
PreconditionNotMet
(
errors
::
PreconditionNotMet
(
"PutAlongAxisCUDAKernel only runs on GPU device."
));
"PutAlongAxisCUDAKernel only runs on GPU device."
));
const
auto
&
index_type
=
const
auto
&
index_type
=
index
.
dtype
();
paddle
::
framework
::
TransToProtoVarType
(
index
.
dtype
());
phi
::
Copy
(
dev_ctx
,
x
,
dev_ctx
.
GetPlace
(),
false
,
out
);
phi
::
Copy
(
dev_ctx
,
x
,
dev_ctx
.
GetPlace
(),
false
,
out
);
if
(
reduce
==
"add"
)
{
if
(
reduce
==
"add"
)
{
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
gpu_scatter_add_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
gpu_scatter_add_kernel
<
T
,
int32_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
gpu_scatter_add_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
gpu_scatter_add_kernel
<
T
,
int64_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
}
}
else
if
(
reduce
==
"multiply"
||
reduce
==
"mul"
)
{
}
else
if
(
reduce
==
"multiply"
||
reduce
==
"mul"
)
{
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
gpu_scatter_mul_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
gpu_scatter_mul_kernel
<
T
,
int32_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
gpu_scatter_mul_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
gpu_scatter_mul_kernel
<
T
,
int64_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
}
}
else
if
(
reduce
==
"assign"
)
{
}
else
if
(
reduce
==
"assign"
)
{
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
gpu_scatter_assign_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
gpu_scatter_assign_kernel
<
T
,
int32_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
gpu_scatter_assign_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
gpu_scatter_assign_kernel
<
T
,
int64_t
>
(
*
out
,
axis
,
index
,
value
,
dev_ctx
);
*
out
,
axis
,
index
,
value
,
dev_ctx
);
}
}
...
...
paddle/phi/kernels/gpu/sync_batch_norm_utils.h
浏览文件 @
2f34fc7a
...
@@ -30,7 +30,6 @@ namespace cub = hipcub;
...
@@ -30,7 +30,6 @@ namespace cub = hipcub;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
#endif
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
...
@@ -431,8 +430,7 @@ void SyncBatchNormGradFunctor(
...
@@ -431,8 +430,7 @@ void SyncBatchNormGradFunctor(
}
}
if
(
comm
)
{
if
(
comm
)
{
int
dtype
=
paddle
::
platform
::
ToNCCLDataType
(
int
dtype
=
paddle
::
platform
::
ToNCCLDataType
(
scale
.
dtype
());
paddle
::
framework
::
TransToProtoVarType
(
scale
.
dtype
()));
// In-place operation
// In-place operation
PADDLE_ENFORCE_GPU_SUCCESS
(
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
ncclAllReduce
(
stats
,
phi
::
dynload
::
ncclAllReduce
(
stats
,
...
...
paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu
浏览文件 @
2f34fc7a
...
@@ -14,11 +14,11 @@
...
@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
namespace
phi
{
...
@@ -43,17 +43,16 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
...
@@ -43,17 +43,16 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
// Set to zero tensor.
// Set to zero tensor.
phi
::
funcs
::
SetConstant
<
Context
,
T
>
functor
;
phi
::
funcs
::
SetConstant
<
Context
,
T
>
functor
;
functor
(
dev_ctx
,
x_grad
,
static_cast
<
T
>
(
0
));
functor
(
dev_ctx
,
x_grad
,
static_cast
<
T
>
(
0
));
const
auto
&
index_type
=
const
auto
&
index_type
=
index
.
dtype
();
paddle
::
framework
::
TransToProtoVarType
(
index
.
dtype
());
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT32
)
{
if
(
index_type
==
Data
Type
::
INT32
)
{
paddle
::
operators
::
gpu_scatter_add_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
gpu_scatter_add_kernel
<
T
,
int32_t
>
(
*
x_grad
,
*
x_grad
,
axis
,
axis
,
index
,
index
,
out_grad
,
out_grad
,
dev_ctx
);
// the gradient of gather is scatter
dev_ctx
);
// the gradient of gather is scatter
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
gpu_scatter_add_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
gpu_scatter_add_kernel
<
T
,
int64_t
>
(
*
x_grad
,
axis
,
index
,
out_grad
,
dev_ctx
);
*
x_grad
,
axis
,
index
,
out_grad
,
dev_ctx
);
}
}
...
...
paddle/phi/kernels/gpu/take_along_axis_kernel.cu
浏览文件 @
2f34fc7a
...
@@ -14,11 +14,11 @@
...
@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace
phi
{
namespace
phi
{
...
@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
...
@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
out
->
Resize
(
index
.
dims
());
out
->
Resize
(
index
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
out
);
dev_ctx
.
template
Alloc
<
T
>(
out
);
const
auto
&
index_type
=
const
auto
&
index_type
=
index
.
dtype
();
paddle
::
framework
::
TransToProtoVarType
(
index
.
dtype
());
if
(
index_type
==
DataType
::
INT32
)
{
if
(
index_type
==
paddle
::
framework
::
proto
::
VarType
::
INT32
)
{
paddle
::
operators
::
gpu_gather_kernel
<
T
,
int32_t
>
(
paddle
::
operators
::
gpu_gather_kernel
<
T
,
int32_t
>
(
x
,
axis
,
index
,
*
out
,
dev_ctx
);
x
,
axis
,
index
,
*
out
,
dev_ctx
);
}
else
if
(
index_type
==
paddle
::
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
index_type
==
Data
Type
::
INT64
)
{
paddle
::
operators
::
gpu_gather_kernel
<
T
,
int64_t
>
(
paddle
::
operators
::
gpu_gather_kernel
<
T
,
int64_t
>
(
x
,
axis
,
index
,
*
out
,
dev_ctx
);
x
,
axis
,
index
,
*
out
,
dev_ctx
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录