Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c6366c81
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c6366c81
编写于
9月 12, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
softmax as functor.
上级
2507bcaa
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
192 addition
and
134 deletion
+192
-134
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+1
-1
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+5
-23
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+5
-2
paddle/operators/math/softmax_function.cc
paddle/operators/math/softmax_function.cc
+10
-48
paddle/operators/math/softmax_function.cu
paddle/operators/math/softmax_function.cu
+27
-0
paddle/operators/math/softmax_function.h
paddle/operators/math/softmax_function.h
+46
-11
paddle/operators/softmax_op.h
paddle/operators/softmax_op.h
+1
-1
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+24
-20
paddle/operators/softmax_with_cross_entropy_op.h
paddle/operators/softmax_with_cross_entropy_op.h
+26
-1
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+8
-5
python/paddle/v2/framework/tests/test_softmax_with_cost_op.py
...on/paddle/v2/framework/tests/test_softmax_with_cost_op.py
+0
-22
python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py
.../v2/framework/tests/test_softmax_with_cross_entropy_op.py
+39
-0
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
c6366c81
...
...
@@ -60,7 +60,7 @@ set(DEPS_OPS
op_library
(
identity_op DEPS scale_op
)
op_library
(
minus_op DEPS scale_op
)
op_library
(
mul_op DEPS math_function
)
op_library
(
softmax_op DEPS
math
_function
)
op_library
(
softmax_op DEPS
softmax
_function
)
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor operator net_op
)
op_library
(
scale_op DEPS net_op
)
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
c6366c81
...
...
@@ -14,31 +14,13 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/utils.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
inline
T
tolerable_value
(
const
T
x
)
{
static_assert
(
std
::
is_floating_point
<
T
>::
value
,
"tolerable_value works only on float, "
"double and double double."
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
{
return
kApproInf
;
}
if
(
x
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
x
;
}
template
<
typename
T
>
class
OnehotCrossEntropyOpKernel
:
public
framework
::
OpKernel
{
public:
...
...
@@ -55,12 +37,12 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel {
T
*
Ydata
=
Y
->
data
<
T
>
();
int
batch_size
=
X
->
dims
()[
0
];
int
class_num
=
X
->
dims
()[
1
];
const
int
batch_size
=
X
->
dims
()[
0
];
const
int
class_num
=
X
->
dims
()[
1
];
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
Ydata
[
i
]
=
-
tolerable_value
(
std
::
log
(
Xdata
[
index
]));
Ydata
[
i
]
=
-
math
::
tolerable_value
(
std
::
log
(
Xdata
[
index
]));
}
}
};
...
...
@@ -89,7 +71,7 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
memset
(
dXdata
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
dXdata
[
index
]
=
-
tolerable_value
(
dYdata
[
i
]
/
Xdata
[
index
]);
dXdata
[
index
]
=
-
math
::
tolerable_value
(
dYdata
[
i
]
/
Xdata
[
index
]);
}
}
};
...
...
paddle/operators/math/CMakeLists.txt
浏览文件 @
c6366c81
if
(
WITH_GPU
)
nv_library
(
math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu softmax_function.cc DEPS cblas device_context operator
)
im2col.cu DEPS cblas device_context operator
)
nv_library
(
softmax_function SRCS softmax_function.cc softmax_function.cu
DEPS operator
)
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc
softmax_function.cc DEPS cblas device_context operator
)
DEPS cblas device_context operator
)
cc_library
(
softmax_function SRCS softmax_function.cc DEPS operator
)
endif
()
nv_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
...
...
paddle/operators/math/softmax_function.cc
浏览文件 @
c6366c81
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#define EIGEN_USE_GPU
#endif
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/softmax_function.h"
...
...
@@ -22,41 +18,7 @@ namespace paddle {
namespace
operators
{
namespace
math
{
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
Place
,
typename
T
>
void
softmax
(
const
framework
::
Tensor
*
X
,
framework
::
Tensor
*
Y
,
const
framework
::
ExecutionContext
&
context
)
{
auto
logits
=
EigenMatrix
<
T
>::
From
(
*
X
);
auto
softmax
=
EigenMatrix
<
T
>::
From
(
*
Y
);
const
int
kBatchDim
=
0
;
const
int
kClassDim
=
1
;
const
int
batch_size
=
logits
.
dimension
(
kBatchDim
);
const
int
num_classes
=
logits
.
dimension
(
kClassDim
);
Eigen
::
DSizes
<
int
,
1
>
along_class
(
kClassDim
);
Eigen
::
DSizes
<
int
,
2
>
batch_by_one
(
batch_size
,
1
);
Eigen
::
DSizes
<
int
,
2
>
one_by_class
(
1
,
num_classes
);
auto
shifted_logits
=
(
logits
-
logits
.
maximum
(
along_class
)
.
eval
()
.
reshape
(
batch_by_one
)
.
broadcast
(
one_by_class
));
softmax
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
shifted_logits
.
exp
();
softmax
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
(
softmax
*
softmax
.
sum
(
along_class
)
.
inverse
()
.
eval
()
.
reshape
(
batch_by_one
)
.
broadcast
(
one_by_class
));
}
template
class
SoftmaxFunctor
<
platform
::
CPUPlace
,
float
>;
}
// namespace math
}
// namespace operators
...
...
paddle/operators/math/softmax_function.cu
0 → 100644
浏览文件 @
c6366c81
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/math/softmax_function.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
class
SoftmaxFunctor
<
platform
::
GPUPlace
,
float
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/softmax_function.h
浏览文件 @
c6366c81
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
...
...
@@ -21,9 +21,44 @@ namespace paddle {
namespace
operators
{
namespace
math
{
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
Place
,
typename
T
>
void
softmax
(
const
framework
::
Tensor
*
X
,
framework
::
Tensor
*
Y
,
const
framework
::
ExecutionContext
&
context
);
class
SoftmaxFunctor
{
public:
void
operator
()(
const
framework
::
Tensor
*
X
,
framework
::
Tensor
*
Y
,
const
framework
::
ExecutionContext
&
context
)
{
auto
logits
=
EigenMatrix
<
T
>::
From
(
*
X
);
auto
softmax
=
EigenMatrix
<
T
>::
From
(
*
Y
);
const
int
kBatchDim
=
0
;
const
int
kClassDim
=
1
;
const
int
batch_size
=
logits
.
dimension
(
kBatchDim
);
const
int
num_classes
=
logits
.
dimension
(
kClassDim
);
Eigen
::
DSizes
<
int
,
1
>
along_class
(
kClassDim
);
Eigen
::
DSizes
<
int
,
2
>
batch_by_one
(
batch_size
,
1
);
Eigen
::
DSizes
<
int
,
2
>
one_by_class
(
1
,
num_classes
);
auto
shifted_logits
=
(
logits
-
logits
.
maximum
(
along_class
)
.
eval
()
.
reshape
(
batch_by_one
)
.
broadcast
(
one_by_class
));
softmax
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
shifted_logits
.
exp
();
softmax
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
(
softmax
*
softmax
.
sum
(
along_class
)
.
inverse
()
.
eval
()
.
reshape
(
batch_by_one
)
.
broadcast
(
one_by_class
));
}
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/softmax_op.h
浏览文件 @
c6366c81
...
...
@@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel {
// allocate memory on device.
Y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
softmax
<
Place
,
T
>
(
X
,
Y
,
context
);
math
::
SoftmaxFunctor
<
Place
,
T
>
()
(
X
,
Y
,
context
);
}
};
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
c6366c81
...
...
@@ -23,13 +23,13 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
logits
=
ctx
.
Input
<
Tensor
>
(
"
l
ogits"
);
auto
logits
=
ctx
.
Input
<
Tensor
>
(
"
L
ogits"
);
PADDLE_ENFORCE
(
logits
->
dims
().
size
()
==
2UL
,
"The input of softmax_with_cross_entropy should be a 2-d tensor."
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"
lables
"
)
->
dims
().
size
()
==
1UL
,
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"
Label
"
)
->
dims
().
size
()
==
1UL
,
"The label should be a 1-d tensor."
);
ctx
.
Output
<
Tensor
>
(
"
Y
"
)
->
Resize
({
logits
->
dims
()[
0
]});
ctx
.
Output
<
Tensor
>
(
"
Label
"
)
->
Resize
({
logits
->
dims
()[
0
]});
}
};
...
...
@@ -39,11 +39,15 @@ class SoftmaxWithCrossEntropyOpMaker
SoftmaxWithCrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"
l
ogits"
,
AddInput
(
"
L
ogits"
,
"The unscaled log probabilities which is a 2-D tensor<float> with"
"shape [N x K]. N is the batch_size, and K is the class number."
);
AddInput
(
"label"
,
"The ground truth. A 1-D tensor<int> with shape N."
);
AddOutput
(
"Y"
,
"A 1-D tensor<float> with shape N."
);
AddInput
(
"Label"
,
"The ground truth. A 1-D tensor<int> with shape N."
);
AddOutput
(
"Softmax"
,
"Store the outputs of softmax function, "
"which will be used in backward calculation."
)
.
AsIntermediate
();
AddOutput
(
"Loss"
,
"A 1-D tensor<float> with shape N."
);
AddComment
(
R"DOC(
Cross entropy loss with softmax are used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input
...
...
@@ -67,21 +71,21 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) should be not null."
);
PADDLE_ENFORCE_EQ
(
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
dims
(),
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
))
->
dims
(),
"Input(Y) and its gradients should have a same shape."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"labels"
),
"Input(lables) should be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"logits"
)),
"Input(logits@GRAD) should be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Loss"
),
"Input(Loss) should be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Loss"
)),
"Input(Loss@GRAD) should be not null."
);
PADDLE_ENFORCE_EQ
(
ctx
.
Input
<
Tensor
>
(
"logits"
)
->
dims
(),
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"logits"
))
->
dims
(),
"Input(logits) and its gradients should have a same shape."
);
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
dims
(),
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Logits"
))
->
dims
(),
"Input(Logits) and its gradients should have a same shape."
);
PADDLE_ENFORCE_EQ
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
dims
(),
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Logits"
))
->
dims
(),
"Input(Logits) and its gradients should have a same shape."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Lable) should be not null."
);
}
};
...
...
paddle/operators/softmax_with_cross_entropy_op.h
浏览文件 @
c6366c81
...
...
@@ -15,6 +15,8 @@
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/softmax_function.h"
#include "paddle/operators/math/utils.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -27,7 +29,30 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template
<
typename
Place
,
typename
T
>
class
SoftmaxWithCrossEntropyKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
// Calculate ths softmax outputs.
const
Tensor
*
logits
=
context
.
Input
<
Tensor
>
(
"Logits"
);
Tensor
*
softmax
=
context
.
Output
<
Tensor
>
(
"Softmax"
);
// allocate memory on device.
softmax
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
SoftmaxFunctor
<
Place
,
T
>
()(
logits
,
softmax
,
context
);
// Calculate the cross entropy loss based on hard labels.
T
*
softmax_out
=
softmax
->
data
<
T
>
();
const
int
*
label_data
=
context
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"Loss"
);
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
loss_data
=
loss
->
data
<
T
>
();
const
int
batch_size
=
logits
->
dims
()[
0
];
const
int
class_num
=
logits
->
dims
()[
1
];
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
loss_data
[
i
]
=
-
math
::
tolerable_value
(
std
::
log
(
softmax_out
[
index
]));
}
}
};
template
<
typename
Place
,
typename
T
>
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
c6366c81
import
unittest
import
numpy
from
op_test
import
OpTest
import
pdb
class
TestCrossEntropy
(
OpTest
):
...
...
@@ -10,18 +11,20 @@ class TestCrossEntropy(OpTest):
class_num
=
10
X
=
numpy
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
(
class_num
/
2
)
*
numpy
.
ones
(
batch_size
).
astype
(
"int32"
)
self
.
inputs
=
{
'X'
:
X
,
'label'
:
label
}
labels
=
numpy
.
random
.
randint
(
0
,
class_num
,
batch_size
,
dtype
=
"int32"
)
self
.
inputs
=
{
"X"
:
X
,
"label"
:
labels
}
Y
=
[]
for
i
in
range
(
0
,
batch_size
):
Y
.
append
(
-
numpy
.
log
(
X
[
i
][
label
[
i
]]))
self
.
outputs
=
{
'Y'
:
numpy
.
array
(
Y
).
astype
(
"float32"
)}
Y
.
append
(
-
numpy
.
log
(
X
[
i
][
label
s
[
i
]]))
self
.
outputs
=
{
"Y"
:
numpy
.
array
(
Y
).
astype
(
"float32"
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
self
.
check_grad
([
"X"
],
"Y"
)
if
__name__
==
"__main__"
:
...
...
python/paddle/v2/framework/tests/test_softmax_with_cost_op.py
已删除
100644 → 0
浏览文件 @
2507bcaa
import
unittest
import
numpy
as
np
from
gradient_checker
import
GradientChecker
,
create_op
from
op_test_util
import
OpTestMeta
class
TestSoftmaxWithLossOp
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
pass
class
SoftmaxWithLossGradOpTest
(
GradientChecker
):
def
test_softmax
(
self
):
pass
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py
0 → 100644
浏览文件 @
c6366c81
import
unittest
import
numpy
as
np
import
pdb
from
op_test
import
OpTest
from
test_softmax_op
import
stable_softmax
class
TestSoftmaxWithCrossEntropyOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"softmax_with_cross_entropy"
MAX_BATCH_SIZE
=
23
MAX_CLASS_NUM
=
255
batch_size
=
np
.
random
.
randint
(
1
,
MAX_BATCH_SIZE
,
1
)[
0
]
class_num
=
np
.
random
.
randint
(
2
,
MAX_CLASS_NUM
,
1
)[
0
]
logits
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
softmax
=
np
.
apply_along_axis
(
stable_softmax
,
1
,
logits
)
labels
=
np
.
random
.
randint
(
0
,
class_num
,
batch_size
,
dtype
=
"int32"
)
cross_entropy
=
[
-
np
.
log
(
softmax
[
i
][
labels
[
i
]])
for
i
in
range
(
softmax
.
shape
[
0
])
]
self
.
inputs
=
{
"Logits"
:
logits
,
"Label"
:
labels
}
self
.
outputs
=
{
"Loss"
:
cross_entropy
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
pass
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录