未验证 提交 582c3d9f 编写于 作者: T tangnana925 提交者: GitHub

add norm、vector_norm、matrix_norm form python to c++ and add tripletMarginLoss (#5965)

* add test file at first

* add tripletMarginLoss py code

* module ok

* add  forward test

* amend test code

* delete import torch

* add autotest ok

* delete numpy test code

* amend docstring

* amend loss.py, delete None

* API transfer to C++

* motify module

* delete cout

* delete cout

* Submit some modified code first

* submit vector_norm functor

* matrix norm

* Refine max/min functor (#6359)

merge to dev_tripletMarginLoss

* replace reducemax and reducemin

* amend code error

* motify code

* delete norm2

* delete print

* delete norm2

* delete print

* motify review code

* add assert to c++

* motify review code

* add else

* motify review problem

* add code

* add test code

* motify code delete dim_check

* delete norm.py code

* delete print

* delete print

* delete pu norm

* delete error code

* motify docsting

* auto format by CI

* delete no use num_dims

* delete import torch lib

* delete CI bug code

* motify clip_grad_norm_ resolve autotest bug

* auto format by CI

* motify loss docstring

* motify norm docstring
Co-authored-by: NZhenhua <1209435+hengzi@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
上级 3bcd09da
......@@ -30,7 +30,8 @@ Functional operations for neural networks
.. autofunction:: tanh
.. autofunction:: silu
.. autofunction:: mish
.. autofunction:: one_hot
.. autofunction:: one_hot
.. autofunction:: triplet_margin_loss
.. autofunction:: dropout
.. autofunction:: upsample
.. autofunction:: affine_grid
......
......@@ -54,6 +54,7 @@ Operators for neural networks
LogSoftmax,
MSELoss,
MarginRankingLoss,
TripletMarginLoss,
MaxPool1d,
MaxPool2d,
MaxPool3d,
......
......@@ -744,6 +744,11 @@
signature: "Tensor (Tensor dy, Tensor label, Tensor theta, Float m1, Float m2, Float m3, Int64 depth) => CombinedMarginLossGrad"
bind_python: False
- name: "triplet_margin_loss"
signature: "Tensor (Tensor anchor, Tensor positive, Tensor negative, *, Float margin, Float p, Float eps, Bool swap, String reduction) => TripletMarginLoss"
bind_python: True
- name: "margin_ranking_loss"
signature: "Tensor (Tensor input_1, Tensor input_2, Tensor target, Float margin, String reduction) => MarginRankingLoss"
bind_python: True
......@@ -1128,6 +1133,32 @@
signature: "Tensor (Tensor dy, Tensor x, Scalar min=None, Scalar max=None) => ClampGrad"
bind_python: False
- name: "vector_norm"
signature:
[
"Tensor (Tensor input, Scalar ord=2, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None) => VectorNorm",
"Tensor (Tensor input, Scalar ord=2, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => VectorNorm",
]
bind_python: True
- name: "matrix_norm"
signature:
[
"Tensor (Tensor input, Scalar ord, Int32List dim, Bool keepdim=False, *, DataType dtype=None) => MatrixNorm",
"Tensor (Tensor input, String ord, Int32List dim, Bool keepdim=False, *, DataType dtype=None) => MatrixNorm",
]
bind_python: True
- name: "norm"
signature:
[
"Tensor (Tensor input, Scalar ord=None, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None) => Norm",
"Tensor (Tensor input, String ord, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None) => Norm",
"Tensor (Tensor input, Scalar ord=None, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => ScalarNorm",
"Tensor (Tensor input, String ord, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => ScalarNorm",
]
bind_python: True
- name: "dropout"
signature: "Tensor (Tensor x, Float p=0.5, Bool training=True, Generator generator=None) => Dropout"
bind_python: True
......
......@@ -735,6 +735,376 @@ class ClampFunctor {
std::shared_ptr<OpExpr> clip_max_op_;
};
class VectorNormFunctor {
public:
VectorNormFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& ord,
const Optional<std::vector<int32_t>>& input_dim, const bool& keepdim,
const Optional<Symbol<DType>>& dtype) const {
std::shared_ptr<one::Tensor> res;
Symbol<DType> dtype_val;
if (dtype) {
dtype_val = JUST(dtype);
if (!(dtype_val->data_type() == DataType::kFloat
|| dtype_val->data_type() == DataType::kDouble
|| dtype_val->data_type() == DataType::kFloat16
|| dtype_val->data_type() == DataType::kBFloat16)) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.vector_norm(): only supports floating point and "
"complex dtypes, but got: Int.";
}
} else {
if (!IsFloatingDataType(x->dtype()->data_type())) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.vector_norm(): only supports floating point and "
"complex dtypes, but got: Int.";
}
dtype_val = x->dtype();
}
std::vector<int32_t> dim;
if (!input_dim.has_value()) {
std::vector<int32_t> reduce_axis(x->shape()->NumAxes());
std::iota(reduce_axis.begin(), reduce_axis.end(), 0);
dim = reduce_axis;
} else {
std::vector<int32_t> dim_check;
dim_check = *JUST(input_dim);
for (int i = 0; i < dim_check.size(); ++i) {
if (dim_check[i] >= 0) {
dim.push_back(dim_check[i]);
} else {
dim.push_back(dim_check[i] + x->shape()->NumAxes());
}
}
}
if (ord.IsIntegral() || ord.IsFloatingPoint()) {
double ord_val = JUST(ord.As<double>());
if (ord_val == 0) {
std::vector<int32_t> dim_column(1, 0);
res = JUST(ReduceSum(JUST(ScalarLogicalNotEqual(x, 0)), dim_column, keepdim));
} else if (ord_val == INFINITY) {
res = JUST(ReduceMax(JUST(Abs(x)), dim, keepdim));
} else if (ord_val == -INFINITY) {
res = JUST(ReduceMin(JUST(Abs(x)), dim, keepdim));
} else {
res =
JUST(ScalarPow(JUST(ReduceSum(JUST(ScalarPow(JUST(Abs(x)), ord, false)), dim, keepdim)),
Scalar(1.0) / ord, false));
}
res = JUST(Cast(res, dtype_val));
return res;
} else {
UNIMPLEMENTED_THEN_RETURN()
<< "linalg_vector_norm(): argument 'ord' must be Number, not str.";
}
}
};
class ScalarVectorNormFunctor {
public:
ScalarVectorNormFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& ord,
const Scalar& input_dim, const bool& keepdim,
const Optional<Symbol<DType>>& dtype) const {
if (dtype) {
Symbol<DType> dtype_val = JUST(dtype);
if (!(dtype_val->data_type() == DataType::kFloat
|| dtype_val->data_type() == DataType::kDouble
|| dtype_val->data_type() == DataType::kFloat16
|| dtype_val->data_type() == DataType::kBFloat16)) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.vector_norm(): only supports the float, double, "
"cfloat and cdouble dtypes, but got: Int.";
}
} else {
if (!IsFloatingDataType(x->dtype()->data_type())) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.vector_norm(): only supports the float, double, "
"cfloat and cdouble dtypes, but got: Int.";
}
}
if (input_dim.IsIntegral()) {
std::vector<int32_t> dim(1, JUST(input_dim.As<int>()));
return functional::VectorNorm(x, ord, dim, keepdim, dtype);
} else {
UNIMPLEMENTED_THEN_RETURN() << "linalg.vector_norm(): only support int dim.";
}
}
};
class ScalarMatrixNormFunctor {
public:
ScalarMatrixNormFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& ord,
const std::vector<int32_t>& input_dim, const bool& keepdim,
const Optional<Symbol<DType>>& dtype) const {
std::shared_ptr<one::Tensor> res;
auto num_dims = x->shape()->NumAxes();
auto axis = input_dim.size();
CHECK_OR_RETURN(num_dims >= 2)
<< "linalg.matrix_norm(): input tensor must be a matrix or batch of matrices";
CHECK_OR_RETURN(axis == 2 && input_dim[0] != input_dim[1])
<< "linalg.matrix_norm(): input_dim must be a 2-tuple of ints with different elements";
Symbol<DType> dtype_val;
if (dtype) {
dtype_val = JUST(dtype);
if (!(dtype_val->data_type() == DataType::kFloat
|| dtype_val->data_type() == DataType::kDouble
|| dtype_val->data_type() == DataType::kFloat16
|| dtype_val->data_type() == DataType::kBFloat16)) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): only supports the float, double, "
"cfloat and cdouble dtypes, but got: Int.";
}
} else {
if (!IsFloatingDataType(x->dtype()->data_type())) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): only supports the float, double, "
"cfloat and cdouble dtypes, but got: Int.";
}
dtype_val = x->dtype();
}
std::vector<int32_t> dim_tmp;
for (int i = 0; i < axis; ++i) {
if (input_dim[i] >= 0) {
dim_tmp.push_back(input_dim[i]);
} else {
dim_tmp.push_back(input_dim[i] + num_dims);
}
}
std::vector<int32_t> dim(2);
double ord_tmp = JUST(ord.As<double>());
if (ord_tmp == INFINITY || ord_tmp == -INFINITY) {
dim = dim_tmp;
dim[0] = dim_tmp[1];
dim[1] = dim_tmp[0];
} else if (ord_tmp == 1 || ord_tmp == -1) {
dim = dim_tmp;
} else {
UNIMPLEMENTED_THEN_RETURN()
<< "linalg.matrix_norm(): Only support INFINITY,-INFINITY,1 or -1 data type.";
}
if (dim[1] > dim[0] && keepdim == false) { dim[1] -= 1; }
std::vector<int32_t> dim_tmp0_vec(1, dim[0]);
std::vector<int32_t> dim_tmp1_vec(1, dim[1]);
res = JUST(ReduceSum(JUST(Abs(x)), dim_tmp0_vec, keepdim));
if (ord_tmp == INFINITY || ord_tmp == 1) {
res = JUST(ReduceMax(res, dim_tmp1_vec, keepdim));
} else if (ord_tmp == -INFINITY || ord_tmp == -1) {
res = JUST(ReduceMin(res, dim_tmp1_vec, keepdim));
}
res = JUST(Cast(res, dtype_val));
return res;
}
};
class MatrixNormFunctor {
public:
MatrixNormFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::string& ord,
const std::vector<int32_t>& input_dim, const bool& keepdim,
const Optional<Symbol<DType>>& dtype) const {
std::shared_ptr<one::Tensor> res;
Symbol<DType> dtype_val;
if (dtype) {
dtype_val = JUST(dtype);
if (!(dtype_val->data_type() == DataType::kFloat
|| dtype_val->data_type() == DataType::kDouble
|| dtype_val->data_type() == DataType::kFloat16
|| dtype_val->data_type() == DataType::kBFloat16)) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): only supports the float, double, "
"cfloat and cdouble dtypes, but got: Int.";
}
} else {
if (!IsFloatingDataType(x->dtype()->data_type())) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): only supports the float, double, "
"cfloat and cdouble dtypes, but got: Int.";
}
dtype_val = x->dtype();
}
auto num_dims = x->shape()->NumAxes();
auto axis = input_dim.size();
std::vector<int32_t> dim_tmp(axis);
for (int i = 0; i < axis; ++i) {
if (input_dim[i] >= 0) {
dim_tmp.push_back(input_dim[i]);
} else {
dim_tmp.push_back(input_dim[i] + num_dims);
}
}
if (ord == "nuc") {
UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): Not support ord is nuc.";
} else if (ord == "fro") {
res = JUST(Sqrt(JUST(ReduceSum(JUST(Square(x)), dim_tmp, keepdim))));
} else {
UNIMPLEMENTED_THEN_RETURN() << "linalg.matrix_norm(): could not convert string to float:"
<< ord;
}
res = JUST(Cast(res, dtype_val));
return res;
}
};
class NormFunctor {
public:
NormFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& ord,
const Optional<std::vector<int32_t>>& input_dim, const bool& keepdim,
const Optional<Symbol<DType>>& dtype) const {
std::shared_ptr<one::Tensor> res;
if (dtype) {
Symbol<DType> dtype_val = JUST(dtype);
if (!(dtype_val->data_type() == DataType::kFloat
|| dtype_val->data_type() == DataType::kDouble
|| dtype_val->data_type() == DataType::kFloat16
|| dtype_val->data_type() == DataType::kBFloat16)) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and "
"cdouble dtypes, but got: Int.";
}
} else {
if (!IsFloatingDataType(x->dtype()->data_type())) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and "
"cdouble dtypes, but got: Int.";
}
}
Scalar ord_sca;
if (ord.has_value()) {
auto ord_type = (*JUST(ord)).IsIntegral();
if (ord_type) {
ord_sca = Scalar(JUST((*JUST(ord)).As<double>()));
} else {
ord_sca = *JUST(ord);
}
}
if (input_dim.has_value()) {
auto axis = (*JUST(input_dim)).size();
if (axis == 1) {
Scalar ord_val;
if (!ord.has_value()) {
ord_val = Scalar(2.0);
} else {
ord_val = ord_sca;
}
res = JUST(VectorNorm(x, ord_val, input_dim, keepdim, dtype));
} else if (axis > 2) {
res = JUST(MatrixNorm(x, ord_sca, *JUST(input_dim), keepdim, dtype));
} else if (axis == 2) {
if (!ord.has_value()) {
res = JUST(MatrixNorm(x, "fro", *JUST(input_dim), keepdim, dtype));
} else {
res = JUST(MatrixNorm(x, ord_sca, *JUST(input_dim), keepdim, dtype));
}
}
} else {
if (ord.has_value()) {
CHECK_OR_RETURN(x->shape()->NumAxes() <= 2)
<< "linalg.norm(): input must be 1-D or 2-D when dim is None and ord is not None";
if (x->shape()->NumAxes() == 1) {
res = JUST(VectorNorm(x, ord_sca, input_dim, keepdim, dtype));
} else {
std::vector<int32_t> dim{0, 1};
res = JUST(MatrixNorm(x, ord_sca, dim, keepdim, dtype));
}
} else {
std::vector<int32_t> dim(1, 2);
res = JUST(VectorNorm(JUST(Flatten(x, 0, -1)), Scalar(2.0), input_dim, keepdim, dtype));
}
}
return res;
}
};
class Norm2Functor {
public:
Norm2Functor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::string& ord,
const Optional<std::vector<int32_t>>& input_dim, const bool& keepdim,
const Optional<Symbol<DType>>& dtype) const {
std::shared_ptr<one::Tensor> res;
std::vector<int32_t> dim(x->shape()->NumAxes());
std::iota(dim.begin(), dim.end(), 0);
if (dtype) {
Symbol<DType> dtype_val = JUST(dtype);
if (!(dtype_val->data_type() == DataType::kFloat
|| dtype_val->data_type() == DataType::kDouble
|| dtype_val->data_type() == DataType::kFloat16
|| dtype_val->data_type() == DataType::kBFloat16)) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and "
"cdouble dtypes, but got: Int.";
}
} else {
if (!IsFloatingDataType(x->dtype()->data_type())) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and "
"cdouble dtypes, but got: Int.";
}
}
if (input_dim.has_value()) {
res = JUST(MatrixNorm(x, ord, *JUST(input_dim), keepdim, dtype));
} else {
res = JUST(MatrixNorm(x, ord, dim, keepdim, dtype));
}
return res;
}
};
class ScalarNormFunctor {
public:
ScalarNormFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& ord,
const Scalar& input_dim, const bool& keepdim,
const Optional<Symbol<DType>>& dtype) const {
if (dtype) {
Symbol<DType> dtype_val = JUST(dtype);
if (!(dtype_val->data_type() == DataType::kFloat
|| dtype_val->data_type() == DataType::kDouble
|| dtype_val->data_type() == DataType::kFloat16
|| dtype_val->data_type() == DataType::kBFloat16)) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and "
"cdouble dtypes, but got: Int.";
}
} else {
if (!IsFloatingDataType(x->dtype()->data_type())) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and "
"cdouble dtypes, but got: Int.";
}
}
if (input_dim.IsIntegral()) {
std::vector<int32_t> dim(1, JUST(input_dim.As<int>()));
return functional::Norm(x, ord, dim, keepdim, dtype);
} else {
UNIMPLEMENTED_THEN_RETURN() << "linalg_norm(): only supports int dim.";
}
}
};
class ScalarNorm2Functor {
public:
ScalarNorm2Functor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::string& ord,
const Scalar& input_dim, const bool& keepdim,
const Optional<Symbol<DType>>& dtype) const {
if (dtype) {
Symbol<DType> dtype_val = JUST(dtype);
if (!(dtype_val->data_type() == DataType::kFloat
|| dtype_val->data_type() == DataType::kDouble
|| dtype_val->data_type() == DataType::kFloat16
|| dtype_val->data_type() == DataType::kBFloat16)) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and "
"cdouble dtypes, but got: Int.";
}
} else {
if (!IsFloatingDataType(x->dtype()->data_type())) {
UNIMPLEMENTED_THEN_RETURN() << "linalg.norm(): only supports the float, double, cfloat and "
"cdouble dtypes, but got: Int.";
}
}
if (input_dim.IsIntegral()) {
std::vector<int32_t> dim(1, JUST(input_dim.As<int>()));
return functional::Norm(x, ord, dim, keepdim, dtype);
} else {
UNIMPLEMENTED_THEN_RETURN() << "linalg_norm(): only supports int dim.";
}
}
};
class ClampGradFunctor {
public:
ClampGradFunctor() {
......@@ -1155,6 +1525,10 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<ConsistentArangeFunctor, ConsistentArange2Functor>("ConsistentArange");
m.add_functor<CastFunctor>("Cast");
m.add_functor<ClampFunctor>("Clamp");
m.add_functor<VectorNormFunctor, ScalarVectorNormFunctor>("VectorNorm");
m.add_functor<ScalarMatrixNormFunctor, MatrixNormFunctor>("MatrixNorm");
m.add_functor<NormFunctor, Norm2Functor>("Norm");
m.add_functor<ScalarNormFunctor, ScalarNorm2Functor>("ScalarNorm");
m.add_functor<ClampGradFunctor>("ClampGrad");
m.add_functor<SelectTopNFunctor>("SelectTopN");
m.add_functor<MinimumFunctor>("Minimum");
......
......@@ -1083,6 +1083,44 @@ class CtcLossFunctor {
std::shared_ptr<OpExpr> op_xdivy_;
};
class TripletMarginLossFunctor {
public:
TripletMarginLossFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& anchor,
const std::shared_ptr<one::Tensor>& positive,
const std::shared_ptr<one::Tensor>& negative, const float& margin,
const float& p, const float& eps, const bool& swap,
const std::string& reduction) const {
int32_t dim_norm = anchor->ndim() - 1;
std::vector<int32_t> dim(1, dim_norm);
CHECK_OR_RETURN([&]() -> bool {
if ((reduction != "none") && (reduction != "sum") && (reduction != "mean")) return false;
return true;
}());
auto da_p = JUST(VectorNorm(JUST(ScalarAdd(eps, JUST(Sub(anchor, positive)))), p, dim, false,
anchor->dtype()));
auto da_n = JUST(VectorNorm(JUST(ScalarAdd(eps, JUST(Sub(anchor, negative)))), p, dim, false,
anchor->dtype()));
if (swap) {
auto distance_swap = JUST(VectorNorm(JUST(ScalarAdd(eps, JUST(Sub(positive, negative)))), p,
dim, false, positive->dtype()));
da_n = JUST(Minimum(distance_swap, da_n));
}
auto triplet_loss =
JUST(Clamp(JUST(ScalarAdd(JUST(Sub(da_p, da_n)), margin, false)), 0.0, NullOpt));
int32_t ndim = triplet_loss->ndim() - 1;
std::vector<int32_t> axis(1, ndim);
if (reduction == "mean") {
triplet_loss = JUST(ReduceMean(triplet_loss, axis, false));
} else if (reduction == "sum") {
triplet_loss = JUST(ReduceSum(triplet_loss, axis, false));
}
return triplet_loss;
}
};
class AffineGridFunctor {
public:
AffineGridFunctor() {
......@@ -1835,6 +1873,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::SoftmaxCrossEntropyGradFunctor>("SoftmaxCrossEntropyGrad");
m.add_functor<impl::SmoothL1LossFunctor>("SmoothL1Loss");
m.add_functor<impl::CombinedMarginLossFunctor>("CombinedMarginLoss");
m.add_functor<impl::TripletMarginLossFunctor>("TripletMarginLoss");
m.add_functor<impl::MarginRankingLossFunctor>("MarginRankingLoss");
m.add_functor<impl::CtcLossFunctor>("CtcLoss");
m.add_functor<impl::AffineGridFunctor>("AffineGrid");
......
......@@ -245,6 +245,7 @@ from oneflow.framework.generator import default_generator, manual_seed
from oneflow.framework.scope_util import api_current_scope as current_scope
from oneflow.framework.tensor import Tensor
from oneflow.framework.tensor import is_nonzero
from oneflow.nn.modules.pooling import (
adaptive_avg_pool1d,
adaptive_avg_pool2d,
......
......@@ -20,6 +20,8 @@ from .pooling import *
from .activation import *
from .dropout import *
from .vision import *
from .norm import *
from .loss import *
from .onehot import *
from .comparison import *
from .cast import *
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow
from oneflow.framework.docstr.utils import add_docstr
add_docstr(
oneflow._C.triplet_margin_loss,
r"""
The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.nn.functional.triplet_margin_loss.html?highlight=triplet_margin_loss
Creates a criterion that measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
This is used for measuring a relative similarity between samples. A triplet
is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative
examples` respectively). The shapes of all input tensors should be
:math:`(N, D)`.
The distance swap is described in detail in the paper `Learning shallow
convolutional feature descriptors with triplet losses <http://www.bmva.org/bmvc/2016/papers/paper119/index.html>`__ by
V. Balntas, E. Riba et al.
The loss function for each sample in the mini-batch is:
.. math::
L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
where
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
Args:
margin (float, optional): Default: :math:`1`.
p (float, optional): The norm degree for pairwise distance. Default: :math:`2.0`.
swap (bool, optional): The distance swap is described in detail in the paper
`Learning shallow convolutional feature descriptors with triplet losses` by
V. Balntas, E. Riba et al. Default: ``False``.
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
Shape:
- Input: :math:`(N, D)` where :math:`D` is the vector dimension.
- Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
otherwise.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> import numpy as np
>>> triplet_loss = flow.nn.TripletMarginLoss(margin=1.0, p=2)
>>> anchor = np.array([[1, -1, 1],[-1, 1, -1], [1, 1, 1]])
>>> positive = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> negative = np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]])
>>> output = triplet_loss(flow.Tensor(anchor), flow.Tensor(positive), flow.Tensor(negative))
>>> output
tensor(6.2971, dtype=oneflow.float32)
""",
)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow
from oneflow.framework.docstr.utils import add_docstr
add_docstr(
oneflow.linalg.vector_norm,
"""linalg.vector_norm(input, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor
Computes a vector norm.
Supports input of float, double dtypes.
This function does not necessarily treat multidimensonal attr:`input` as a batch of
vectors, instead:
- If :attr:`dim`\\ `= None`, :attr:`input` will be flattened before the norm is computed.
- If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions and the other dimensions will be treated as batch dimensions.
This behavior is for consistency with :func:`flow.linalg.norm`.
:attr:`ord` defines the vector norm that is computed. The following norms are supported:
====================== ========================================================
:attr:`ord` vector norm
====================== ========================================================
`2` (default) `2`-norm (see below)
`inf` `max(abs(x))`
`-inf` `min(abs(x))`
`0` `sum(x != 0)`
other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}`
====================== ========================================================
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.
Args:
input (Tensor): tensor, flattened by default, but this behavior can be
controlled using :attr:`dim`.
ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2`
dim (int, Tuple[int], optional): dimensions over which to compute
the norm. See above for the behavior when :attr:`dim`\\ `= None`.
Default: `None`
keepdim (bool, optional): If set to `True`, the reduced dimensions are retained
in the result as dimensions with size one. Default: `False`
Returns:
A real-valued tensor.
Examples:
.. code-block:: python
>>> import oneflow as flow
>>> from oneflow import linalg as LA
>>> import numpy as np
>>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4)
>>> a
tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=oneflow.float32)
>>> b = a.reshape(3, 3)
>>> b
tensor([[-4., -3., -2.],
[-1., 0., 1.],
[ 2., 3., 4.]], dtype=oneflow.float32)
>>> LA.vector_norm(a, ord=3.5)
tensor(5.4345, dtype=oneflow.float32)
>>> LA.vector_norm(b, ord=3.5)
tensor(5.4345, dtype=oneflow.float32)
""",
)
add_docstr(
oneflow.linalg.matrix_norm,
"""linalg.matrix_norm(input, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor
Computes a matrix norm.
Support input of float, double, cfloat and cdouble dtypes.
Also supports batches of matrices: the norm will be computed over the
dimensions specified by the 2-tuple :attr:`dim` and the other dimensions will
be treated as batch dimensions. The output will have the same batch dimensions.
:attr:`ord` defines the matrix norm that is computed. The following norms are supported:
====================== ========================================================
:attr:`ord` matrix norm
====================== ========================================================
`'fro'` (default) Frobenius norm
`'nuc'` -- not supported yet --
`inf` `max(sum(abs(x), dim=1))`
`-inf` `min(sum(abs(x), dim=1))`
`1` `max(sum(abs(x), dim=0))`
`-1` `min(sum(abs(x), dim=0))`
`2` -- not supported yet --
`-2` -- not supported yet --
====================== ========================================================
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.
Args:
input (Tensor): tensor with two or more dimensions. By default its
shape is interpreted as `(*, m, n)` where `*` is zero or more
batch dimensions, but this behavior can be controlled using :attr:`dim`.
ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'`
dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)`
keepdim (bool, optional): If set to `True`, the reduced dimensions are retained
in the result as dimensions with size one. Default: `False`
Returns:
A real-valued tensor.
Examples:
.. code-block:: python
>>> import oneflow as flow
>>> from oneflow import linalg as LA
>>> import numpy as np
>>> a = flow.tensor(np.arange(9, dtype=np.float32)).reshape(3,3)
>>> a
tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]], dtype=oneflow.float32)
>>> LA.matrix_norm(a)
tensor(14.2829, dtype=oneflow.float32)
>>> LA.matrix_norm(a, ord=-1)
tensor(9., dtype=oneflow.float32)
>>> b = a.expand(2, -1, -1)
>>> b
tensor([[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]],
<BLANKLINE>
[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]]], dtype=oneflow.float32)
>>> LA.matrix_norm(b, dim=(0, 2))
tensor([ 3.1623, 10.0000, 17.2627], dtype=oneflow.float32)
""",
)
add_docstr(
oneflow.linalg.norm,
"""linalg.norm(input, ord=None, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor
Returns the matrix norm or vector norm of a given tensor.
This function can calculate one of eight different types of matrix norms, or one
of an infinite number of vector norms, depending on both the number of reduction
dimensions and the value of the `ord` parameter.
Args:
input (Tensor): The input tensor. If dim is None, input must be 1-D or 2-D, unless :attr:`ord`
is None. If both :attr:`dim` and :attr:`ord` are None, the 2-norm of the input flattened to 1-D
will be returned. Its data type must be either a floating point or complex type. For complex
inputs, the norm is calculated on of the absolute values of each element. If the input is
complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will
be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat).
ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'None'`
The following norms can be calculated:
============== ============================ =================================
:attr:`ord` norm for matrices norm for vectors
============== ============================ =================================
None Frobenius norm `2`-norm
`'fro'` Frobenius norm -- not supported --
`'nuc'` -- not supported yet -- -- not supported --
`inf` `max(sum(abs(x), dim=1))` `max(abs(x))`
`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))`
`0` -- not supported -- `sum(x != 0)`
`1` `max(sum(abs(x), dim=0))` as below
`-1` `min(sum(abs(x), dim=0))` as below
`2` -- not supported yet -- as below
`-2` -- not supported yet -- as below
other -- not supported -- `sum(abs(x)^{ord})^{(1 / ord)}`
============== ============================ =================================
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.
dim (int, 2-tuple of ints, 2-list of ints, optional): If :attr:`dim` is an int,
vector norm will be calculated over the specified dimension. If :attr:`dim`
is a 2-tuple of ints, matrix norm will be calculated over the specified
dimensions. If :attr:`dim` is None, matrix norm will be calculated
when the input tensor has two dimensions, and vector norm will be
calculated when the input tensor has one dimension. Default: ``None``
keepdim (bool, optional): If set to True, the reduced dimensions are retained
in the result as dimensions with size one. Default: ``False``
out (Tensor, optional): The output tensor.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> from oneflow import linalg as LA
>>> import numpy as np
>>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4)
>>> a
tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=oneflow.float32)
>>> b = a.reshape(3, 3)
>>> b
tensor([[-4., -3., -2.],
[-1., 0., 1.],
[ 2., 3., 4.]], dtype=oneflow.float32)
>>> LA.norm(a)
tensor(7.7460, dtype=oneflow.float32)
>>> LA.norm(b)
tensor(7.7460, dtype=oneflow.float32)
>>> LA.norm(b, 'fro')
tensor(7.7460, dtype=oneflow.float32)
>>> LA.norm(a, float('inf'))
tensor(4., dtype=oneflow.float32)
>>> LA.norm(b, float('inf'))
tensor(9., dtype=oneflow.float32)
>>> LA.norm(a, -float('inf'))
tensor(0., dtype=oneflow.float32)
>>> LA.norm(b, -float('inf'))
tensor(2., dtype=oneflow.float32)
>>> LA.norm(a, 1)
tensor(20., dtype=oneflow.float32)
>>> LA.norm(b, 1)
tensor(7., dtype=oneflow.float32)
>>> LA.norm(a, -1)
tensor(0., dtype=oneflow.float32)
>>> LA.norm(b, -1)
tensor(6., dtype=oneflow.float32)
>>> LA.norm(a, 2)
tensor(7.7460, dtype=oneflow.float32)
>>> LA.norm(a, -2)
tensor(0., dtype=oneflow.float32)
>>> LA.norm(a, 3)
tensor(5.8480, dtype=oneflow.float32)
>>> LA.norm(a, -3)
tensor(0., dtype=oneflow.float32)
>>> c = flow.tensor([[1., 2., 3.],
... [-1, 1, 4]])
>>> LA.norm(c, dim=0)
tensor([1.4142, 2.2361, 5.0000], dtype=oneflow.float32)
>>> LA.norm(c, dim=1, keepdim = True)
tensor([[3.7417],
[4.2426]], dtype=oneflow.float32)
>>> LA.norm(c, ord=1, dim=1)
tensor([6., 6.], dtype=oneflow.float32)
""",
)
......@@ -174,6 +174,18 @@ def _contiguous(self):
return self
def _norm(self, ord=None, dim=None, keepdim=False, dtype=None):
return flow._C.norm(self, ord, dim, keepdim, dtype=dtype)
def _vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
return flow._C.vector_norm(self, ord, dim, keepdim, dtype=dtype)
def _matrix_norm(self, ord="fro", dim=(-2, -1), keepdim=False, dtype=None):
return flow._C.matrix_norm(self, ord, dim, keepdim, dtype=dtype)
def _transpose(self, dim0, dim1):
return flow._C.transpose(self, dim0, dim1)
......@@ -762,6 +774,9 @@ def RegisterMethods():
Tensor.tril = _tril
Tensor.triu = _triu
Tensor.contiguous = _contiguous
Tensor.norm = _norm
Tensor.vector_norm = _vector_norm
Tensor.matrix_norm = _matrix_norm
Tensor.transpose = _transpose
Tensor.relu = _relu
Tensor.softmax = _softmax
......
......@@ -13,6 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from oneflow.nn.modules.norm import matrix_norm_tensor_op as matrix_norm
from oneflow.nn.modules.norm import norm_op as norm
from oneflow.nn.modules.norm import vector_norm_tensor_op as vector_norm
from oneflow.framework.tensor import _norm as norm
from oneflow.framework.tensor import _vector_norm as vector_norm
from oneflow.framework.tensor import _matrix_norm as matrix_norm
......@@ -103,6 +103,7 @@ from oneflow.nn.modules.loss import (
NLLLoss,
SmoothL1Loss,
CombinedMarginLoss,
TripletMarginLoss,
)
from oneflow.nn.modules.normalization import GroupNorm, LayerNorm
from oneflow.nn.modules.padding import (
......
......@@ -54,6 +54,7 @@ from oneflow._C import dropout
from oneflow._C import smooth_l1_loss
from oneflow._C import pad
from oneflow._C import upsample
from oneflow._C import triplet_margin_loss
from oneflow._C import ctc_greedy_decoder
from oneflow.nn.modules.one_hot import one_hot
from oneflow.nn.modules.sparse import embedding
......
......@@ -927,6 +927,94 @@ class CombinedMarginLoss(Module):
)
class TripletMarginLoss(Module):
r"""Creates a criterion that measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
This is used for measuring a relative similarity between samples. A triplet
is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative
examples` respectively). The shapes of all input tensors should be
:math:`(N, D)`.
The distance swap is described in detail in the paper `Learning shallow
convolutional feature descriptors with triplet losses <http://www.bmva.org/bmvc/2016/papers/paper119/index.html>`__ by
V. Balntas, E. Riba et al.
The loss function for each sample in the mini-batch is:
.. math::
L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
where
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
Args:
margin (float, optional): Default: :math:`1`.
p (float, optional): The norm degree for pairwise distance. Default: :math:`2.0`.
swap (bool, optional): The distance swap is described in detail in the paper
`Learning shallow convolutional feature descriptors with triplet losses` by
V. Balntas, E. Riba et al. Default: ``False``.
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
Shape:
- Input: :math:`(N, D)` where :math:`D` is the vector dimension.
- Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
otherwise.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> import numpy as np
>>> triplet_loss = flow.nn.TripletMarginLoss(margin=1.0, p=2)
>>> anchor = np.array([[1, -1, 1],[-1, 1, -1], [1, 1, 1]])
>>> positive = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> negative = np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]])
>>> output = triplet_loss(flow.Tensor(anchor), flow.Tensor(positive), flow.Tensor(negative))
>>> output
tensor(6.2971, dtype=oneflow.float32)
"""
def __init__(
self,
margin: float = 1.0,
p: float = 2.0,
eps: float = 1e-6,
swap: bool = False,
size_average=None,
reduce=None,
reduction: str = "mean",
) -> None:
super().__init__()
self.margin = margin
self.p = p
self.eps = eps
self.swap = swap
self.reduction = reduction
def forward(self, anchor, positive, negative):
triplet_loss = flow._C.triplet_margin_loss(
anchor,
positive,
negative,
margin=self.margin,
p=self.p,
eps=self.eps,
swap=self.swap,
reduction=self.reduction,
)
return triplet_loss
if __name__ == "__main__":
import doctest
......
......@@ -18,437 +18,6 @@ from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module
def check_dim(num_dims, input_dim):
if input_dim == None:
dim = input_dim
elif isinstance(input_dim, (int, tuple)):
if isinstance(input_dim, int):
dim = input_dim if input_dim >= 0 else input_dim + num_dims
if dim >= num_dims or dim < 0:
raise IndexError("Dimension out of range")
else:
temp = list(input_dim)
for i in range(len(temp)):
temp[i] = temp[i] if temp[i] >= 0 else temp[i] + num_dims
if temp[i] >= num_dims or temp[i] < 0:
raise IndexError("Dimension out of range")
dim = temp
else:
raise TypeError(
"linalg_vector_norm(): argument 'dim' must be tuple of ints, not {}".format(
type(input_dim)
)
)
return dim
def _norm_min_max(input, ord, dim, keepdim):
if ord > 0:
temp = flow.max(input, dim=dim, keepdim=keepdim)
return temp if dim == None else temp[0]
else:
temp = flow.min(input, dim=dim, keepdim=keepdim)
return temp if dim == None else temp[0]
class Vector_Norm(Module):
def __init__(self, ord=2, dim=None, keepdim=False) -> None:
super().__init__()
if ord == None:
self.ord = 2.0
elif isinstance(ord, (int, float)):
self.ord = float(ord)
else:
raise TypeError(
"linalg_vector_norm(): argument 'ord' must be Number, not {}".format(
type(ord)
)
)
self.dim = dim
self.keepdim = keepdim
def _vector_norm(self, x, ord, dim, keepdim=False):
if ord == 0:
return flow.cast(flow.tensor([flow.argwhere(x).shape[0]]), flow.float32)
elif ord == float("inf"):
temp = flow.max(flow.abs(x), dim=dim, keepdim=keepdim)
return temp if dim == None else temp[0]
elif ord == float("-inf"):
temp = flow.min(flow.abs(x), dim=dim, keepdim=keepdim)
return temp if dim == None else temp[0]
else:
return flow.pow(
flow.sum(flow.pow(flow.abs(x), ord), dim=dim, keepdim=keepdim),
1.0 / ord,
)
def forward(self, x):
num_dims = len(x.shape)
dim = check_dim(num_dims, self.dim)
if dim == None:
return self._vector_norm(
x.flatten(), ord=self.ord, dim=self.dim, keepdim=self.keepdim
)
else:
return self._vector_norm(x, ord=self.ord, dim=dim, keepdim=self.keepdim)
class Matrix_Norm(Module):
def __init__(self, ord="fro", dim=(-2, -1), keepdim=False) -> None:
super().__init__()
if isinstance(ord, str):
assert ord in ["fro", "nuc"], "{} are not supported in matrix norm".format(
ord
)
self.ord = ord
elif isinstance(ord, float):
assert ord in [
float("inf"),
float("-inf"),
], "{} are not supported in matrix norm".format(ord)
self.ord = ord
elif isinstance(ord, int):
assert ord in [1, -1, 2, -2], "{} are not supported in matrix norm".format(
ord
)
self.ord = ord
elif ord == None:
self.ord = "fro"
else:
raise TypeError(
"linalg_matrix_norm(): argument 'ord' must be Number, not {}".format(
type(ord)
)
)
if isinstance(dim, tuple) and len(dim) == 2 and (dim[0] != dim[1]):
self.dim = dim
else:
raise TypeError(
"linalg.matrix_norm(): dim must be a 2-tuple of ints with different elements"
)
self.keepdim = keepdim
def _matrix_norm(self, x, ord, dim, keepdim):
if ord == "nuc":
raise NotImplementedError
elif ord == "fro":
return flow.sqrt(flow.sum(flow.square(x), dim=dim, keepdim=keepdim))
elif ord in [float("inf"), float("-inf")]:
(dim_0, dim_1) = (dim[0], dim[1])
(dim_0, dim_1) = (dim_1, dim_0)
if dim_1 > dim_0 and (not keepdim):
dim_1 -= 1
res = flow.sum(flow.abs(x), dim=dim_0, keepdim=keepdim)
return _norm_min_max(res, ord, dim_1, keepdim)
elif ord in [1, -1]:
(dim_0, dim_1) = (dim[0], dim[1])
if dim_1 > dim_0 and (not keepdim):
dim_1 -= 1
res = flow.sum(flow.abs(x), dim=dim_0, keepdim=keepdim)
return _norm_min_max(res, ord, dim_1, keepdim)
elif ord in [2, -2]:
raise NotImplementedError
else:
raise ValueError("Invalid norm order: {}".format(ord))
def forward(self, x):
num_dims = len(x.shape)
if num_dims < 2:
raise RuntimeError(
"linalg.matrix_norm(): input tensor must be a matrix or batch of matrices"
)
dim = check_dim(num_dims, self.dim)
return self._matrix_norm(x, ord=self.ord, dim=dim, keepdim=self.keepdim)
class Norm(Module):
def __init__(self, ord=None, dim=None, keepdim=False) -> None:
super().__init__()
self.ord = ord
self.dim = dim
self.keepdim = keepdim
def forward(self, x):
if isinstance(self.dim, int):
res = Vector_Norm(ord=self.ord, dim=self.dim, keepdim=self.keepdim)(x)
elif isinstance(self.dim, tuple):
res = Matrix_Norm(ord=self.ord, dim=self.dim, keepdim=self.keepdim)(x)
elif self.dim == None and self.ord != None:
assert (
len(x.shape) <= 2
), "input must be 1-D or 2-D when dim is None and ord is not None"
if len(x.shape) == 1:
res = Vector_Norm(ord=self.ord, keepdim=self.keepdim)(x)
else:
res = Matrix_Norm(ord=self.ord, keepdim=self.keepdim)(x)
elif self.dim == None and self.ord == None:
res = Vector_Norm(keepdim=self.keepdim)(x)
return res
def norm_op(input, ord=None, dim=None, keepdim=False):
"""linalg.norm(input, ord=None, dim=None, keepdim=False, *, out=None) -> Tensor
Returns the matrix norm or vector norm of a given tensor.
This function can calculate one of eight different types of matrix norms, or one
of an infinite number of vector norms, depending on both the number of reduction
dimensions and the value of the `ord` parameter.
Args:
input (Tensor): The input tensor. If dim is None, input must be 1-D or 2-D, unless :attr:`ord`
is None. If both :attr:`dim` and :attr:`ord` are None, the 2-norm of the input flattened to 1-D
will be returned. Its data type must be either a floating point or complex type. For complex
inputs, the norm is calculated on of the absolute values of each element. If the input is
complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will
be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat).
ord (int, float, inf, -inf, 'fro', 'nuc', optional): The order of norm.
inf refers to :attr:`float('inf')`, numpy's :attr:`inf` object, or any equivalent object.
The following norms can be calculated:
===== ============================ ==========================
ord norm for matrices norm for vectors
===== ============================ ==========================
None Frobenius norm 2-norm
'fro' Frobenius norm -- not supported --
'nuc' -- not supported yet -- -- not supported --
inf max(sum(abs(x), dim=1)) max(abs(x))
-inf min(sum(abs(x), dim=1)) min(abs(x))
0 -- not supported -- sum(x != 0)
1 max(sum(abs(x), dim=0)) as below
-1 min(sum(abs(x), dim=0)) as below
2 -- not supported yet -- as below
-2 -- not supported yet -- as below
other -- not supported -- sum(abs(x)**ord)**(1./ord)
===== ============================ ==========================
Default: ``None``
dim (int, 2-tuple of ints, 2-list of ints, optional): If :attr:`dim` is an int,
vector norm will be calculated over the specified dimension. If :attr:`dim`
is a 2-tuple of ints, matrix norm will be calculated over the specified
dimensions. If :attr:`dim` is None, matrix norm will be calculated
when the input tensor has two dimensions, and vector norm will be
calculated when the input tensor has one dimension. Default: ``None``
keepdim (bool, optional): If set to True, the reduced dimensions are retained
in the result as dimensions with size one. Default: ``False``
out (Tensor, optional): The output tensor.
Examples::
>>> import oneflow as flow
>>> from oneflow import linalg as LA
>>> import numpy as np
>>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4)
>>> a
tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=oneflow.float32)
>>> b = a.reshape(3, 3)
>>> b
tensor([[-4., -3., -2.],
[-1., 0., 1.],
[ 2., 3., 4.]], dtype=oneflow.float32)
>>> LA.norm(a)
tensor(7.7460, dtype=oneflow.float32)
>>> LA.norm(b)
tensor(7.7460, dtype=oneflow.float32)
>>> LA.norm(b, 'fro')
tensor(7.7460, dtype=oneflow.float32)
>>> LA.norm(a, float('inf'))
tensor(4., dtype=oneflow.float32)
>>> LA.norm(b, float('inf'))
tensor(9., dtype=oneflow.float32)
>>> LA.norm(a, -float('inf'))
tensor(0., dtype=oneflow.float32)
>>> LA.norm(b, -float('inf'))
tensor(2., dtype=oneflow.float32)
>>> LA.norm(a, 1)
tensor(20., dtype=oneflow.float32)
>>> LA.norm(b, 1)
tensor(7., dtype=oneflow.float32)
>>> LA.norm(a, -1)
tensor(0., dtype=oneflow.float32)
>>> LA.norm(b, -1)
tensor(6., dtype=oneflow.float32)
>>> LA.norm(a, 2)
tensor(7.7460, dtype=oneflow.float32)
>>> LA.norm(a, -2)
tensor(0., dtype=oneflow.float32)
>>> LA.norm(a, 3)
tensor(5.8480, dtype=oneflow.float32)
>>> LA.norm(a, -3)
tensor(0., dtype=oneflow.float32)
Using the :attr:`dim` argument to compute vector norms::
>>> c = flow.tensor([[1., 2., 3.],
... [-1, 1, 4]])
>>> LA.norm(c, dim=0)
tensor([1.4142, 2.2361, 5.0000], dtype=oneflow.float32)
>>> LA.norm(c, dim=1, keepdim = True)
tensor([[3.7417],
[4.2426]], dtype=oneflow.float32)
>>> LA.norm(c, ord=1, dim=1)
tensor([6., 6.], dtype=oneflow.float32)
Using the :attr:`dim` argument to compute matrix norms::
>>> m = flow.tensor(np.arange(8, dtype=np.float32)).reshape(2, 2, 2)
>>> LA.norm(m, dim=(1,2))
tensor([ 3.7417, 11.2250], dtype=oneflow.float32)
"""
return Norm(ord, dim, keepdim)(input)
@register_tensor_op("norm")
def norm_tensor_op(input, ord=None, dim=None, keepdim=False):
"""
See :func:`oneflow.linalg.norm`
"""
return Norm(ord, dim, keepdim)(input)
def vector_norm_tensor_op(input, ord=2, dim=None, keepdim=False):
"""
linalg.vector_norm(input, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor
Computes a vector norm.
Supports input of float, double dtypes.
This function does not necessarily treat multidimensonal attr:`input` as a batch of
vectors, instead:
- If :attr:`dim`\\ `= None`, :attr:`input` will be flattened before the norm is computed.
- If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions and the other dimensions will be treated as batch dimensions.
This behavior is for consistency with :func:`flow.linalg.norm`.
:attr:`ord` defines the vector norm that is computed. The following norms are supported:
====================== ========================================================
:attr:`ord` vector norm
====================== ========================================================
`2` (default) `2`-norm (see below)
`inf` `max(abs(x))`
`-inf` `min(abs(x))`
`0` `sum(x != 0)`
other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}`
====================== ========================================================
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.
Args:
input (Tensor): tensor, flattened by default, but this behavior can be
controlled using :attr:`dim`.
ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2`
dim (int, Tuple[int], optional): dimensions over which to compute
the norm. See above for the behavior when :attr:`dim`\\ `= None`.
Default: `None`
keepdim (bool, optional): If set to `True`, the reduced dimensions are retained
in the result as dimensions with size one. Default: `False`
Returns:
A real-valued tensor.
Examples::
>>> import oneflow as flow
>>> from oneflow import linalg as LA
>>> import numpy as np
>>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4)
>>> a
tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=oneflow.float32)
>>> b = a.reshape(3, 3)
>>> b
tensor([[-4., -3., -2.],
[-1., 0., 1.],
[ 2., 3., 4.]], dtype=oneflow.float32)
>>> LA.vector_norm(a, ord=3.5)
tensor(5.4345, dtype=oneflow.float32)
>>> LA.vector_norm(b, ord=3.5)
tensor(5.4345, dtype=oneflow.float32)
"""
return Vector_Norm(ord, dim, keepdim)(input)
def matrix_norm_tensor_op(input, ord="fro", dim=(-2, -1), keepdim=False):
"""
linalg.matrix_norm(input, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor
Computes a matrix norm.
Support input of float, double, cfloat and cdouble dtypes.
Also supports batches of matrices: the norm will be computed over the
dimensions specified by the 2-tuple :attr:`dim` and the other dimensions will
be treated as batch dimensions. The output will have the same batch dimensions.
:attr:`ord` defines the matrix norm that is computed. The following norms are supported:
====================== ========================================================
:attr:`ord` matrix norm
====================== ========================================================
`'fro'` (default) Frobenius norm
`'nuc'` -- not supported yet --
`inf` `max(sum(abs(x), dim=1))`
`-inf` `min(sum(abs(x), dim=1))`
`1` `max(sum(abs(x), dim=0))`
`-1` `min(sum(abs(x), dim=0))`
`2` -- not supported yet --
`-2` -- not supported yet --
====================== ========================================================
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.
Args:
input (Tensor): tensor with two or more dimensions. By default its
shape is interpreted as `(*, m, n)` where `*` is zero or more
batch dimensions, but this behavior can be controlled using :attr:`dim`.
ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'`
dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)`
keepdim (bool, optional): If set to `True`, the reduced dimensions are retained
in the result as dimensions with size one. Default: `False`
Returns:
A real-valued tensor.
Examples::
>>> import oneflow as flow
>>> from oneflow import linalg as LA
>>> import numpy as np
>>> a = flow.tensor(np.arange(9, dtype=np.float32)).reshape(3,3)
>>> a
tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]], dtype=oneflow.float32)
>>> LA.matrix_norm(a)
tensor(14.2829, dtype=oneflow.float32)
>>> LA.matrix_norm(a, ord=-1)
tensor(9., dtype=oneflow.float32)
>>> b = a.expand(2, -1, -1)
>>> b
tensor([[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]],
<BLANKLINE>
[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]]], dtype=oneflow.float32)
>>> LA.matrix_norm(b)
tensor([14.2829, 14.2829], dtype=oneflow.float32)
>>> LA.matrix_norm(b, dim=(0, 2))
tensor([ 3.1623, 10.0000, 17.2627], dtype=oneflow.float32)
"""
return Matrix_Norm(ord, dim, keepdim)(input)
def l2_normalize(input, dim=0, epsilon=1e-12):
"""Use L2 norm to normalizes along dimension `dim`
......
......@@ -107,7 +107,7 @@ def clip_grad_norm_(
),
norm_type,
)
if np.isnan(total_norm.numpy()) or np.isinf(total_norm.numpy()):
if np.isnan(total_norm.numpy()).all() or np.isinf(total_norm.numpy()).all():
if error_if_nonfinite:
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
......@@ -124,11 +124,12 @@ def clip_grad_norm_(
FutureWarning,
stacklevel=2,
)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef.numpy().item() < 1:
for p in parameters:
# TODO: Switch to inplace multiply in future
p.grad[:] = p.grad.detach().mul(clip_coef.to(p.grad.device))
clip_coef_clamped = clip_coef.clamp(max=1.0)
for p in parameters:
# TODO: Switch to inplace multiply in future
p.grad[:] = p.grad.detach().mul(clip_coef_clamped.to(p.grad.device))
return total_norm
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
from test_util import GenArgList
from oneflow.test_utils.automated_test_util import *
import oneflow as flow
@flow.unittest.skip_unless_1n1d()
class TestTripletMarginLoss(flow.unittest.TestCase):
@autotest(n=10)
def test_triplet_marginloss_with_random_data(test_case):
margin = random().to(float)
p = random().to(float)
swap = random_bool()
reduction = oneof("none", "sum", "mean", nothing())
m = torch.nn.TripletMarginLoss(
margin=margin, p=p, swap=swap, reduction=reduction
)
m.train(random())
device = random_device()
m.to(device)
shape = random_tensor(ndim=2, dim0=random(1, 8)).value().shape
anchor = random_pytorch_tensor(len(shape), *shape).to(device)
pos = random_pytorch_tensor(len(shape), *shape).to(device)
neg = random_pytorch_tensor(len(shape), *shape).to(device)
y = m(anchor, pos, neg)
return y
if __name__ == "__main__":
unittest.main()
......@@ -19,7 +19,7 @@ from collections import OrderedDict
import numpy as np
from test_util import GenArgList
from oneflow.test_utils.automated_test_util import *
import oneflow as flow
import oneflow.unittest
......@@ -255,6 +255,55 @@ class TestNormModule(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest()
def test_no_dim_no_ord_norm_with_random_data(test_case):
device = random_device()
input = random_pytorch_tensor().to(device)
keepdim = random_bool()
m = torch.linalg.norm(input, keepdim=keepdim)
return m
@autotest()
def test_one_dim_norm_with_random_data(test_case):
device = random_device()
input = random_pytorch_tensor(ndim=4).to(device)
dim = random(low=0, high=4).to(int)
k = random().to(float)
ord = oneof(float("inf"), float("-inf"), k, None)
keepdim = random_bool()
m = torch.linalg.norm(input, ord, dim, keepdim)
return m
@autotest()
def test_no_dim_one_shape_norm_with_random_data(test_case):
device = random_device()
input = random_pytorch_tensor(ndim=1).to(device)
k = random().to(float)
ord = oneof(float("inf"), float("-inf"), k)
keepdim = random_bool()
m = torch.linalg.norm(input, ord=ord, keepdim=keepdim)
return m
@autotest()
def test_no_dim_two_shape_norm_with_random_data(test_case):
device = random_device()
input = random_pytorch_tensor(ndim=2).to(device)
ord = oneof(float("inf"), float("-inf"), "fro", 1, -1)
keepdim = random().to(bool)
m = torch.linalg.norm(input, ord=ord, keepdim=keepdim)
return m
@autotest()
def test_tuple_dim_norm_with_random_data(test_case):
device = random_device()
input = random_pytorch_tensor(ndim=2).to(device)
k = random(low=-2, high=1).to(int)
dim = oneof((-2, -1), (0, 1), (-1, 0))
ord = oneof(float("inf"), float("-inf"), "fro", 1, -1, None)
keepdim = random().to(bool)
m = torch.linalg.norm(input, ord=ord, dim=dim, keepdim=keepdim)
return m
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册