未验证 提交 8e032db8 编写于 作者: Z zyfncg 提交者: GitHub

Add nll_loss yaml (#41126)

* add nll_loss yaml

* fix nll loss

* fix nll loss bug

* fix bug

* fix bug

* fix infrt problem
Co-authored-by: Nxiongkun <xiongkun03@baidu.com>
上级 934cbcd8
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/ternary.h"
namespace paddle { namespace paddle {
...@@ -94,68 +95,6 @@ class NLLLossGradOp : public framework::OperatorWithKernel { ...@@ -94,68 +95,6 @@ class NLLLossGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NLLLoss");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "NLLLoss");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "NLLLoss");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "NLLLoss");
auto reduction = ctx->Attrs().Get<std::string>("reduction");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
bool contain_unknown_dim =
phi::contain_unknown_dim(x_dims) || phi::contain_unknown_dim(dout_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
auto batch_size = x_dims[0];
if (x_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dout_dims.size(), 1,
platform::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be 1"));
if (reduction == "none") {
PADDLE_ENFORCE_EQ(
dout_dims[0], batch_size,
platform::errors::InvalidArgument(
"The unreduced size ofInput(Out@Grad) must be the "
"same as batch_size."));
} else {
PADDLE_ENFORCE_EQ(
dout_dims[0], 1,
platform::errors::InvalidArgument(
"The reduced size of Input(Out@Grad) must be 1"));
}
} else if (x_dims.size() == 4) {
if (reduction == "none") {
PADDLE_ENFORCE_EQ(
dout_dims.size(), 3,
platform::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be 3,But got [%s].",
dout_dims.size()));
PADDLE_ENFORCE_EQ(
dout_dims[0] == label_dims[0] && dout_dims[1] == label_dims[1] &&
dout_dims[2] == label_dims[2],
true, platform::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be match "
"to Input(Label) dimensions."));
} else {
PADDLE_ENFORCE_EQ(
dout_dims[0], 1,
platform::errors::InvalidArgument(
"The reduced size of Input(Out@Grad) must be 1"));
}
}
}
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -192,9 +131,12 @@ class NLLLossGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -192,9 +131,12 @@ class NLLLossGradMaker : public framework::SingleGradOpMaker<T> {
DECLARE_INFER_SHAPE_FUNCTOR(nll_loss, NllLossRawInferShapeFunctor, DECLARE_INFER_SHAPE_FUNCTOR(nll_loss, NllLossRawInferShapeFunctor,
PD_INFER_META(phi::NllLossRawInferMeta)); PD_INFER_META(phi::NllLossRawInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(nll_loss_grad, NllLossGradInferShapeFunctor,
PD_INFER_META(phi::NllLossGradInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(nll_loss, ops::NLLLossOp, ops::NLLLossOpMaker, REGISTER_OPERATOR(nll_loss, ops::NLLLossOp, ops::NLLLossOpMaker,
ops::NLLLossGradMaker<paddle::framework::OpDesc>, ops::NLLLossGradMaker<paddle::framework::OpDesc>,
ops::NLLLossGradMaker<paddle::imperative::OpBase>, ops::NLLLossGradMaker<paddle::imperative::OpBase>,
NllLossRawInferShapeFunctor); NllLossRawInferShapeFunctor);
REGISTER_OPERATOR(nll_loss_grad, ops::NLLLossGradOp); REGISTER_OPERATOR(nll_loss_grad, ops::NLLLossGradOp,
NllLossGradInferShapeFunctor);
...@@ -180,20 +180,23 @@ std::shared_ptr<phi::DenseTensor> PrepareData( ...@@ -180,20 +180,23 @@ std::shared_ptr<phi::DenseTensor> PrepareData(
const phi::TensorArgDef& target_args_def, const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) { const TransformFlag& transform_flag) {
const auto& tensor_in = input.impl(); const auto& tensor_in = input.impl();
phi::DenseTensor& dense_tensor = if (tensor_in) {
*static_cast<phi::DenseTensor*>(tensor_in.get()); phi::DenseTensor& dense_tensor =
if (!transform_flag.NeedTransform() || !dense_tensor.initialized() || *static_cast<phi::DenseTensor*>(tensor_in.get());
(!NeedTransformPlace( if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
dense_tensor.place(), target_args_def.backend, transform_flag) && (!NeedTransformPlace(
!NeedTransformDataType( dense_tensor.place(), target_args_def.backend, transform_flag) &&
dense_tensor.dtype(), target_args_def.dtype, transform_flag) && !NeedTransformDataType(
!NeedTransformLayout( dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
dense_tensor.layout(), target_args_def.layout, transform_flag))) { !NeedTransformLayout(
return std::static_pointer_cast<phi::DenseTensor>(tensor_in); dense_tensor.layout(), target_args_def.layout, transform_flag))) {
return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
}
phi::DenseTensor out =
TransformData(dense_tensor, target_args_def, transform_flag);
return std::make_shared<phi::DenseTensor>(std::move(out));
} }
phi::DenseTensor out = return nullptr;
TransformData(dense_tensor, target_args_def, transform_flag);
return std::make_shared<phi::DenseTensor>(std::move(out));
} }
std::shared_ptr<phi::DenseTensor> PrepareData( std::shared_ptr<phi::DenseTensor> PrepareData(
......
...@@ -180,6 +180,72 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, ...@@ -180,6 +180,72 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
dx->share_meta(x); dx->share_meta(x);
} }
void NllLossGradInferMeta(const MetaTensor& x,
const MetaTensor& label,
paddle::optional<const MetaTensor&> weight,
const MetaTensor& total_weight,
const MetaTensor& out_grad,
int64_t ignore_index,
const std::string& reduction,
MetaTensor* dx,
MetaConfig config) {
const auto& x_dims = x.dims();
const auto& label_dims = label.dims();
const auto& dout_dims = out_grad.dims();
bool contain_unknown_dim =
phi::contain_unknown_dim(x_dims) || phi::contain_unknown_dim(dout_dims);
bool check = config.is_runtime || !contain_unknown_dim;
if (check) {
auto batch_size = x_dims[0];
if (x_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dout_dims.size(),
1,
phi::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be 1"));
if (reduction == "none") {
PADDLE_ENFORCE_EQ(
dout_dims[0],
batch_size,
phi::errors::InvalidArgument(
"The unreduced size ofInput(Out@Grad) must be the "
"same as batch_size."));
} else {
PADDLE_ENFORCE_EQ(dout_dims[0],
1,
phi::errors::InvalidArgument(
"The reduced size of Input(Out@Grad) must be 1"));
}
} else if (x_dims.size() == 4) {
if (reduction == "none") {
PADDLE_ENFORCE_EQ(
dout_dims.size(),
3,
phi::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be 3,But got [%s].",
dout_dims.size()));
PADDLE_ENFORCE_EQ(dout_dims[0] == label_dims[0] &&
dout_dims[1] == label_dims[1] &&
dout_dims[2] == label_dims[2],
true,
phi::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be match "
"to Input(Label) dimensions."));
} else {
PADDLE_ENFORCE_EQ(dout_dims[0],
1,
phi::errors::InvalidArgument(
"The reduced size of Input(Out@Grad) must be 1"));
}
}
}
if (dx) {
dx->set_dims(x_dims);
dx->set_dtype(x.dtype());
}
}
void PoolGradInferMeta(const MetaTensor& x, void PoolGradInferMeta(const MetaTensor& x,
const MetaTensor& out, const MetaTensor& out,
const MetaTensor& dout, const MetaTensor& dout,
......
...@@ -104,6 +104,16 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, ...@@ -104,6 +104,16 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
bool adaptive, bool adaptive,
MetaTensor* dx); MetaTensor* dx);
void NllLossGradInferMeta(const MetaTensor& input,
const MetaTensor& label,
paddle::optional<const MetaTensor&> weight,
const MetaTensor& total_weight,
const MetaTensor& out_grad,
int64_t ignore_index,
const std::string& reduction,
MetaTensor* intput_grad,
MetaConfig config = MetaConfig());
void PsroiPoolGradInferMeta(const MetaTensor& x, void PsroiPoolGradInferMeta(const MetaTensor& x,
const MetaTensor& rois, const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num, paddle::optional<const MetaTensor&> rois_num,
......
...@@ -710,6 +710,8 @@ class OpTest(unittest.TestCase): ...@@ -710,6 +710,8 @@ class OpTest(unittest.TestCase):
def prepare_python_api_arguments(api, op_proto_ins, op_proto_attrs, def prepare_python_api_arguments(api, op_proto_ins, op_proto_attrs,
kernel_sig): kernel_sig):
""" map from `op proto inputs and attrs` to `api input list and api attrs dict` """ map from `op proto inputs and attrs` to `api input list and api attrs dict`
NOTE: the op_proto_attrs and op_proto_ins is a default dict. default value is []
""" """
class Empty: class Empty:
...@@ -770,7 +772,9 @@ class OpTest(unittest.TestCase): ...@@ -770,7 +772,9 @@ class OpTest(unittest.TestCase):
api_params), "Error happens. contack xiongkun03 to solve." api_params), "Error happens. contack xiongkun03 to solve."
inputs_sig, attrs_sig, outputs_sig = kernel_sig inputs_sig, attrs_sig, outputs_sig = kernel_sig
inputs_and_attrs = inputs_sig + attrs_sig inputs_and_attrs = inputs_sig + attrs_sig
input_arguments = [op_proto_ins[name] for name in inputs_sig] + [ input_arguments = [
op_proto_ins.get(name, Empty()) for name in inputs_sig
] + [
parse_attri_value(name, op_proto_ins, op_proto_attrs) parse_attri_value(name, op_proto_ins, op_proto_attrs)
for name in attrs_sig for name in attrs_sig
] ]
...@@ -814,16 +818,19 @@ class OpTest(unittest.TestCase): ...@@ -814,16 +818,19 @@ class OpTest(unittest.TestCase):
transform inputs by the following rules: transform inputs by the following rules:
1. [Tensor] -> Tensor 1. [Tensor] -> Tensor
2. [Tensor, Tensor, ...] -> list of Tensors 2. [Tensor, Tensor, ...] -> list of Tensors
3. None -> None
4. Others: raise Error
only support "X" is list of Tensor, currently don't support other structure like dict. only support "X" is list of Tensor, currently don't support other structure like dict.
""" """
for inp in args[:inp_num]: inp_args = [[inp] if inp is None else inp
for inp in args[:inp_num]] # convert None -> [None]
for inp in inp_args:
assert isinstance( assert isinstance(
inp, list inp, list
), "currently only support `X` is [Tensor], don't support other structure." ), "currently only support `X` is [Tensor], don't support other structure."
args = [ args = [inp[0] if len(inp) == 1 else inp
inp[0] if len(inp) == 1 else inp for inp in args[:inp_num] for inp in inp_args] + args[inp_num:]
] + args[inp_num:]
return args return args
def _get_kernel_signature(eager_tensor_inputs, eager_tensor_outputs, def _get_kernel_signature(eager_tensor_inputs, eager_tensor_outputs,
......
...@@ -763,6 +763,8 @@ class TestNLLLossOp1DWithReduce(OpTest): ...@@ -763,6 +763,8 @@ class TestNLLLossOp1DWithReduce(OpTest):
def setUp(self): def setUp(self):
self.init_test_case() self.init_test_case()
self.op_type = "nll_loss" self.op_type = "nll_loss"
self.python_api = paddle.nn.functional.nll_loss
self.python_out_sig = ["Out"]
self.with_weight = False self.with_weight = False
self.python_api = paddle.nn.functional.nll_loss self.python_api = paddle.nn.functional.nll_loss
self.python_out_sig = ["Out"] self.python_out_sig = ["Out"]
...@@ -786,19 +788,19 @@ class TestNLLLossOp1DWithReduce(OpTest): ...@@ -786,19 +788,19 @@ class TestNLLLossOp1DWithReduce(OpTest):
self.attrs = {'reduction': 'mean', 'ignore_index': -100} self.attrs = {'reduction': 'mean', 'ignore_index': -100}
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=True)
def test_check_output_with_weight(self): def test_check_output_with_weight(self):
self.with_weight = True self.with_weight = True
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.with_weight = True self.with_weight = True
place = fluid.CPUPlace() place = fluid.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out', check_eager=False) self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)
def init_test_case(self): def init_test_case(self):
self.input_shape = [10, 10] self.input_shape = [10, 10]
...@@ -809,6 +811,8 @@ class TestNLLLossOp1DNoReduce(OpTest): ...@@ -809,6 +811,8 @@ class TestNLLLossOp1DNoReduce(OpTest):
def setUp(self): def setUp(self):
self.init_test_case() self.init_test_case()
self.op_type = "nll_loss" self.op_type = "nll_loss"
self.python_api = paddle.nn.functional.nll_loss
self.python_out_sig = ["Out"]
self.with_weight = False self.with_weight = False
np.random.seed(200) np.random.seed(200)
input_np = np.random.uniform(0.1, 0.8, input_np = np.random.uniform(0.1, 0.8,
...@@ -831,19 +835,19 @@ class TestNLLLossOp1DNoReduce(OpTest): ...@@ -831,19 +835,19 @@ class TestNLLLossOp1DNoReduce(OpTest):
self.attrs = {'reduction': 'none', 'ignore_index': -100} self.attrs = {'reduction': 'none', 'ignore_index': -100}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def test_check_output_with_weight(self): def test_check_output_with_weight(self):
self.with_weight = True self.with_weight = True
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.with_weight = True self.with_weight = True
place = fluid.CPUPlace() place = fluid.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)
def init_test_case(self): def init_test_case(self):
self.input_shape = [10, 10] self.input_shape = [10, 10]
...@@ -854,6 +858,8 @@ class TestNLLLossOp2DWithReduce(OpTest): ...@@ -854,6 +858,8 @@ class TestNLLLossOp2DWithReduce(OpTest):
def setUp(self): def setUp(self):
self.init_test_case() self.init_test_case()
self.op_type = "nll_loss" self.op_type = "nll_loss"
self.python_api = paddle.nn.functional.nll_loss
self.python_out_sig = ["Out"]
self.with_weight = False self.with_weight = False
np.random.seed(200) np.random.seed(200)
input_np = np.random.uniform(0.1, 0.8, input_np = np.random.uniform(0.1, 0.8,
...@@ -875,19 +881,19 @@ class TestNLLLossOp2DWithReduce(OpTest): ...@@ -875,19 +881,19 @@ class TestNLLLossOp2DWithReduce(OpTest):
self.attrs = {'reduction': 'mean', 'ignore_index': -100} self.attrs = {'reduction': 'mean', 'ignore_index': -100}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def test_check_output_with_weight(self): def test_check_output_with_weight(self):
self.with_weight = True self.with_weight = True
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.with_weight = True self.with_weight = True
place = fluid.CPUPlace() place = fluid.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)
def init_test_case(self): def init_test_case(self):
self.input_shape = [2, 3, 5, 5] self.input_shape = [2, 3, 5, 5]
...@@ -898,6 +904,8 @@ class TestNLLLossOp2DNoReduce(OpTest): ...@@ -898,6 +904,8 @@ class TestNLLLossOp2DNoReduce(OpTest):
def setUp(self): def setUp(self):
self.init_test_case() self.init_test_case()
self.op_type = "nll_loss" self.op_type = "nll_loss"
self.python_api = paddle.nn.functional.nll_loss
self.python_out_sig = ["Out"]
self.with_weight = False self.with_weight = False
np.random.seed(200) np.random.seed(200)
input_np = np.random.uniform(0.1, 0.8, input_np = np.random.uniform(0.1, 0.8,
...@@ -920,19 +928,19 @@ class TestNLLLossOp2DNoReduce(OpTest): ...@@ -920,19 +928,19 @@ class TestNLLLossOp2DNoReduce(OpTest):
self.attrs = {'reduction': 'none', 'ignore_index': -100} self.attrs = {'reduction': 'none', 'ignore_index': -100}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def test_check_output_with_weight(self): def test_check_output_with_weight(self):
self.with_weight = True self.with_weight = True
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.with_weight = True self.with_weight = True
place = fluid.CPUPlace() place = fluid.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)
def init_test_case(self): def init_test_case(self):
self.input_shape = [5, 3, 5, 5] self.input_shape = [5, 3, 5, 5]
......
...@@ -784,7 +784,17 @@ def nll_loss(input, ...@@ -784,7 +784,17 @@ def nll_loss(input,
input_dims)) input_dims))
n = input_shape[0] n = input_shape[0]
c = input_shape[1] c = input_shape[1]
if _non_static_mode(): if in_dygraph_mode():
if input_dims != 2 and input_dims != 4:
input, _ = _C_ops.reshape2(input, None, 'shape', [n, c, 1, -1])
label, _ = _C_ops.reshape2(label, None, 'shape', [n, 1, -1])
out_shape = [n] + input_shape[2:]
out, total_weight = _C_ops.final_state_nll_loss(input, label, weight,
ignore_index, reduction)
if input_dims != 2 and input_dims != 4 and reduction == 'none':
out, _ = _C_ops.reshape2(out, None, 'shape', out_shape)
return out
if _in_legacy_dygraph():
if input_dims != 2 and input_dims != 4: if input_dims != 2 and input_dims != 4:
input, _ = _C_ops.reshape2(input, None, 'shape', [n, c, 1, -1]) input, _ = _C_ops.reshape2(input, None, 'shape', [n, c, 1, -1])
label, _ = _C_ops.reshape2(label, None, 'shape', [n, 1, -1]) label, _ = _C_ops.reshape2(label, None, 'shape', [n, 1, -1])
......
...@@ -806,6 +806,17 @@ ...@@ -806,6 +806,17 @@
func : mv func : mv
backward : mv_grad backward : mv_grad
- api : nll_loss
args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index, str reduction)
output : Tensor(out), Tensor(total_weight)
infer_meta :
func : NllLossRawInferMeta
kernel :
func : nll_loss
data_type : input
optional : weight
backward : nll_loss_grad
- api : not_equal - api : not_equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y, int axis = -1)
output : Tensor output : Tensor
......
...@@ -460,15 +460,14 @@ ...@@ -460,15 +460,14 @@
func : mv_grad func : mv_grad
- backward_api : nll_loss_grad - backward_api : nll_loss_grad
forward : nll_loss (Tensor x, Tensor label, Tensor weight, int64_t ignore_index, str reduction) -> Tensor(out), Tensor(total_weight) forward : nll_loss (Tensor input, Tensor label, Tensor weight, int64_t ignore_index, str reduction) -> Tensor(out), Tensor(total_weight)
args : (Tensor x, Tensor label, Tensor weight, Tensor total_weight, Tensor out_grad, int64_t ignore_index, str reduction) args : (Tensor input, Tensor label, Tensor weight, Tensor total_weight, Tensor out_grad, int64_t ignore_index, str reduction)
output : Tensor (x_grad) output : Tensor(input_grad)
infer_meta : infer_meta :
func : UnchangedInferMeta func : NllLossGradInferMeta
param : [x]
kernel : kernel :
func : nll_loss_grad func : nll_loss_grad
data_type : out_grad data_type : input
optional : weight optional : weight
- backward_api : psroi_pool_grad - backward_api : psroi_pool_grad
......
...@@ -9,5 +9,5 @@ ...@@ -9,5 +9,5 @@
forward : sparse_relu(Tensor x) -> Tensor(out@SparseCooTensor) forward : sparse_relu(Tensor x) -> Tensor(out@SparseCooTensor)
args : (Tensor x, Tensor out_grad) args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad@SparseCooTensor) output : Tensor(x_grad@SparseCooTensor)
kernel : kernel :
func : sparse_relu_grad func : sparse_relu_grad
{ {
"phi_apis":["conj"], "phi_apis":["conj", "nll_loss"],
"phi_kernels":["equal_all"] "phi_kernels":["equal_all"]
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册