Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8b3bf28c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8b3bf28c
编写于
9月 21, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine reduce_op and follow comments
上级
630273d4
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
103 addition
and
114 deletion
+103
-114
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+7
-0
paddle/operators/reduce_op.cc
paddle/operators/reduce_op.cc
+65
-82
paddle/operators/reduce_op.h
paddle/operators/reduce_op.h
+31
-32
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
8b3bf28c
...
@@ -62,6 +62,13 @@ function(op_library TARGET)
...
@@ -62,6 +62,13 @@ function(op_library TARGET)
file
(
APPEND
${
pybind_file
}
"USE_OP(sigmoid);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(sigmoid);
\n
"
)
endif
()
endif
()
# reduce_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"reduce_op"
)
set
(
pybind_flag 1
)
# It's enough to just adding one operator to pybind
file
(
APPEND
${
pybind_file
}
"USE_OP(reduce_sum);
\n
"
)
endif
()
# pybind USE_NO_KERNEL_OP
# pybind USE_NO_KERNEL_OP
file
(
READ
${
TARGET
}
.cc TARGET_CONTENT
)
file
(
READ
${
TARGET
}
.cc TARGET_CONTENT
)
string
(
REGEX MATCH
"OperatorWithKernel"
regex_result
"
${
TARGET_CONTENT
}
"
)
string
(
REGEX MATCH
"OperatorWithKernel"
regex_result
"
${
TARGET_CONTENT
}
"
)
...
...
paddle/operators/reduce_op.cc
浏览文件 @
8b3bf28c
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
using
framework
::
Tensor
;
using
framework
::
Tensor
;
using
framework
::
DDim
;
using
framework
::
LoDTensor
;
class
ReduceOp
:
public
framework
::
OperatorWithKernel
{
class
ReduceOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
...
@@ -26,18 +26,19 @@ class ReduceOp : public framework::OperatorWithKernel {
...
@@ -26,18 +26,19 @@ class ReduceOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) of ReduceOp should not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Out"
),
"Output(Out) of ReduceOp should not be null."
);
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
x_rank
=
x_dims
.
size
();
auto
x_rank
=
x_dims
.
size
();
PADDLE_ENFORCE_LE
(
x_rank
,
6
,
"Tensors with rank at most 6 are supported"
);
PADDLE_ENFORCE_LE
(
x_rank
,
6
,
"Tensors with rank at most 6 are supported
.
"
);
int
dim
=
ctx
.
Attr
<
int
>
(
"dim"
);
int
dim
=
ctx
.
Attr
<
int
>
(
"dim"
);
if
(
dim
<
0
)
dim
=
x_rank
+
dim
;
if
(
dim
<
0
)
dim
=
x_rank
+
dim
;
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
dim
,
x_rank
,
dim
,
x_rank
,
"The dim should be in the range [-rank(input), rank(input))"
);
"The dim should be in the range [-rank(input), rank(input))."
);
PADDLE_ENFORCE_GE
(
ctx
.
Attr
<
int
>
(
"keep_dim"
),
0
,
"keep_dim must be 0 or 1"
);
bool
keep_dim
=
ctx
.
Attr
<
bool
>
(
"keep_dim"
);
PADDLE_ENFORCE_LE
(
ctx
.
Attr
<
int
>
(
"keep_dim"
),
1
,
"keep_dim must be 0 or 1"
);
bool
keep_dim
=
ctx
.
Attr
<
int
>
(
"keep_dim"
)
==
1
;
auto
dims_vector
=
vectorize
(
x_dims
);
auto
dims_vector
=
vectorize
(
x_dims
);
if
(
keep_dim
||
x_rank
==
1
)
{
if
(
keep_dim
||
x_rank
==
1
)
{
dims_vector
[
dim
]
=
1
;
dims_vector
[
dim
]
=
1
;
...
@@ -45,7 +46,7 @@ class ReduceOp : public framework::OperatorWithKernel {
...
@@ -45,7 +46,7 @@ class ReduceOp : public framework::OperatorWithKernel {
dims_vector
.
erase
(
dims_vector
.
begin
()
+
dim
);
dims_vector
.
erase
(
dims_vector
.
begin
()
+
dim
);
}
}
auto
out_dims
=
framework
::
make_ddim
(
dims_vector
);
auto
out_dims
=
framework
::
make_ddim
(
dims_vector
);
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
(
out_dims
);
ctx
.
Output
<
framework
::
LoD
Tensor
>
(
"Out"
)
->
Resize
(
out_dims
);
}
}
};
};
...
@@ -55,119 +56,101 @@ class ReduceGradOp : public framework::OperatorWithKernel {
...
@@ -55,119 +56,101 @@ class ReduceGradOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) should not be null
.
"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
"Input(Out@GRAD) should not be null
.
"
);
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
x_rank
=
x_dims
.
size
();
auto
x_rank
=
x_dims
.
size
();
PADDLE_ENFORCE_LE
(
x_rank
,
6
,
"Tensors with rank at most 6 are supported"
);
PADDLE_ENFORCE_LE
(
x_rank
,
6
,
"Tensors with rank at most 6 are supported
.
"
);
int
dim
=
ctx
.
Attr
<
int
>
(
"dim"
);
int
dim
=
ctx
.
Attr
<
int
>
(
"dim"
);
if
(
dim
<
0
)
dim
=
x_rank
+
dim
;
if
(
dim
<
0
)
dim
=
x_rank
+
dim
;
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
dim
,
x_rank
,
dim
,
x_rank
,
"The dim should be in the range [-rank(input), rank(input))"
);
"The dim should be in the range [-rank(input), rank(input))."
);
auto
*
x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
x_grad
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
x_grad
)
x_grad
->
Resize
(
x_dims
);
if
(
x_grad
)
x_grad
->
Resize
(
x_dims
);
}
}
};
};
class
Reduce
Sum
OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
ReduceOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
ReduceSumOpMaker
(
framework
::
OpProto
*
proto
,
ReduceOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
AddInput
(
"X"
,
"X"
,
"(Tensor) The input tensor. Tensors with rank at most 6 are supported"
);
"(Tensor) The input tensor. Tensors with rank at most 6 are supported"
);
AddOutput
(
"Out"
,
"(Tensor) The result tensor."
);
AddOutput
(
"Out"
,
"(Tensor) The result tensor."
);
AddComment
(
R"DOC(
ReduceMean operator computes the sum of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless `keep_dim` is true.
)DOC"
);
AddAttr
<
int
>
(
"dim"
,
AddAttr
<
int
>
(
"dim"
,
"(int, default 0) The dimension to reduce. "
"(int, default 0) The dimension to reduce. "
"Must be in the range [-rank(input), rank(input))"
)
"Must be in the range [-rank(input), rank(input))"
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
int
>
(
AddAttr
<
bool
>
(
"keep_dim"
,
"keep_dim"
,
"(bool, default false) "
"(int, default 0) "
"If true, retain the reduced dimension with length 1."
)
"Must be 0 or 1. If 1, retain the reduced dimension with length 1."
)
.
SetDefault
(
false
);
.
SetDefault
(
0
);
comment_
=
R"DOC(
{ReduceOP} operator computes the {reduce} of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless `keep_dim` is true.
)DOC"
;
AddComment
(
comment_
);
}
protected:
std
::
string
comment_
;
void
Replace
(
std
::
string
&
src
,
std
::
string
from
,
std
::
string
to
)
{
std
::
size_t
len_from
=
std
::
strlen
(
from
.
c_str
());
std
::
size_t
len_to
=
std
::
strlen
(
to
.
c_str
());
for
(
std
::
size_t
pos
=
src
.
find
(
from
);
pos
!=
std
::
string
::
npos
;
pos
=
src
.
find
(
from
,
pos
+
len_to
))
{
src
.
replace
(
pos
,
len_from
,
to
);
}
}
void
SetComment
(
std
::
string
name
,
std
::
string
op
)
{
Replace
(
comment_
,
"{ReduceOP}"
,
name
);
Replace
(
comment_
,
"{reduce}"
,
op
);
}
}
};
};
class
ReduceMeanOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
ReduceSumOpMaker
:
public
ReduceOpMaker
{
public:
ReduceSumOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
ReduceOpMaker
(
proto
,
op_checker
)
{
SetComment
(
"ReduceSum"
,
"sum"
);
AddComment
(
comment_
);
}
};
class
ReduceMeanOpMaker
:
public
ReduceOpMaker
{
public:
public:
ReduceMeanOpMaker
(
framework
::
OpProto
*
proto
,
ReduceMeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
ReduceOpMaker
(
proto
,
op_checker
)
{
AddInput
(
SetComment
(
"ReduceMean"
,
"mean"
);
"X"
,
AddComment
(
comment_
);
"(Tensor) The input tensor. Tensors with rank at most 6 are supported"
);
AddOutput
(
"Out"
,
"(Tensor) The result tensor."
);
AddComment
(
R"DOC(
ReduceMean operator computes the mean of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless `keep_dim` is true.
)DOC"
);
AddAttr
<
int
>
(
"dim"
,
"(int, default 0) The dimension to reduce. "
"Must be in the range [-rank(input), rank(input))"
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"keep_dim"
,
"(int, default 0) "
"Must be 0 or 1. If 1, retain the reduced dimension with length 1."
)
.
SetDefault
(
0
);
}
}
};
};
class
ReduceMaxOpMaker
:
public
framework
::
OpProtoAndChecker
Maker
{
class
ReduceMaxOpMaker
:
public
ReduceOp
Maker
{
public:
public:
ReduceMaxOpMaker
(
framework
::
OpProto
*
proto
,
ReduceMaxOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
ReduceOpMaker
(
proto
,
op_checker
)
{
AddInput
(
SetComment
(
"ReduceMax"
,
"max"
);
"X"
,
AddComment
(
comment_
);
"(Tensor) The input tensor. Tensors with rank at most 6 are supported"
);
AddOutput
(
"Out"
,
"(Tensor) The result tensor."
);
AddComment
(
R"DOC(
ReduceMax operator computes the maximum of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless `keep_dim` is true.
)DOC"
);
AddAttr
<
int
>
(
"dim"
,
"(int, default 0) The dimension to reduce. "
"Must be in the range [-rank(input), rank(input))"
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"keep_dim"
,
"(int, default 0) "
"Must be 0 or 1. If 1, retain the reduced dimension with length 1."
)
.
SetDefault
(
0
);
}
}
};
};
class
ReduceMinOpMaker
:
public
framework
::
OpProtoAndChecker
Maker
{
class
ReduceMinOpMaker
:
public
ReduceOp
Maker
{
public:
public:
ReduceMinOpMaker
(
framework
::
OpProto
*
proto
,
ReduceMinOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
ReduceOpMaker
(
proto
,
op_checker
)
{
AddInput
(
SetComment
(
"ReduceMin"
,
"min"
);
"X"
,
AddComment
(
comment_
);
"(Tensor) The input tensor. Tensors with rank at most 6 are supported"
);
AddOutput
(
"Out"
,
"(Tensor) The result tensor."
);
AddComment
(
R"DOC(
ReduceMin operator computes the minimum of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless `keep_dim` is true.
)DOC"
);
AddAttr
<
int
>
(
"dim"
,
"(int, default 0) The dimension to reduce. "
"Must be in the range [-rank(input), rank(input))"
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"keep_dim"
,
"(int, default 0) "
"Must be 0 or 1. If 1, retain the reduced dimension with length 1."
)
.
SetDefault
(
0
);
}
}
};
};
...
...
paddle/operators/reduce_op.h
浏览文件 @
8b3bf28c
...
@@ -27,61 +27,60 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
...
@@ -27,61 +27,60 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
struct
SumFunctor
{
struct
SumFunctor
{
template
<
typename
Place
,
typename
In
,
typename
Out
,
typename
Dim
>
template
<
typename
Place
,
typename
X
,
typename
Y
,
typename
Dim
>
void
operator
()(
const
Place
&
place
,
In
&
in
,
Out
&
out
,
const
Dim
&
dim
)
{
void
operator
()(
const
Place
&
place
,
X
&
x
,
Y
&
y
,
const
Dim
&
dim
)
{
out
.
device
(
place
)
=
in
.
sum
(
dim
);
y
.
device
(
place
)
=
x
.
sum
(
dim
);
}
}
};
};
struct
SumGradFunctor
{
struct
SumGradFunctor
{
template
<
typename
Place
,
typename
In
,
typename
In_Const
,
typename
Out
,
template
<
typename
Place
,
typename
X
,
typename
Y
,
typename
DX
,
typename
DY
,
typename
Dim
>
typename
Dim
>
void
operator
()(
const
Place
&
place
,
In_Const
&
in
,
In
&
in_grad
,
Out
&
out
,
void
operator
()(
const
Place
&
place
,
X
&
x
,
Y
&
y
,
DX
&
dx
,
DY
&
dy
,
Out
&
out_grad
,
const
Dim
&
dim
,
int
size
)
{
const
Dim
&
dim
,
int
size
)
{
in_grad
.
device
(
place
)
=
out_grad
.
broadcast
(
dim
);
dx
.
device
(
place
)
=
dy
.
broadcast
(
dim
);
}
}
};
};
struct
MeanFunctor
{
struct
MeanFunctor
{
template
<
typename
Place
,
typename
In
,
typename
Out
,
typename
Dim
>
template
<
typename
Place
,
typename
X
,
typename
Y
,
typename
Dim
>
void
operator
()(
const
Place
&
place
,
In
&
in
,
Out
&
out
,
const
Dim
&
dim
)
{
void
operator
()(
const
Place
&
place
,
X
&
x
,
Y
&
y
,
const
Dim
&
dim
)
{
out
.
device
(
place
)
=
in
.
mean
(
dim
);
y
.
device
(
place
)
=
x
.
mean
(
dim
);
}
}
};
};
struct
MeanGradFunctor
{
struct
MeanGradFunctor
{
template
<
typename
Place
,
typename
In
,
typename
In_Const
,
typename
Out
,
template
<
typename
Place
,
typename
X
,
typename
Y
,
typename
DX
,
typename
DY
,
typename
Dim
>
typename
Dim
>
void
operator
()(
const
Place
&
place
,
In_Const
&
in
,
In
&
in_grad
,
Out
&
out
,
void
operator
()(
const
Place
&
place
,
X
&
x
,
Y
&
y
,
DX
&
dx
,
DY
&
dy
,
Out
&
out_grad
,
const
Dim
&
dim
,
int
size
)
{
const
Dim
&
dim
,
int
size
)
{
in_grad
.
device
(
place
)
=
out_grad
.
broadcast
(
dim
)
/
in_grad
.
constant
(
size
);
dx
.
device
(
place
)
=
dy
.
broadcast
(
dim
)
/
dx
.
constant
(
size
);
}
}
};
};
struct
MaxFunctor
{
struct
MaxFunctor
{
template
<
typename
Place
,
typename
In
,
typename
Out
,
typename
Dim
>
template
<
typename
Place
,
typename
X
,
typename
Y
,
typename
Dim
>
void
operator
()(
const
Place
&
place
,
In
&
in
,
Out
&
out
,
const
Dim
&
dim
)
{
void
operator
()(
const
Place
&
place
,
X
&
x
,
Y
&
y
,
const
Dim
&
dim
)
{
out
.
device
(
place
)
=
in
.
maximum
(
dim
);
y
.
device
(
place
)
=
x
.
maximum
(
dim
);
}
}
};
};
struct
MinFunctor
{
struct
MinFunctor
{
template
<
typename
Place
,
typename
In
,
typename
Out
,
typename
Dim
>
template
<
typename
Place
,
typename
X
,
typename
Y
,
typename
Dim
>
void
operator
()(
const
Place
&
place
,
In
&
in
,
Out
&
out
,
const
Dim
&
dim
)
{
void
operator
()(
const
Place
&
place
,
X
&
x
,
Y
&
y
,
const
Dim
&
dim
)
{
out
.
device
(
place
)
=
in
.
minimum
(
dim
);
y
.
device
(
place
)
=
x
.
minimum
(
dim
);
}
}
};
};
struct
MaxOrMinGradFunctor
{
struct
MaxOrMinGradFunctor
{
template
<
typename
Place
,
typename
In
,
typename
In_Const
,
typename
Out
,
template
<
typename
Place
,
typename
X
,
typename
Y
,
typename
DX
,
typename
DY
,
typename
Dim
>
typename
Dim
>
void
operator
()(
const
Place
&
place
,
In_Const
&
in
,
In
&
in_grad
,
Out
&
out
,
void
operator
()(
const
Place
&
place
,
X
&
x
,
Y
&
y
,
DX
&
dx
,
DY
&
dy
,
Out
&
out_grad
,
const
Dim
&
dim
,
int
size
)
{
const
Dim
&
dim
,
int
size
)
{
auto
equals
=
in
==
out
.
broadcast
(
dim
);
auto
equals
=
x
==
y
.
broadcast
(
dim
);
auto
ones
=
in_grad
.
constant
(
1
);
auto
ones
=
dx
.
constant
(
1
);
auto
zeros
=
in_grad
.
constant
(
0
);
auto
zeros
=
dx
.
constant
(
0
);
in_grad
.
device
(
place
)
=
dx
.
device
(
place
)
=
dy
.
broadcast
(
dim
)
*
equals
.
select
(
ones
,
zeros
);
out_grad
.
broadcast
(
dim
)
*
equals
.
select
(
ones
,
zeros
);
}
}
};
};
...
@@ -125,7 +124,7 @@ class ReduceKernel : public framework::OpKernel {
...
@@ -125,7 +124,7 @@ class ReduceKernel : public framework::OpKernel {
if
(
dim
<
0
)
dim
=
x_rank
+
dim
;
if
(
dim
<
0
)
dim
=
x_rank
+
dim
;
auto
reduce_dim
=
Eigen
::
array
<
int
,
1
>
({{
dim
}});
auto
reduce_dim
=
Eigen
::
array
<
int
,
1
>
({{
dim
}});
// construct the squeezed output tensor
// construct the squeezed output tensor
bool
keep_dim
=
context
.
Attr
<
int
>
(
"keep_dim"
)
==
1
;
bool
keep_dim
=
context
.
Attr
<
bool
>
(
"keep_dim"
)
;
DDim
dims
=
output
->
dims
();
DDim
dims
=
output
->
dims
();
auto
dims_vector
=
vectorize
(
dims
);
auto
dims_vector
=
vectorize
(
dims
);
if
(
keep_dim
&&
x_rank
>
1
)
{
if
(
keep_dim
&&
x_rank
>
1
)
{
...
@@ -191,7 +190,7 @@ class ReduceGradKernel : public framework::OpKernel {
...
@@ -191,7 +190,7 @@ class ReduceGradKernel : public framework::OpKernel {
braodcast_dim
[
dim
]
=
input0
->
dims
()[
dim
];
braodcast_dim
[
dim
]
=
input0
->
dims
()[
dim
];
auto
&
place
=
context
.
GetEigenDevice
<
Place
>
();
auto
&
place
=
context
.
GetEigenDevice
<
Place
>
();
Functor
functor
;
Functor
functor
;
functor
(
place
,
x
,
x_
grad
,
x_reduce
,
x_reduce_grad
,
braodcast_dim
,
functor
(
place
,
x
,
x_
reduce
,
x_grad
,
x_reduce_grad
,
braodcast_dim
,
braodcast_dim
[
dim
]);
braodcast_dim
[
dim
]);
}
}
}
}
...
@@ -235,8 +234,8 @@ class ReduceGradEigenFreeKernel : public framework::OpKernel {
...
@@ -235,8 +234,8 @@ class ReduceGradEigenFreeKernel : public framework::OpKernel {
out_offset
=
inner_count
*
i
+
j
;
out_offset
=
inner_count
*
i
+
j
;
for
(
int
k
=
0
;
k
<
mid_count
;
++
k
)
{
for
(
int
k
=
0
;
k
<
mid_count
;
++
k
)
{
x_offset
=
(
inner_count
*
mid_count
)
*
i
+
inner_count
*
k
+
j
;
x_offset
=
(
inner_count
*
mid_count
)
*
i
+
inner_count
*
k
+
j
;
functor
(
x_data
+
x_offset
,
x_grad_data
+
x
_offset
,
functor
(
x_data
+
x_offset
,
out_data
+
out
_offset
,
out_data
+
out
_offset
,
out_grad_data
+
out_offset
,
x_grad_data
+
x
_offset
,
out_grad_data
+
out_offset
,
mid_count
);
mid_count
);
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录