未验证 提交 6d5744b4 编写于 作者: D duanboqiang 提交者: GitHub

[phi] migrate margin infer shape and yaml (#44940)

* add margin infer

* migrate yaml

* modify unittests script
上级 7b29c89b
......@@ -12,8 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -22,55 +25,6 @@ class MarginCrossEntropyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(
ctx->HasInput("Logits"), "Input", "Logits", "MarginCrossEntropyOp");
OP_INOUT_CHECK(
ctx->HasInput("Label"), "Input", "Label", "MarginCrossEntropyOp");
OP_INOUT_CHECK(
ctx->HasOutput("Softmax"), "Output", "Softmax", "MarginCrossEntropyOp");
OP_INOUT_CHECK(
ctx->HasOutput("Loss"), "Output", "Loss", "MarginCrossEntropyOp");
auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
auto logits_rank = logits_dims.size();
auto axis = logits_rank - 1;
for (int i = 0; i < logits_rank; i++) {
if (i != axis) {
if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[i],
labels_dims[i],
platform::errors::InvalidArgument(
"Input(Logits) and Input(Label) should in "
"same shape in dimensions except axis."));
}
}
}
if (labels_dims.size() > 1) {
PADDLE_ENFORCE_EQ(
labels_dims[logits_rank - 1],
1UL,
platform::errors::InvalidArgument(
"the last dimension of Input(Label) should be 1."
"But received: the last dimension of Input(Label) is [%d],"
"the last dimension is [%d]",
labels_dims[logits_rank - 1],
logits_rank - 1));
}
ctx->SetOutputDim("Softmax", logits_dims);
logits_dims[axis] = 1;
ctx->SetOutputDim("Loss", logits_dims);
ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -141,29 +95,6 @@ class MarginCrossEntropyOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Loss")),
true,
platform::errors::InvalidArgument(
"Input(Loss@Grad) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"),
true,
platform::errors::InvalidArgument(
"Input(Softmax) should be not null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"),
true,
platform::errors::InvalidArgument("Input(Label) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Logits")),
true,
platform::errors::InvalidArgument(
"Output(Logits@Grad) should be not null."));
ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Softmax"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -195,13 +126,21 @@ class MarginCrossEntropyOpGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(margin_cross_entropy,
MarginCrossEntropyInferShapeFunctor,
PD_INFER_META(phi::MarginCrossEntropyInferMeta));
REGISTER_OPERATOR(
margin_cross_entropy,
ops::MarginCrossEntropyOp,
ops::MarginCrossEntropyOpMaker,
ops::MarginCrossEntropyOpGradMaker<paddle::framework::OpDesc>,
ops::MarginCrossEntropyOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(margin_cross_entropy_grad, ops::MarginCrossEntropyOpGrad);
ops::MarginCrossEntropyOpGradMaker<paddle::imperative::OpBase>,
MarginCrossEntropyInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(
margin_cross_entropy_grad,
MarginCrossEntropyGradInferShapeFunctor,
PD_INFER_META(phi::MarginCrossEntropyGradInferMeta));
REGISTER_OPERATOR(margin_cross_entropy_grad,
ops::MarginCrossEntropyOpGrad,
MarginCrossEntropyGradInferShapeFunctor);
......@@ -1564,6 +1564,16 @@
data_type : x
backward : lu_unpack_grad
- api : margin_cross_entropy
args : (Tensor logits, Tensor label, bool return_softmax, int ring_id, int rank, int nranks, float margin1, float margin2, float margin3, float scale)
output : Tensor(softmax), Tensor(loss)
infer_meta :
func : MarginCrossEntropyInferMeta
kernel :
func : margin_cross_entropy
data_type : logits
backward : margin_cross_entropy_grad
# masked_select
- api : masked_select
args : (Tensor x, Tensor mask)
......
......@@ -1336,6 +1336,17 @@
kernel :
func : lu_unpack_grad
- backward_api : margin_cross_entropy_grad
forward : margin_cross_entropy (Tensor logits, Tensor label, bool return_softmax, int ring_id, int rank, int nranks, float margin1, float margin2, float margin3, float scale) -> Tensor(softmax), Tensor(loss)
args : (Tensor logits, Tensor label, Tensor softmax, Tensor loss_grad, bool return_softmax, int ring_id, int rank, int nranks, float margin1, float margin2, float margin3, float scale)
output : Tensor(logits_grad)
infer_meta :
func : MarginCrossEntropyGradInferMeta
kernel :
func : margin_cross_entropy_grad
data_type : softmax
inplace : (softmax -> logits_grad)
- backward_api : masked_select_grad
forward : masked_select (Tensor x, Tensor mask) -> Tensor(out)
args : (Tensor x, Tensor mask, Tensor out_grad)
......
......@@ -560,6 +560,30 @@ void LUUnpackGradInferMeta(const MetaTensor& x,
}
}
void MarginCrossEntropyGradInferMeta(const MetaTensor& logits,
const MetaTensor& label,
const MetaTensor& softmax,
const MetaTensor& loss_grad,
bool return_softmax,
int ring_id,
int rank,
int nranks,
float margin1,
float margin2,
float margin3,
float scale,
MetaTensor* logits_grad) {
PADDLE_ENFORCE_NE(
logits_grad,
nullptr,
phi::errors::InvalidArgument(
"The Logits@GRAD in MarginCrossEntropy can't be nullptr."));
auto softmax_dims = softmax.dims();
logits_grad->set_dims(softmax_dims);
logits_grad->set_dtype(softmax.dtype());
}
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
......
......@@ -245,6 +245,20 @@ void LUUnpackGradInferMeta(const MetaTensor& x,
bool unpack_pivots,
MetaTensor* x_grad);
void MarginCrossEntropyGradInferMeta(const MetaTensor& logits,
const MetaTensor& label,
const MetaTensor& softmax,
const MetaTensor& loss_grad,
bool return_softmax,
int ring_id,
int rank,
int nranks,
float margin1,
float margin2,
float margin3,
float scale,
MetaTensor* logits_grad);
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
......
......@@ -1545,6 +1545,65 @@ void LUUnpackInferMeta(const MetaTensor& x,
}
}
void MarginCrossEntropyInferMeta(const MetaTensor& logits,
const MetaTensor& label,
bool return_softmax,
int ring_id,
int rank,
int nranks,
float margin1,
float margin2,
float margin3,
float scale,
MetaTensor* softmax,
MetaTensor* loss,
MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL(
logits,
phi::errors::InvalidArgument("Input of logits should not be null."));
PADDLE_ENFORCE_NOT_NULL(
label,
phi::errors::InvalidArgument("Input of label should not be null."));
auto logits_dims = logits.dims();
auto labels_dims = label.dims();
auto logits_rank = logits_dims.size();
auto axis = logits_rank - 1;
for (int i = 0; i < logits_rank; i++) {
if (i != axis) {
if (config.is_runtime || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[i],
labels_dims[i],
phi::errors::InvalidArgument(
"Input(Logits) and Input(Label) should in "
"same shape in dimensions except axis."));
}
}
}
if (labels_dims.size() > 1) {
PADDLE_ENFORCE_EQ(
labels_dims[logits_rank - 1],
1UL,
phi::errors::InvalidArgument(
"the last dimension of Input(Label) should be 1."
"But received: the last dimension of Input(Label) is [%d],"
"the last dimension is [%d]",
labels_dims[logits_rank - 1],
logits_rank - 1));
}
softmax->set_dims(logits_dims);
softmax->set_dtype(logits.dtype());
logits_dims[axis] = 1;
loss->set_dims(logits_dims);
loss->set_dtype(logits.dtype());
softmax->share_lod(logits);
loss->share_lod(logits);
}
void MaskedSelectInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out) {
......
......@@ -240,6 +240,20 @@ void LUUnpackInferMeta(const MetaTensor& x,
MetaTensor* l,
MetaTensor* u);
void MarginCrossEntropyInferMeta(const MetaTensor& logits,
const MetaTensor& label,
bool return_softmax,
int ring_id,
int rank,
int nranks,
float margin1,
float margin2,
float margin3,
float scale,
MetaTensor* softmax,
MetaTensor* loss,
MetaConfig config = MetaConfig());
void MaskedSelectInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/margin_cross_entropy_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MarginCrossEntropyKernel(const Context& dev_ctx,
const DenseTensor& logits,
const DenseTensor& labels,
bool return_softmax,
int ring_id,
int rank,
int nranks,
float margin1,
float margin2,
float margin3,
float scale,
DenseTensor* softmax,
DenseTensor* loss) {
PADDLE_THROW(
errors::Unavailable("Do not support margin_cross_entropy for cpu kernel "
"now."));
}
} // namespace phi
PD_REGISTER_KERNEL(margin_cross_entropy,
CPU,
ALL_LAYOUT,
phi::MarginCrossEntropyKernel,
float,
double,
phi::dtype::float16) {}
......@@ -378,7 +378,6 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
DenseTensor sum_exp_logits;
sum_exp_logits.Resize({N, 1});
dev_ctx.template Alloc<T>(&sum_exp_logits);
// T* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place);
T* sum_exp_logits_buff = dev_ctx.template Alloc<T>(&sum_exp_logits);
phi::funcs::ReduceKernel<T, T, phi::kps::AddFunctor, phi::kps::ExpFunctor<T>>(
static_cast<const phi::GPUContext&>(dev_ctx),
......
......@@ -66,12 +66,36 @@ def margin_cross_entropy(logits,
return loss, softmax
def python_api(logits,
label,
return_softmax=False,
ring_id=0,
rank=0,
nrank=0,
margin1=1.0,
margin2=0.5,
margin3=0.0,
scale=64.0):
return paddle.nn.functional.margin_cross_entropy(
logits,
label,
return_softmax=return_softmax,
margin1=margin1,
margin2=margin2,
margin3=margin3,
scale=scale,
group=None,
reduction=None)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestMarginCrossEntropyOp(OpTest):
def initParams(self):
self.python_api = python_api
self.op_type = "margin_cross_entropy"
self.python_out_sig = ["Loss"]
self.axis = -1
self.batch_dim = 5
self.feat_dim = 41
......@@ -121,10 +145,14 @@ class TestMarginCrossEntropyOp(OpTest):
}
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0), atol=1e-5)
self.check_output_with_place(core.CUDAPlace(0),
atol=1e-5,
check_eager=True)
def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ["Logits"], "Loss")
self.check_grad_with_place(core.CUDAPlace(0), ["Logits"],
"Loss",
check_eager=True)
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -138,7 +166,8 @@ class TestMarginCrossEntropyOpFP32(TestMarginCrossEntropyOp):
self.check_grad_with_place(core.CUDAPlace(0), ["Logits"],
"Loss",
numeric_grad_delta=5e-2,
max_relative_error=5e-2)
max_relative_error=5e-2,
check_eager=True)
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -149,13 +178,16 @@ class TestMarginCrossEntropyOpFP16(TestMarginCrossEntropyOp):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0), atol=5e-2)
self.check_output_with_place(core.CUDAPlace(0),
atol=5e-2,
check_eager=True)
def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ["Logits"],
"Loss",
numeric_grad_delta=6e-1,
max_relative_error=6e-1)
max_relative_error=6e-1,
check_eager=True)
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -184,13 +216,17 @@ class TestMarginCrossEntropyOpCPU(TestMarginCrossEntropyOp):
def test_check_output(self):
try:
self.check_output_with_place(core.CPUPlace(), atol=1e-5)
self.check_output_with_place(core.CPUPlace(),
atol=1e-5,
check_eager=True)
except RuntimeError:
pass
def test_check_grad(self):
try:
self.check_grad_with_place(core.CPUPlace(), ["Logits"], "Loss")
self.check_grad_with_place(core.CPUPlace(), ["Logits"],
"Loss",
check_eager=True)
except RuntimeError:
pass
......@@ -208,6 +244,7 @@ class TestMarginCrossEntropyOpV2(unittest.TestCase):
self.places.append(paddle.fluid.CUDAPlace(0))
def initParams(self):
self.python_out_sig = ["Loss"]
self.seed = 2021
self.axis = -1
self.batch_dim = 5
......@@ -356,6 +393,8 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase):
self.places.append(paddle.fluid.CUDAPlace(0))
def initParams(self):
self.python_api = python_api
self.python_out_sig = ["Loss"]
self.seed = 2021
self.axis = -1
self.batch_dim = 10
......
......@@ -1926,7 +1926,19 @@ def margin_cross_entropy(logits,
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=-1)
if in_dynamic_mode():
if in_dygraph_mode():
softmax, loss = _C_ops.final_state_margin_cross_entropy(
logits, label, return_softmax, ring_id, rank, nranks, margin1,
margin2, margin3, scale)
if reduction == 'mean':
loss = paddle.mean(loss)
elif reduction == 'sum':
loss = paddle.sum(loss)
if not return_softmax:
return loss
else:
return loss, softmax
elif paddle.in_dynamic_mode():
softmax, loss = _C_ops.margin_cross_entropy(
logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks,
'margin1', margin1, 'margin2', margin2, 'margin3', margin3, 'scale',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册