Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6d5744b4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6d5744b4
编写于
8月 09, 2022
作者:
D
duanboqiang
提交者:
GitHub
8月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[phi] migrate margin infer shape and yaml (#44940)
* add margin infer * migrate yaml * modify unittests script
上级
7b29c89b
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
257 addition
and
86 deletion
+257
-86
paddle/fluid/operators/margin_cross_entropy_op.cc
paddle/fluid/operators/margin_cross_entropy_op.cc
+16
-77
paddle/phi/api/yaml/legacy_api.yaml
paddle/phi/api/yaml/legacy_api.yaml
+10
-0
paddle/phi/api/yaml/legacy_backward.yaml
paddle/phi/api/yaml/legacy_backward.yaml
+11
-0
paddle/phi/infermeta/backward.cc
paddle/phi/infermeta/backward.cc
+24
-0
paddle/phi/infermeta/backward.h
paddle/phi/infermeta/backward.h
+14
-0
paddle/phi/infermeta/binary.cc
paddle/phi/infermeta/binary.cc
+59
-0
paddle/phi/infermeta/binary.h
paddle/phi/infermeta/binary.h
+14
-0
paddle/phi/kernels/cpu/margin_cross_entropy_kernel.cc
paddle/phi/kernels/cpu/margin_cross_entropy_kernel.cc
+50
-0
paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu
paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu
+0
-1
python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py
...dle/fluid/tests/unittests/test_margin_cross_entropy_op.py
+46
-7
python/paddle/nn/functional/loss.py
python/paddle/nn/functional/loss.py
+13
-1
未找到文件。
paddle/fluid/operators/margin_cross_entropy_op.cc
浏览文件 @
6d5744b4
...
...
@@ -12,8 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -22,55 +25,6 @@ class MarginCrossEntropyOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Logits"
),
"Input"
,
"Logits"
,
"MarginCrossEntropyOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Label"
),
"Input"
,
"Label"
,
"MarginCrossEntropyOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Softmax"
),
"Output"
,
"Softmax"
,
"MarginCrossEntropyOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Loss"
),
"Output"
,
"Loss"
,
"MarginCrossEntropyOp"
);
auto
logits_dims
=
ctx
->
GetInputDim
(
"Logits"
);
auto
labels_dims
=
ctx
->
GetInputDim
(
"Label"
);
auto
logits_rank
=
logits_dims
.
size
();
auto
axis
=
logits_rank
-
1
;
for
(
int
i
=
0
;
i
<
logits_rank
;
i
++
)
{
if
(
i
!=
axis
)
{
if
(
ctx
->
IsRuntime
()
||
(
logits_dims
[
i
]
>
0
&&
labels_dims
[
i
]
>
0
))
{
PADDLE_ENFORCE_EQ
(
logits_dims
[
i
],
labels_dims
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Logits) and Input(Label) should in "
"same shape in dimensions except axis."
));
}
}
}
if
(
labels_dims
.
size
()
>
1
)
{
PADDLE_ENFORCE_EQ
(
labels_dims
[
logits_rank
-
1
],
1UL
,
platform
::
errors
::
InvalidArgument
(
"the last dimension of Input(Label) should be 1."
"But received: the last dimension of Input(Label) is [%d],"
"the last dimension is [%d]"
,
labels_dims
[
logits_rank
-
1
],
logits_rank
-
1
));
}
ctx
->
SetOutputDim
(
"Softmax"
,
logits_dims
);
logits_dims
[
axis
]
=
1
;
ctx
->
SetOutputDim
(
"Loss"
,
logits_dims
);
ctx
->
ShareLoD
(
"Logits"
,
/*->*/
"Softmax"
);
ctx
->
ShareLoD
(
"Logits"
,
/*->*/
"Loss"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -141,29 +95,6 @@ class MarginCrossEntropyOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Loss@Grad) should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Softmax"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Softmax) should be not null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Label"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Label) should be not null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Logits"
)),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(Logits@Grad) should be not null."
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Logits"
),
ctx
->
GetInputDim
(
"Softmax"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -195,13 +126,21 @@ class MarginCrossEntropyOpGradMaker : public framework::SingleGradOpMaker<T> {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
DECLARE_INFER_SHAPE_FUNCTOR
(
margin_cross_entropy
,
MarginCrossEntropyInferShapeFunctor
,
PD_INFER_META
(
phi
::
MarginCrossEntropyInferMeta
));
REGISTER_OPERATOR
(
margin_cross_entropy
,
ops
::
MarginCrossEntropyOp
,
ops
::
MarginCrossEntropyOpMaker
,
ops
::
MarginCrossEntropyOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
MarginCrossEntropyOpGradMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
margin_cross_entropy_grad
,
ops
::
MarginCrossEntropyOpGrad
);
ops
::
MarginCrossEntropyOpGradMaker
<
paddle
::
imperative
::
OpBase
>
,
MarginCrossEntropyInferShapeFunctor
);
DECLARE_INFER_SHAPE_FUNCTOR
(
margin_cross_entropy_grad
,
MarginCrossEntropyGradInferShapeFunctor
,
PD_INFER_META
(
phi
::
MarginCrossEntropyGradInferMeta
));
REGISTER_OPERATOR
(
margin_cross_entropy_grad
,
ops
::
MarginCrossEntropyOpGrad
,
MarginCrossEntropyGradInferShapeFunctor
);
paddle/phi/api/yaml/legacy_api.yaml
浏览文件 @
6d5744b4
...
...
@@ -1564,6 +1564,16 @@
data_type
:
x
backward
:
lu_unpack_grad
-
api
:
margin_cross_entropy
args
:
(Tensor logits, Tensor label, bool return_softmax, int ring_id, int rank, int nranks, float margin1, float margin2, float margin3, float scale)
output
:
Tensor(softmax), Tensor(loss)
infer_meta
:
func
:
MarginCrossEntropyInferMeta
kernel
:
func
:
margin_cross_entropy
data_type
:
logits
backward
:
margin_cross_entropy_grad
# masked_select
-
api
:
masked_select
args
:
(Tensor x, Tensor mask)
...
...
paddle/phi/api/yaml/legacy_backward.yaml
浏览文件 @
6d5744b4
...
...
@@ -1336,6 +1336,17 @@
kernel
:
func
:
lu_unpack_grad
-
backward_api
:
margin_cross_entropy_grad
forward
:
margin_cross_entropy (Tensor logits, Tensor label, bool return_softmax, int ring_id, int rank, int nranks, float margin1, float margin2, float margin3, float scale) -> Tensor(softmax), Tensor(loss)
args
:
(Tensor logits, Tensor label, Tensor softmax, Tensor loss_grad, bool return_softmax, int ring_id, int rank, int nranks, float margin1, float margin2, float margin3, float scale)
output
:
Tensor(logits_grad)
infer_meta
:
func
:
MarginCrossEntropyGradInferMeta
kernel
:
func
:
margin_cross_entropy_grad
data_type
:
softmax
inplace
:
(softmax -> logits_grad)
-
backward_api
:
masked_select_grad
forward
:
masked_select (Tensor x, Tensor mask) -> Tensor(out)
args
:
(Tensor x, Tensor mask, Tensor out_grad)
...
...
paddle/phi/infermeta/backward.cc
浏览文件 @
6d5744b4
...
...
@@ -560,6 +560,30 @@ void LUUnpackGradInferMeta(const MetaTensor& x,
}
}
void
MarginCrossEntropyGradInferMeta
(
const
MetaTensor
&
logits
,
const
MetaTensor
&
label
,
const
MetaTensor
&
softmax
,
const
MetaTensor
&
loss_grad
,
bool
return_softmax
,
int
ring_id
,
int
rank
,
int
nranks
,
float
margin1
,
float
margin2
,
float
margin3
,
float
scale
,
MetaTensor
*
logits_grad
)
{
PADDLE_ENFORCE_NE
(
logits_grad
,
nullptr
,
phi
::
errors
::
InvalidArgument
(
"The Logits@GRAD in MarginCrossEntropy can't be nullptr."
));
auto
softmax_dims
=
softmax
.
dims
();
logits_grad
->
set_dims
(
softmax_dims
);
logits_grad
->
set_dtype
(
softmax
.
dtype
());
}
void
MaxPoolWithIndexGradInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
mask
,
const
MetaTensor
&
dout
,
...
...
paddle/phi/infermeta/backward.h
浏览文件 @
6d5744b4
...
...
@@ -245,6 +245,20 @@ void LUUnpackGradInferMeta(const MetaTensor& x,
bool
unpack_pivots
,
MetaTensor
*
x_grad
);
void
MarginCrossEntropyGradInferMeta
(
const
MetaTensor
&
logits
,
const
MetaTensor
&
label
,
const
MetaTensor
&
softmax
,
const
MetaTensor
&
loss_grad
,
bool
return_softmax
,
int
ring_id
,
int
rank
,
int
nranks
,
float
margin1
,
float
margin2
,
float
margin3
,
float
scale
,
MetaTensor
*
logits_grad
);
void
MaxPoolWithIndexGradInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
mask
,
const
MetaTensor
&
dout
,
...
...
paddle/phi/infermeta/binary.cc
浏览文件 @
6d5744b4
...
...
@@ -1545,6 +1545,65 @@ void LUUnpackInferMeta(const MetaTensor& x,
}
}
void
MarginCrossEntropyInferMeta
(
const
MetaTensor
&
logits
,
const
MetaTensor
&
label
,
bool
return_softmax
,
int
ring_id
,
int
rank
,
int
nranks
,
float
margin1
,
float
margin2
,
float
margin3
,
float
scale
,
MetaTensor
*
softmax
,
MetaTensor
*
loss
,
MetaConfig
config
)
{
PADDLE_ENFORCE_NOT_NULL
(
logits
,
phi
::
errors
::
InvalidArgument
(
"Input of logits should not be null."
));
PADDLE_ENFORCE_NOT_NULL
(
label
,
phi
::
errors
::
InvalidArgument
(
"Input of label should not be null."
));
auto
logits_dims
=
logits
.
dims
();
auto
labels_dims
=
label
.
dims
();
auto
logits_rank
=
logits_dims
.
size
();
auto
axis
=
logits_rank
-
1
;
for
(
int
i
=
0
;
i
<
logits_rank
;
i
++
)
{
if
(
i
!=
axis
)
{
if
(
config
.
is_runtime
||
(
logits_dims
[
i
]
>
0
&&
labels_dims
[
i
]
>
0
))
{
PADDLE_ENFORCE_EQ
(
logits_dims
[
i
],
labels_dims
[
i
],
phi
::
errors
::
InvalidArgument
(
"Input(Logits) and Input(Label) should in "
"same shape in dimensions except axis."
));
}
}
}
if
(
labels_dims
.
size
()
>
1
)
{
PADDLE_ENFORCE_EQ
(
labels_dims
[
logits_rank
-
1
],
1UL
,
phi
::
errors
::
InvalidArgument
(
"the last dimension of Input(Label) should be 1."
"But received: the last dimension of Input(Label) is [%d],"
"the last dimension is [%d]"
,
labels_dims
[
logits_rank
-
1
],
logits_rank
-
1
));
}
softmax
->
set_dims
(
logits_dims
);
softmax
->
set_dtype
(
logits
.
dtype
());
logits_dims
[
axis
]
=
1
;
loss
->
set_dims
(
logits_dims
);
loss
->
set_dtype
(
logits
.
dtype
());
softmax
->
share_lod
(
logits
);
loss
->
share_lod
(
logits
);
}
void
MaskedSelectInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
mask
,
MetaTensor
*
out
)
{
...
...
paddle/phi/infermeta/binary.h
浏览文件 @
6d5744b4
...
...
@@ -240,6 +240,20 @@ void LUUnpackInferMeta(const MetaTensor& x,
MetaTensor
*
l
,
MetaTensor
*
u
);
void
MarginCrossEntropyInferMeta
(
const
MetaTensor
&
logits
,
const
MetaTensor
&
label
,
bool
return_softmax
,
int
ring_id
,
int
rank
,
int
nranks
,
float
margin1
,
float
margin2
,
float
margin3
,
float
scale
,
MetaTensor
*
softmax
,
MetaTensor
*
loss
,
MetaConfig
config
=
MetaConfig
());
void
MaskedSelectInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
mask
,
MetaTensor
*
out
);
...
...
paddle/phi/kernels/cpu/margin_cross_entropy_kernel.cc
0 → 100644
浏览文件 @
6d5744b4
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/margin_cross_entropy_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
MarginCrossEntropyKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
logits
,
const
DenseTensor
&
labels
,
bool
return_softmax
,
int
ring_id
,
int
rank
,
int
nranks
,
float
margin1
,
float
margin2
,
float
margin3
,
float
scale
,
DenseTensor
*
softmax
,
DenseTensor
*
loss
)
{
PADDLE_THROW
(
errors
::
Unavailable
(
"Do not support margin_cross_entropy for cpu kernel "
"now."
));
}
}
// namespace phi
PD_REGISTER_KERNEL
(
margin_cross_entropy
,
CPU
,
ALL_LAYOUT
,
phi
::
MarginCrossEntropyKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu
浏览文件 @
6d5744b4
...
...
@@ -378,7 +378,6 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
DenseTensor
sum_exp_logits
;
sum_exp_logits
.
Resize
({
N
,
1
});
dev_ctx
.
template
Alloc
<
T
>(
&
sum_exp_logits
);
// T* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place);
T
*
sum_exp_logits_buff
=
dev_ctx
.
template
Alloc
<
T
>(
&
sum_exp_logits
);
phi
::
funcs
::
ReduceKernel
<
T
,
T
,
phi
::
kps
::
AddFunctor
,
phi
::
kps
::
ExpFunctor
<
T
>>
(
static_cast
<
const
phi
::
GPUContext
&>
(
dev_ctx
),
...
...
python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py
浏览文件 @
6d5744b4
...
...
@@ -66,12 +66,36 @@ def margin_cross_entropy(logits,
return
loss
,
softmax
def
python_api
(
logits
,
label
,
return_softmax
=
False
,
ring_id
=
0
,
rank
=
0
,
nrank
=
0
,
margin1
=
1.0
,
margin2
=
0.5
,
margin3
=
0.0
,
scale
=
64.0
):
return
paddle
.
nn
.
functional
.
margin_cross_entropy
(
logits
,
label
,
return_softmax
=
return_softmax
,
margin1
=
margin1
,
margin2
=
margin2
,
margin3
=
margin3
,
scale
=
scale
,
group
=
None
,
reduction
=
None
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestMarginCrossEntropyOp
(
OpTest
):
def
initParams
(
self
):
self
.
python_api
=
python_api
self
.
op_type
=
"margin_cross_entropy"
self
.
python_out_sig
=
[
"Loss"
]
self
.
axis
=
-
1
self
.
batch_dim
=
5
self
.
feat_dim
=
41
...
...
@@ -121,10 +145,14 @@ class TestMarginCrossEntropyOp(OpTest):
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
core
.
CUDAPlace
(
0
),
atol
=
1e-5
)
self
.
check_output_with_place
(
core
.
CUDAPlace
(
0
),
atol
=
1e-5
,
check_eager
=
True
)
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
core
.
CUDAPlace
(
0
),
[
"Logits"
],
"Loss"
)
self
.
check_grad_with_place
(
core
.
CUDAPlace
(
0
),
[
"Logits"
],
"Loss"
,
check_eager
=
True
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -138,7 +166,8 @@ class TestMarginCrossEntropyOpFP32(TestMarginCrossEntropyOp):
self
.
check_grad_with_place
(
core
.
CUDAPlace
(
0
),
[
"Logits"
],
"Loss"
,
numeric_grad_delta
=
5e-2
,
max_relative_error
=
5e-2
)
max_relative_error
=
5e-2
,
check_eager
=
True
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -149,13 +178,16 @@ class TestMarginCrossEntropyOpFP16(TestMarginCrossEntropyOp):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
self
.
check_output_with_place
(
core
.
CUDAPlace
(
0
),
atol
=
5e-2
)
self
.
check_output_with_place
(
core
.
CUDAPlace
(
0
),
atol
=
5e-2
,
check_eager
=
True
)
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
core
.
CUDAPlace
(
0
),
[
"Logits"
],
"Loss"
,
numeric_grad_delta
=
6e-1
,
max_relative_error
=
6e-1
)
max_relative_error
=
6e-1
,
check_eager
=
True
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -184,13 +216,17 @@ class TestMarginCrossEntropyOpCPU(TestMarginCrossEntropyOp):
def
test_check_output
(
self
):
try
:
self
.
check_output_with_place
(
core
.
CPUPlace
(),
atol
=
1e-5
)
self
.
check_output_with_place
(
core
.
CPUPlace
(),
atol
=
1e-5
,
check_eager
=
True
)
except
RuntimeError
:
pass
def
test_check_grad
(
self
):
try
:
self
.
check_grad_with_place
(
core
.
CPUPlace
(),
[
"Logits"
],
"Loss"
)
self
.
check_grad_with_place
(
core
.
CPUPlace
(),
[
"Logits"
],
"Loss"
,
check_eager
=
True
)
except
RuntimeError
:
pass
...
...
@@ -208,6 +244,7 @@ class TestMarginCrossEntropyOpV2(unittest.TestCase):
self
.
places
.
append
(
paddle
.
fluid
.
CUDAPlace
(
0
))
def
initParams
(
self
):
self
.
python_out_sig
=
[
"Loss"
]
self
.
seed
=
2021
self
.
axis
=
-
1
self
.
batch_dim
=
5
...
...
@@ -356,6 +393,8 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase):
self
.
places
.
append
(
paddle
.
fluid
.
CUDAPlace
(
0
))
def
initParams
(
self
):
self
.
python_api
=
python_api
self
.
python_out_sig
=
[
"Loss"
]
self
.
seed
=
2021
self
.
axis
=
-
1
self
.
batch_dim
=
10
...
...
python/paddle/nn/functional/loss.py
浏览文件 @
6d5744b4
...
...
@@ -1926,7 +1926,19 @@ def margin_cross_entropy(logits,
if
input_dims
-
1
==
label_dims
:
label
=
paddle
.
unsqueeze
(
label
,
axis
=-
1
)
if
in_dynamic_mode
():
if
in_dygraph_mode
():
softmax
,
loss
=
_C_ops
.
final_state_margin_cross_entropy
(
logits
,
label
,
return_softmax
,
ring_id
,
rank
,
nranks
,
margin1
,
margin2
,
margin3
,
scale
)
if
reduction
==
'mean'
:
loss
=
paddle
.
mean
(
loss
)
elif
reduction
==
'sum'
:
loss
=
paddle
.
sum
(
loss
)
if
not
return_softmax
:
return
loss
else
:
return
loss
,
softmax
elif
paddle
.
in_dynamic_mode
():
softmax
,
loss
=
_C_ops
.
margin_cross_entropy
(
logits
,
label
,
'ring_id'
,
ring_id
,
'rank'
,
rank
,
'nranks'
,
nranks
,
'margin1'
,
margin1
,
'margin2'
,
margin2
,
'margin3'
,
margin3
,
'scale'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录