Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
610801b5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
610801b5
编写于
8月 07, 2017
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"remove a lot alias"
上级
6b23b91c
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
89 addition
and
57 deletion
+89
-57
paddle/operators/fc_op.cc
paddle/operators/fc_op.cc
+8
-4
paddle/operators/rowwise_add_op.cc
paddle/operators/rowwise_add_op.cc
+9
-6
paddle/operators/rowwise_add_op.cu
paddle/operators/rowwise_add_op.cu
+3
-2
paddle/operators/rowwise_add_op.h
paddle/operators/rowwise_add_op.h
+12
-3
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+11
-7
paddle/operators/sigmoid_op.cu
paddle/operators/sigmoid_op.cu
+3
-1
paddle/operators/sigmoid_op.h
paddle/operators/sigmoid_op.h
+9
-4
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+10
-7
paddle/operators/softmax_op.cu
paddle/operators/softmax_op.cu
+6
-5
paddle/operators/softmax_op.h
paddle/operators/softmax_op.h
+18
-18
未找到文件。
paddle/operators/fc_op.cc
浏览文件 @
610801b5
...
@@ -12,12 +12,14 @@
...
@@ -12,12 +12,14 @@
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "type_alias.h"
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
FullyConnectedOp
:
public
NetOp
{
class
FullyConnectedOp
:
public
framework
::
NetOp
{
public:
public:
void
Init
()
override
{
void
Init
()
override
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
...
@@ -39,9 +41,10 @@ class FullyConnectedOp : public NetOp {
...
@@ -39,9 +41,10 @@ class FullyConnectedOp : public NetOp {
}
}
};
};
class
FullyConnectedOpMaker
:
public
OpProtoAndCheckerMaker
{
class
FullyConnectedOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
FullyConnectedOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
FullyConnectedOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"the input of fc operator"
);
AddInput
(
"X"
,
"the input of fc operator"
);
AddInput
(
"W"
,
"the weight of fc operator"
);
AddInput
(
"W"
,
"the weight of fc operator"
);
...
@@ -66,4 +69,5 @@ USE_OP(rowwise_add);
...
@@ -66,4 +69,5 @@ USE_OP(rowwise_add);
USE_OP
(
sigmoid
);
USE_OP
(
sigmoid
);
USE_OP
(
softmax
);
USE_OP
(
softmax
);
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
fc
,
ops
::
FullyConnectedOp
,
ops
::
FullyConnectedOpMaker
);
REGISTER_OP
(
fc
,
ops
::
FullyConnectedOp
,
ops
::
FullyConnectedOpMaker
);
paddle/operators/rowwise_add_op.cc
浏览文件 @
610801b5
...
@@ -13,12 +13,13 @@
...
@@ -13,12 +13,13 @@
limitations under the License. */
limitations under the License. */
#include "paddle/operators/rowwise_add_op.h"
#include "paddle/operators/rowwise_add_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
RowWiseAddOp
:
public
OperatorWithKernel
{
class
RowWiseAddOp
:
public
framework
::
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2UL
,
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2UL
,
"Two inputs is needed by rowwise add"
);
"Two inputs is needed by rowwise add"
);
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
...
@@ -32,9 +33,10 @@ class RowWiseAddOp : public OperatorWithKernel {
...
@@ -32,9 +33,10 @@ class RowWiseAddOp : public OperatorWithKernel {
}
}
};
};
class
RowWiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
class
RowWiseAddOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
RowWiseAddOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The left input of row-wise add op, must be matrix"
);
AddInput
(
"X"
,
"The left input of row-wise add op, must be matrix"
);
AddInput
(
"b"
,
"The right input of row-wise add op, must be vector"
);
AddInput
(
"b"
,
"The right input of row-wise add op, must be vector"
);
...
@@ -50,6 +52,7 @@ for i in xrange(X.shape[0]):
...
@@ -50,6 +52,7 @@ for i in xrange(X.shape[0]):
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
rowwise_add
,
ops
::
RowWiseAddOp
,
ops
::
RowWiseAddOpMaker
);
REGISTER_OP
(
rowwise_add
,
ops
::
RowWiseAddOp
,
ops
::
RowWiseAddOpMaker
);
REGISTER_OP_CPU_KERNEL
(
rowwise_add
,
REGISTER_OP_CPU_KERNEL
(
ops
::
RowWiseAddKernel
<
ops
::
CPUPlace
,
float
>
);
rowwise_add
,
ops
::
RowWiseAddKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/rowwise_add_op.cu
浏览文件 @
610801b5
...
@@ -15,5 +15,6 @@
...
@@ -15,5 +15,6 @@
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#include "paddle/operators/rowwise_add_op.h"
#include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL
(
rowwise_add
,
namespace
ops
=
paddle
::
operators
;
ops
::
RowWiseAddKernel
<
ops
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
rowwise_add
,
ops
::
RowWiseAddKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/rowwise_add_op.h
浏览文件 @
610801b5
...
@@ -13,15 +13,24 @@
...
@@ -13,15 +13,24 @@
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
RowWiseAddKernel
:
public
OpKernel
{
class
RowWiseAddKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
out
=
context
.
Output
<
Tensor
>
(
0
);
auto
out
=
context
.
Output
<
Tensor
>
(
0
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/operators/sigmoid_op.cc
浏览文件 @
610801b5
...
@@ -13,21 +13,23 @@
...
@@ -13,21 +13,23 @@
limitations under the License. */
limitations under the License. */
#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/sigmoid_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
SigmoidOp
:
public
OperatorWithKernel
{
class
SigmoidOp
:
public
framework
::
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Sigmoid Op only have one input"
);
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Sigmoid Op only have one input"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Sigmoid Op only have one output"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Sigmoid Op only have one output"
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
}
};
};
class
SigmoidOpMaker
:
public
OpProtoAndCheckerMaker
{
class
SigmoidOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
SigmoidOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"sigmoid input"
);
AddInput
(
"X"
,
"sigmoid input"
);
AddOutput
(
"Y"
,
"sigmoid output"
);
AddOutput
(
"Y"
,
"sigmoid output"
);
...
@@ -35,9 +37,9 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
...
@@ -35,9 +37,9 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
}
}
};
};
class
SigmoidOpGrad
:
public
OperatorWithKernel
{
class
SigmoidOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
std
::
string
DebugString
()
const
override
{
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SigmoidGrad"
;
LOG
(
INFO
)
<<
"SigmoidGrad"
;
return
""
;
return
""
;
...
@@ -47,7 +49,9 @@ class SigmoidOpGrad : public OperatorWithKernel {
...
@@ -47,7 +49,9 @@ class SigmoidOpGrad : public OperatorWithKernel {
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
sigmoid
,
ops
::
SigmoidOp
,
ops
::
SigmoidOpMaker
);
REGISTER_OP
(
sigmoid
,
ops
::
SigmoidOp
,
ops
::
SigmoidOpMaker
);
REGISTER_GRADIENT_OP
(
sigmoid
,
sigmoid_grad
,
ops
::
SigmoidOpGrad
);
REGISTER_GRADIENT_OP
(
sigmoid
,
sigmoid_grad
,
ops
::
SigmoidOpGrad
);
REGISTER_OP_CPU_KERNEL
(
sigmoid
,
ops
::
SigmoidKernel
<
ops
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
sigmoid
,
ops
::
SigmoidKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/sigmoid_op.cu
浏览文件 @
610801b5
...
@@ -15,4 +15,6 @@
...
@@ -15,4 +15,6 @@
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/sigmoid_op.h"
REGISTER_OP_GPU_KERNEL
(
sigmoid
,
ops
::
SigmoidKernel
<
ops
::
GPUPlace
,
float
>
);
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
sigmoid
,
ops
::
SigmoidKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/sigmoid_op.h
浏览文件 @
610801b5
...
@@ -13,16 +13,21 @@
...
@@ -13,16 +13,21 @@
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/
operators/type_alias
.h"
#include "paddle/
framework/op_registry
.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
SigmoidKernel
:
public
OpKernel
{
class
SigmoidKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
input
=
context
.
Input
<
Tensor
>
(
0
);
auto
input
=
context
.
Input
<
Tensor
>
(
0
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/operators/softmax_op.cc
浏览文件 @
610801b5
...
@@ -17,9 +17,9 @@ limitations under the License. */
...
@@ -17,9 +17,9 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
SoftmaxOp
:
public
OperatorWithKernel
{
class
SoftmaxOp
:
public
framework
::
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
"Only one input is need for softmax"
);
"Only one input is need for softmax"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
()
==
2UL
,
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
()
==
2UL
,
...
@@ -30,9 +30,10 @@ class SoftmaxOp : public OperatorWithKernel {
...
@@ -30,9 +30,10 @@ class SoftmaxOp : public OperatorWithKernel {
}
}
};
};
class
SoftmaxOpMaker
:
public
OpProtoAndCheckerMaker
{
class
SoftmaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
SoftmaxOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
SoftmaxOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"input of softmax"
);
AddInput
(
"X"
,
"input of softmax"
);
AddOutput
(
"Y"
,
"output of softmax"
);
AddOutput
(
"Y"
,
"output of softmax"
);
...
@@ -61,8 +62,10 @@ class SoftmaxOpGrad : public OperatorWithKernel {
...
@@ -61,8 +62,10 @@ class SoftmaxOpGrad : public OperatorWithKernel {
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
softmax
,
ops
::
SoftmaxOp
,
ops
::
SoftmaxOpMaker
);
REGISTER_OP
(
softmax
,
ops
::
SoftmaxOp
,
ops
::
SoftmaxOpMaker
);
REGISTER_OP_CPU_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
ops
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_GRADIENT_OP
(
softmax
,
softmax_grad
,
ops
::
SoftmaxOpGrad
);
REGISTER_GRADIENT_OP
(
softmax
,
softmax_grad
,
ops
::
SoftmaxOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax_grad
,
REGISTER_OP_CPU_KERNEL
(
ops
::
SoftmaxGradKernel
<
ops
::
CPUPlace
,
float
>
);
softmax_grad
,
ops
::
SoftmaxGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/softmax_op.cu
浏览文件 @
610801b5
/* Copyright (c) 2016 PaddlePaddle Authors
.
All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -13,9 +13,10 @@
...
@@ -13,9 +13,10 @@
limitations under the License. */
limitations under the License. */
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h"
#include "paddle/operators/softmax_op.h"
REGISTER_OP_GPU_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
ops
::
GPUPlace
,
float
>
);
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
softmax_grad
,
REGISTER_OP_GPU_KERNEL
(
softmax
,
ops
::
SoftmaxGradKernel
<
ops
::
GPUPlace
,
float
>
);
ops
::
SoftmaxKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
softmax_grad
,
ops
::
SoftmaxGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/softmax_op.h
浏览文件 @
610801b5
...
@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
...
@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/type_alias.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
SoftmaxKernel
:
public
OpKernel
{
class
SoftmaxKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
output
=
context
.
Output
<
Tensor
>
(
"Y"
);
auto
output
=
context
.
Output
<
Tensor
>
(
"Y"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
@@ -43,8 +45,7 @@ class SoftmaxKernel : public OpKernel {
...
@@ -43,8 +45,7 @@ class SoftmaxKernel : public OpKernel {
Eigen
::
DSizes
<
int
,
2
>
batch_by_one
(
batch_size
,
1
);
Eigen
::
DSizes
<
int
,
2
>
batch_by_one
(
batch_size
,
1
);
Eigen
::
DSizes
<
int
,
2
>
one_by_class
(
1
,
num_classes
);
Eigen
::
DSizes
<
int
,
2
>
one_by_class
(
1
,
num_classes
);
auto
shifted_logits
=
(
logits
-
auto
shifted_logits
=
(
logits
-
logits
.
maximum
(
along_class
)
logits
.
maximum
(
along_class
)
.
eval
()
.
eval
()
.
reshape
(
batch_by_one
)
.
reshape
(
batch_by_one
)
.
broadcast
(
one_by_class
));
.
broadcast
(
one_by_class
));
...
@@ -52,8 +53,7 @@ class SoftmaxKernel : public OpKernel {
...
@@ -52,8 +53,7 @@ class SoftmaxKernel : public OpKernel {
softmax
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
shifted_logits
.
exp
();
softmax
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
shifted_logits
.
exp
();
softmax
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
softmax
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
(
softmax
*
(
softmax
*
softmax
.
sum
(
along_class
)
softmax
.
sum
(
along_class
)
.
inverse
()
.
inverse
()
.
eval
()
.
eval
()
.
reshape
(
batch_by_one
)
.
reshape
(
batch_by_one
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录