diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 4673dc258d062b219fb90f644265cbaa4cfb82ef..d70df5cd73847e5f63ce0b44b57dbb840d98b522 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include @@ -98,6 +99,18 @@ class ReduceKernel : public framework::OpKernel { int out_dtype = context.Attr("out_dtype"); framework::proto::VarType::Type cast_out_dtype; + // The dims has full dim, set the reduce_all is True + const auto& input_dim_size = context.Input("X")->dims().size(); + std::set dims_set(dims.begin(), dims.end()); + bool full_dim = true; + for (auto i = 0; i < input_dim_size; i++) { + if (dims_set.find(i) == dims_set.end()) { + full_dim = false; + break; + } + } + reduce_all = (reduce_all || full_dim); + if (out_dtype < 0) { auto* cast_input = context.Input("X"); cast_out_dtype = @@ -137,6 +150,18 @@ class BoolReduceKernel : public framework::OpKernel { auto dims = context.Attr>("dim"); bool keep_dim = context.Attr("keep_dim"); + // The dims has full dim, set the reduce_all is True + const auto& input_dim_size = context.Input("X")->dims().size(); + std::set dims_set(dims.begin(), dims.end()); + bool full_dim = true; + for (auto i = 0; i < input_dim_size; i++) { + if (dims_set.find(i) == dims_set.end()) { + full_dim = false; + break; + } + } + reduce_all = (reduce_all || full_dim); + if (reduce_all) { // Flatten and reduce 1-D tensor auto x = EigenVector::Flatten(*input); @@ -183,6 +208,17 @@ class ReduceGradKernel : public framework::OpKernel { auto* output = context.Output(framework::GradVarName("X")); output->mutable_data(context.GetPlace()); + // The dims has full dim, set the reduce_all is True + const auto& input_dim_size = context.Input("X")->dims().size(); + std::set dims_set(dims.begin(), dims.end()); + bool full_dim = true; + for (auto i = 0; i < input_dim_size; i++) { + if (dims_set.find(i) == dims_set.end()) { + full_dim = false; + break; + } + } + reduce_all = (reduce_all || full_dim); // NOTE: EigenTensor::From() uses tensor->data() // if op has NoNeedBufferVarsInferer, the corresponding kNoNeedBufferX or // kNoNeedBufferY should set true diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index b125b36c6b6b459ec2974107faa4385bf1368636..14824407284571619babe058393cfa5956f7d0cd 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -210,8 +210,6 @@ from .framework import BackwardStrategy #DEFINE_ALIAS from .framework import to_variable #DEFINE_ALIAS from .framework import grad #DEFINE_ALIAS from .framework import no_grad #DEFINE_ALIAS -from .framework import save_dygraph #DEFINE_ALIAS -from .framework import load_dygraph #DEFINE_ALIAS from .framework import save #DEFINE_ALIAS from .framework import load #DEFINE_ALIAS from .framework import prepare_context #DEFINE_ALIAS @@ -238,8 +236,6 @@ from .fluid.data import data from . import incubate from .incubate import hapi -from .fluid.dygraph.base import enable_dygraph #DEFINE_ALIAS -from .fluid.dygraph.base import disable_dygraph #DEFINE_ALIAS from .fluid.dygraph.base import enable_dygraph as disable_static #DEFINE_ALIAS from .fluid.dygraph.base import disable_dygraph as enable_static #DEFINE_ALIAS from .fluid.framework import in_dygraph_mode as in_dynamic_mode #DEFINE_ALIAS diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index b179d00626249849f64f0fc571cb2e85cf08ea05..2002b8a95decfd6d6c55538e2dff0a793828dd9b 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1511,7 +1511,7 @@ def array_write(x, i, array=None): assert i.shape == [ 1 ], "The shape of index 'i' should be [1] in dygraph mode" - i = i.numpy()[0] + i = i.numpy().item(0) if array is None: array = create_array(x.dtype) assert isinstance( @@ -1976,7 +1976,7 @@ def array_read(array, i): assert i.shape == [ 1 ], "The shape of index 'i' should be [1] in dygraph mode" - i = i.numpy()[0] + i = i.numpy().item(0) return array[i] check_variable_and_dtype(i, 'i', ['int64'], 'array_read') diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index dc6c37925efaf3544e012f341505cc56d25eacb0..2fb518221e855d2242ace9844f461463ca38931e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4841,7 +4841,7 @@ def split(input, num_or_sections, dim=-1, name=None): if isinstance(dim, Variable): dim = dim.numpy() - dim = dim[0] + dim = dim.item(0) dim = (len(input.shape) + dim) if dim < 0 else dim attrs += ('axis', dim) @@ -5885,7 +5885,7 @@ def one_hot(input, depth, allow_out_of_range=False): depth = depth.numpy() assert depth.shape == ( 1, ), "depth of type Variable should have shape [1]" - depth = depth[0] + depth = depth.item(0) out = core.ops.one_hot(input, 'depth', depth, 'allow_out_of_range', allow_out_of_range) out.stop_gradient = True @@ -6067,7 +6067,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ) if isinstance(shape, (list, tuple)): shape = [ - item.numpy()[0] if isinstance(item, Variable) else item + item.numpy().item(0) if isinstance(item, Variable) else item for item in shape ] out, _ = core.ops.reshape2(x, 'shape', shape) @@ -10195,7 +10195,7 @@ def expand(x, expand_times, name=None): if in_dygraph_mode(): if isinstance(expand_times, (list, tuple)): expand_times = [ - item.numpy()[0] if isinstance(item, Variable) else item + item.numpy().item(0) if isinstance(item, Variable) else item for item in expand_times ] @@ -10806,11 +10806,11 @@ def slice(input, axes, starts, ends): if isinstance(starts, (list, tuple)) and isinstance(ends, (list, tuple)): starts = [ - item.numpy()[0] if isinstance(item, Variable) else item + item.numpy().item(0) if isinstance(item, Variable) else item for item in starts ] ends = [ - item.numpy()[0] if isinstance(item, Variable) else item + item.numpy().item(0) if isinstance(item, Variable) else item for item in ends ] diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 2d874b4806c9e1449a170017440c4b5038ff93bf..7ac67b1bc817964ca65d5b7009b446458d2cc7ab 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -317,7 +317,7 @@ def concat(input, axis=0, name=None): if in_dygraph_mode(): if isinstance(axis, Variable): axis = axis.numpy() - axis = axis[0] + axis = axis.item(0) return core.ops.concat(input, 'axis', axis) check_type(input, 'input', (list, tuple, Variable), 'concat') @@ -699,9 +699,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): if isinstance(value, Variable): if dtype in ['int64', 'int32']: - attrs['str_value'] = str(int(value.numpy())) + attrs['str_value'] = str(int(value.numpy().item(0))) else: - attrs['str_value'] = str(float(value.numpy())) + attrs['str_value'] = str(float(value.numpy().item(0))) core.ops.fill_constant(out, 'value', float(value), 'force_cpu', force_cpu, 'dtype', diff --git a/python/paddle/fluid/tests/unittests/test_cholesky_op.py b/python/paddle/fluid/tests/unittests/test_cholesky_op.py index 4e2280c0118a11ebfc21f6179b8a7a795c6f53da..f3e6c079eedc8effc948a44e08a5dcdcae8d3081 100644 --- a/python/paddle/fluid/tests/unittests/test_cholesky_op.py +++ b/python/paddle/fluid/tests/unittests/test_cholesky_op.py @@ -90,5 +90,15 @@ class TestCholeskyOp2D(TestCholeskyOp): self._input_shape = (64, 64) +class TestDygraph(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + a = np.random.rand(3, 3) + a_t = np.transpose(a, [1, 0]) + x_data = np.matmul(a, a_t) + 1e-03 + x = paddle.to_variable(x_data) + out = paddle.cholesky(x, upper=False) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index aefc809bd5cb852d3fde95dff4550e506c5f1c12..3475320eeebc55a14dd569410610b70ae35e65a3 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -305,14 +305,18 @@ class TestFillConstantImperative(unittest.TestCase): with fluid.dygraph.guard(): data1 = np.array([1, 2]).astype('int32') data2 = np.array([1.1]).astype('float32') + data3 = np.array([88]).astype('int32') shape = fluid.dygraph.to_variable(data1) val = fluid.dygraph.to_variable(data2) + value = fluid.dygraph.to_variable(data3) res1 = fluid.layers.fill_constant( shape=[1, 2], dtype='float32', value=1.1) res2 = fluid.layers.fill_constant( shape=shape, dtype='float32', value=1.1) res3 = fluid.layers.fill_constant( shape=shape, dtype='float32', value=val) + res4 = fluid.layers.fill_constant( + shape=shape, dtype='int32', value=value) assert np.array_equal( res1.numpy(), np.full( [1, 2], 1.1, dtype="float32")) @@ -322,6 +326,9 @@ class TestFillConstantImperative(unittest.TestCase): assert np.array_equal( res3.numpy(), np.full( [1, 2], 1.1, dtype="float32")) + assert np.array_equal( + res4.numpy(), np.full( + [1, 2], 88, dtype="int32")) class TestFillConstantOpError(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_inverse_op.py b/python/paddle/fluid/tests/unittests/test_inverse_op.py index 5349654ac27800d2e70c4b77f6531853178fd3ed..fd540dcd741eef4c007eae19a982bc186c09d7d7 100644 --- a/python/paddle/fluid/tests/unittests/test_inverse_op.py +++ b/python/paddle/fluid/tests/unittests/test_inverse_op.py @@ -89,8 +89,7 @@ class TestInverseAPI(unittest.TestCase): def check_static_result(self, place): with fluid.program_guard(fluid.Program(), fluid.Program()): input = fluid.data(name="input", shape=[4, 4], dtype="float64") - result = paddle.inverse(input=input) - + result = paddle.inverse(x=input) input_np = np.random.random([4, 4]).astype("float64") result_np = np.linalg.inv(input_np) @@ -145,7 +144,7 @@ class TestInverseSingularAPI(unittest.TestCase): def check_static_result(self, place): with fluid.program_guard(fluid.Program(), fluid.Program()): input = fluid.data(name="input", shape=[4, 4], dtype="float64") - result = paddle.inverse(input=input) + result = paddle.inverse(x=input) input_np = np.zeros([4, 4]).astype("float64") diff --git a/python/paddle/fluid/tests/unittests/test_max_op.py b/python/paddle/fluid/tests/unittests/test_max_op.py index 75ccaacc3c3035a3c0cdd081fd93737a90ab435b..e2bdaba91a68ff17d8d17724f8cbd5d8ad684d08 100644 --- a/python/paddle/fluid/tests/unittests/test_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_max_op.py @@ -48,6 +48,15 @@ class ApiMaxTest(unittest.TestCase): res, = exe.run(feed={"data": input_data}, fetch_list=[result_max]) self.assertEqual((res == np.max(input_data, axis=0)).all(), True) + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data = paddle.nn.data("data", shape=[10, 10], dtype="int64") + result_max = paddle.max(x=data, axis=(0, 1)) + exe = paddle.static.Executor(self.place) + input_data = np.random.randint(10, size=(10, 10)).astype(np.int64) + res, = exe.run(feed={"data": input_data}, fetch_list=[result_max]) + self.assertEqual((res == np.max(input_data, axis=(0, 1))).all(), True) + def test_errors(self): paddle.enable_static() @@ -59,6 +68,15 @@ class ApiMaxTest(unittest.TestCase): self.assertRaises(TypeError, test_input_type) + def test_axis_type(): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data = paddle.nn.data("data", shape=[10, 10], dtype="int64") + axis = paddle.nn.data("axis", shape=[10, 10], dtype="int64") + result_min = paddle.min(data, axis) + + self.assertRaises(TypeError, test_axis_type) + def test_imperative_api(self): paddle.disable_static() np_x = np.array([10, 10]).astype('float64') diff --git a/python/paddle/fluid/tests/unittests/test_min_op.py b/python/paddle/fluid/tests/unittests/test_min_op.py index 3dbda66e2a2cf825ae12752484cd314086dab3c5..e8bfe55f32a122ac9259b68d6a888f93757a76be 100644 --- a/python/paddle/fluid/tests/unittests/test_min_op.py +++ b/python/paddle/fluid/tests/unittests/test_min_op.py @@ -48,6 +48,15 @@ class ApiMinTest(unittest.TestCase): res, = exe.run(feed={"data": input_data}, fetch_list=[result_min]) self.assertEqual((res == np.min(input_data, axis=0)).all(), True) + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data = paddle.nn.data("data", shape=[10, 10], dtype="int64") + result_min = paddle.min(x=data, axis=(0, 1)) + exe = paddle.static.Executor(self.place) + input_data = np.random.randint(10, size=(10, 10)).astype(np.int64) + res, = exe.run(feed={"data": input_data}, fetch_list=[result_min]) + self.assertEqual((res == np.min(input_data, axis=(0, 1))).all(), True) + def test_errors(self): paddle.enable_static() @@ -59,6 +68,15 @@ class ApiMinTest(unittest.TestCase): self.assertRaises(TypeError, test_input_type) + def test_axis_type(): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data = paddle.nn.data("data", shape=[10, 10], dtype="int64") + axis = paddle.nn.data("axis", shape=[10, 10], dtype="int64") + result_min = paddle.min(data, axis) + + self.assertRaises(TypeError, test_axis_type) + def test_imperative_api(self): paddle.disable_static() np_x = np.array([10, 10]).astype('float64') diff --git a/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py b/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4f60f3e39a57365163cb3f5f3e061e53f8fd654b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py @@ -0,0 +1,178 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.static import Program, program_guard + + +def calc_margin_rank_loss(x, y, label, margin=0.0, reduction='none'): + result = (-1 * label) * (x - y) + margin + result = np.maximum(result, 0) + if reduction == 'none': + return result + elif reduction == 'sum': + return np.sum(result) + elif reduction == 'mean': + return np.mean(result) + + +def create_test_case(margin, reduction): + class MarginRankingLossCls(unittest.TestCase): + def setUp(self): + self.x_data = np.random.rand(10, 10).astype("float64") + self.y_data = np.random.rand(10, 10).astype("float64") + self.label_data = np.random.choice( + [-1, 1], size=[10, 10]).astype("float64") + self.places = [] + self.places.append(fluid.CPUPlace()) + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def run_static_functional_api(self, place): + paddle.enable_static() + expected = calc_margin_rank_loss( + self.x_data, + self.y_data, + self.label_data, + margin=margin, + reduction=reduction) + with program_guard(Program(), Program()): + x = paddle.nn.data(name="x", shape=[10, 10], dtype="float64") + y = paddle.nn.data(name="y", shape=[10, 10], dtype="float64") + label = paddle.nn.data( + name="label", shape=[10, 10], dtype="float64") + result = paddle.nn.functional.margin_ranking_loss( + x, y, label, margin, reduction) + exe = paddle.static.Executor(place) + result_numpy, = exe.run(feed={ + "x": self.x_data, + "y": self.y_data, + "label": self.label_data + }, + fetch_list=[result]) + self.assertTrue(np.allclose(result_numpy, expected)) + + def run_static_api(self, place): + paddle.enable_static() + expected = calc_margin_rank_loss( + self.x_data, + self.y_data, + self.label_data, + margin=margin, + reduction=reduction) + with program_guard(Program(), Program()): + x = paddle.nn.data(name="x", shape=[10, 10], dtype="float64") + y = paddle.nn.data(name="y", shape=[10, 10], dtype="float64") + label = paddle.nn.data( + name="label", shape=[10, 10], dtype="float64") + margin_rank_loss = paddle.nn.loss.MarginRankingLoss( + margin=margin, reduction=reduction) + result = margin_rank_loss(x, y, label) + exe = paddle.static.Executor(place) + result_numpy, = exe.run(feed={ + "x": self.x_data, + "y": self.y_data, + "label": self.label_data + }, + fetch_list=[result]) + self.assertTrue(np.allclose(result_numpy, expected)) + self.assertTrue('loss' in result.name) + + def run_dynamic_functional_api(self, place): + paddle.disable_static(place) + x = paddle.to_variable(self.x_data) + y = paddle.to_variable(self.y_data) + label = paddle.to_variable(self.label_data) + + result = paddle.nn.functional.margin_ranking_loss(x, y, label, + margin, reduction) + expected = calc_margin_rank_loss( + self.x_data, + self.y_data, + self.label_data, + margin=margin, + reduction=reduction) + self.assertTrue(np.allclose(result.numpy(), expected)) + + def run_dynamic_api(self, place): + paddle.disable_static(place) + x = paddle.to_variable(self.x_data) + y = paddle.to_variable(self.y_data) + label = paddle.to_variable(self.label_data) + margin_rank_loss = paddle.nn.loss.MarginRankingLoss( + margin=margin, reduction=reduction) + result = margin_rank_loss(x, y, label) + expected = calc_margin_rank_loss( + self.x_data, + self.y_data, + self.label_data, + margin=margin, + reduction=reduction) + self.assertTrue(np.allclose(result.numpy(), expected)) + + def run_dynamic_broadcast_api(self, place): + paddle.disable_static(place) + label_data = np.random.choice([-1, 1], size=[10]).astype("float64") + x = paddle.to_variable(self.x_data) + y = paddle.to_variable(self.y_data) + label = paddle.to_variable(label_data) + margin_rank_loss = paddle.nn.loss.MarginRankingLoss( + margin=margin, reduction=reduction) + result = margin_rank_loss(x, y, label) + expected = calc_margin_rank_loss( + self.x_data, + self.y_data, + label_data, + margin=margin, + reduction=reduction) + self.assertTrue(np.allclose(result.numpy(), expected)) + + def test_case(self): + for place in self.places: + self.run_static_api(place) + self.run_static_functional_api(place) + self.run_dynamic_api(place) + self.run_dynamic_functional_api(place) + self.run_dynamic_broadcast_api(place) + + cls_name = "TestMarginRankLossCase_{}_{}".format(margin, reduction) + MarginRankingLossCls.__name__ = cls_name + globals()[cls_name] = MarginRankingLossCls + + +for margin in [0.0, 0.2]: + for reduction in ['none', 'mean', 'sum']: + create_test_case(margin, reduction) + + +# test case the raise message +class MarginRakingLossError(unittest.TestCase): + paddle.enable_static() + + def test_errors(self): + def test_margin_value_error(): + margin_rank_loss = paddle.nn.loss.MarginRankingLoss( + margin=0.1, reduction="reduce_mean") + + self.assertRaises(ValueError, test_margin_value_error) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 0c26e4c5178883883dc9e364dff90e67c4667ce1..20f1b453a0cd37aaf0888991a3f20c9e68c438d0 100644 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -43,8 +43,6 @@ from paddle.fluid import core #DEFINE_ALIAS from ..fluid.dygraph.base import no_grad #DEFINE_ALIAS from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS from ..fluid.dygraph.base import grad #DEFINE_ALIAS -from ..fluid.dygraph.checkpoint import load_dygraph #DEFINE_ALIAS -from ..fluid.dygraph.checkpoint import save_dygraph #DEFINE_ALIAS from ..fluid.dygraph.checkpoint import load_dygraph as load #DEFINE_ALIAS from ..fluid.dygraph.checkpoint import save_dygraph as save #DEFINE_ALIAS from ..fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS diff --git a/python/paddle/incubate/complex/tensor/linalg.py b/python/paddle/incubate/complex/tensor/linalg.py index 3badf36280e27c9d7962a2b7b3fff596fd0e8cb3..946a0fd5534d13166706523675c93ef1d01cfa54 100644 --- a/python/paddle/incubate/complex/tensor/linalg.py +++ b/python/paddle/incubate/complex/tensor/linalg.py @@ -56,20 +56,20 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): # [1.+5.j 5.+9.j] """ # x = a + bi, y = c + di - # mm(x, y) = mm(a, c) - mm(b, d) + (mm(a, d) + mm(b, c))i + # P1 = ac; P2 = (a + b)(c + d); P3 = bd; then mm(x, y) = (P1-P3) + (P2-P1-P3)j complex_variable_exists([x, y], "matmul") a, b = (x.real, x.imag) if is_complex(x) else (x, None) c, d = (y.real, y.imag) if is_complex(y) else (y, None) - ac = layers.matmul(a, c, transpose_x, transpose_y, alpha, name) + P1 = layers.matmul(a, c, transpose_x, transpose_y, alpha, name) if is_real(b) and is_real(d): - bd = layers.matmul(b, d, transpose_x, transpose_y, alpha, name) - real = ac - bd - imag = layers.matmul(a, d, transpose_x, transpose_y, alpha, name) + \ - layers.matmul(b, c, transpose_x, transpose_y, alpha, name) + P2 = layers.matmul(a + b, c + d, transpose_x, transpose_y, alpha, name) + P3 = layers.matmul(b, d, transpose_x, transpose_y, alpha, name) + real = P1 - P3 + imag = P2 - P1 - P3 elif is_real(b): - real = ac + real = P1 imag = layers.matmul(b, c, transpose_x, transpose_y, alpha, name) else: - real = ac + real = P1 imag = layers.matmul(a, d, transpose_x, transpose_y, alpha, name) return ComplexVariable(real, imag) diff --git a/python/paddle/jit/__init__.py b/python/paddle/jit/__init__.py index f098dc591cc3e58811d2db6bd170b7eef8c92366..47369e3ff9cd87539f9e96708ff981dc67d06420 100644 --- a/python/paddle/jit/__init__.py +++ b/python/paddle/jit/__init__.py @@ -16,7 +16,6 @@ from ..fluid.dygraph.jit import save #DEFINE_ALIAS from ..fluid.dygraph.jit import load #DEFINE_ALIAS from ..fluid.dygraph.jit import SaveLoadConfig #DEFINE_ALIAS from ..fluid.dygraph.jit import TracedLayer #DEFINE_ALIAS -from ..fluid.dygraph.jit import declarative as __impl__ #DEFINE_ALIAS from ..fluid.dygraph.jit import declarative as to_static #DEFINE_ALIAS from ..fluid.dygraph import ProgramTranslator #DEFINE_ALIAS from ..fluid.dygraph.io import TranslatedLayer #DEFINE_ALIAS diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index aac6b40168533a588637b9b2ec418c1fc0fb7f6d..9583d9a0a39b362ce4bda2c11cb976fbe705cbe3 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -85,6 +85,7 @@ from .layer.loss import MSELoss #DEFINE_ALIAS from .layer.loss import L1Loss #DEFINE_ALIAS from .layer.loss import NLLLoss #DEFINE_ALIAS from .layer.loss import BCELoss #DEFINE_ALIAS +from .layer.loss import MarginRankingLoss #DEFINE_ALIAS from .layer.norm import BatchNorm #DEFINE_ALIAS from .layer.norm import GroupNorm #DEFINE_ALIAS from .layer.norm import LayerNorm #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index d6b88e741c6a88d215397d3383afd203b50fbee5..e3426b22484e4cea764f92cc44cc641386b7f6e4 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -129,7 +129,7 @@ from .loss import iou_similarity #DEFINE_ALIAS from .loss import kldiv_loss #DEFINE_ALIAS from .loss import l1_loss #DEFINE_ALIAS from .loss import log_loss #DEFINE_ALIAS -from .loss import margin_rank_loss #DEFINE_ALIAS +from .loss import margin_ranking_loss #DEFINE_ALIAS from .loss import mse_loss #DEFINE_ALIAS from .loss import nll_loss #DEFINE_ALIAS # from .loss import nce #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 4bbfaed81ea24b4e04e0e55a4a7b15c767dd3e6a..85ca043a10cca8dfaab2a4dcf724030fd505a7c1 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -13,6 +13,7 @@ # limitations under the License. # TODO: define loss functions of neural network +import numpy as np import paddle import paddle.fluid as fluid from ...fluid.framework import core, in_dygraph_mode @@ -38,10 +39,8 @@ from ...fluid.layers import teacher_student_sigmoid_loss #DEFINE_ALIAS from ...fluid.layers import edit_distance #DEFINE_ALIAS from ...fluid.layers import huber_loss #DEFINE_ALIAS -from ...fluid.layers import margin_rank_loss #DEFINE_ALIAS from ...fluid.layers import sampled_softmax_with_cross_entropy #DEFINE_ALIAS from ...fluid.layer_helper import LayerHelper -from ...fluid.framework import in_dygraph_mode from ...fluid.framework import Variable __all__ = [ @@ -55,8 +54,8 @@ __all__ = [ 'kldiv_loss', 'l1_loss', 'log_loss', - 'margin_rank_loss', 'mse_loss', + 'margin_ranking_loss', # 'nce', 'nll_loss', 'npair_loss', @@ -72,6 +71,110 @@ __all__ = [ ] +def margin_ranking_loss(input, + other, + target, + margin=0.0, + reduction='mean', + name=None): + """ + + This op the calcluate the the margin rank loss between the input x, y and target, use the math function as follows. + + .. math:: + margin\_rank\_loss = max(0, -target * (input - other) + margin) + + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is: + + .. math:: + Out = MEAN(margin\_rank\_loss) + + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is: + + .. math:: + Out = SUM(margin\_rank\_loss) + + If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``. + + Parameters: + input(Tensor): the first input tensor, it's data type should be float32, float64. + other(Tensor): the second input tensor, it's data type should be float32, float64. + target(Tensor): the target value corresponding to input, it's data type should be float32, float64. + margin (float, optional): The margin value to add, default value is 0; + reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: Tensor, if :attr:`reduction` is ``'mean'`` or ``'sum'``, the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + paddle.disable_static() + + x = paddle.to_variable(np.array([[1, 2], [3, 4]]).astype('float32')) + y = paddle.to_variable(np.array([[2, 1], [2, 4]]).astype('float32')) + target = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype('float32')) + loss = paddle.nn.functional.margin_ranking_loss(x, y, target) + print(loss.numpy()) # [0.75] + """ + if fluid.framework.in_dygraph_mode(): + out = core.ops.elementwise_sub(other, input) + out = core.ops.elementwise_mul(out, target) + if margin != 0.0: + margin = fluid.dygraph.base.to_variable([margin], dtype=out.dtype) + out = core.ops.elementwise_add(out, margin) + out = core.ops.relu(out) + if reduction == 'sum': + return core.ops.reduce_sum(out, 'reduce_all', True) + elif reduction == 'mean': + return core.ops.mean(out) + return out + + helper = LayerHelper("margin_ranking_loss", **locals()) + fluid.data_feeder.check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'margin_rank_loss') + fluid.data_feeder.check_variable_and_dtype( + other, 'other', ['float32', 'float64'], 'margin_rank_loss') + fluid.data_feeder.check_variable_and_dtype( + target, 'target', ['float32', 'float64'], 'margin_rank_loss') + + out = paddle.elementwise_sub(other, input) + out = paddle.multiply(out, target) + + if margin != 0.0: + margin_var = out.block.create_var(dtype=out.dtype) + paddle.fill_constant([1], out.dtype, margin, out=margin_var) + out = paddle.add(out, margin_var) + + result_out = helper.create_variable_for_type_inference(input.dtype) + + if reduction == 'none': + helper.append_op( + type="relu", inputs={"X": out}, outputs={"Out": result_out}) + return result_out + elif reduction == 'sum': + out = paddle.nn.functional.relu(out) + attrs = {"dim": [0], "keep_dim": False, "reduce_all": True} + helper.append_op( + type="reduce_sum", + inputs={"X": out}, + outputs={"Out": result_out}, + attrs=attrs) + return result_out + elif reduction == 'mean': + out = paddle.nn.functional.relu(out) + helper.append_op( + type="mean", + inputs={"X": out}, + outputs={"Out": result_out}, + attrs={}) + return result_out + + def l1_loss(x, label, reduction='mean', name=None): """ This operator computes the L1 Loss of Tensor ``x`` and ``label`` as follows. diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 560314788a1550125cf8aa614b71bce69e90d21f..680885ac26a52eaf8599ce5f152d3615bf5af8aa 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -62,6 +62,7 @@ from .loss import MSELoss #DEFINE_ALIAS from .loss import L1Loss #DEFINE_ALIAS from .loss import NLLLoss #DEFINE_ALIAS from .loss import BCELoss #DEFINE_ALIAS +from .loss import MarginRankingLoss #DEFINE_ALIAS from .norm import BatchNorm #DEFINE_ALIAS from .norm import GroupNorm #DEFINE_ALIAS from .norm import LayerNorm #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 006b81c9325221931ca6ece7f31bbaff7aaa6384..0cd3673288e676c465f2802ac78edeb73e860180 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -13,7 +13,9 @@ # limitations under the License. # TODO: define loss functions of neural network +import numpy as np import paddle.fluid as fluid +import paddle.fluid.core as core import paddle from .. import functional as F @@ -23,7 +25,8 @@ __all__ = [ 'MSELoss', 'L1Loss', 'NLLLoss', - 'BCELoss' + 'BCELoss', + 'MarginRankingLoss' ] @@ -569,3 +572,72 @@ class NLLLoss(fluid.dygraph.Layer): ignore_index=self._ignore_index, reduction=self._reduction, name=self._name) + + +class MarginRankingLoss(fluid.dygraph.Layer): + """ + + This interface is used to construct a callable object of the ``MarginRankingLoss`` class. + The MarginRankingLoss layer calculates the margin rank loss between the input, other and target + , use the math function as follows. + + .. math:: + margin\_rank\_loss = max(0, -target * (input - other) + margin) + + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is: + + .. math:: + Out = MEAN(margin\_rank\_loss) + + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is: + + .. math:: + Out = SUM(margin\_rank\_loss) + + If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``. + + Parameters: + margin (float, optional): The margin value to add, default value is 0; + reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Shape: + input: N-D Tensor, the shape is [N, *], N is batch size and `*` means any number of additional dimensions., available dtype is float32, float64. + other: N-D Tensor, `other` have the same shape and dtype as `input`. + target: N-D Tensor, target have the same shape and dtype as `input`. + out: If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor. + + Returns: + A callable object of MarginRankingLoss. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + paddle.disable_static() + + input = paddle.to_variable(np.array([[1, 2], [3, 4]]).astype("float32")) + other = paddle.to_variable(np.array([[2, 1], [2, 4]]).astype("float32")) + target = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype("float32")) + margin_rank_loss = paddle.nn.MarginRankingLoss() + loss = margin_rank_loss(input, other, target) + print(loss.numpy()) # [0.75] + """ + + def __init__(self, margin=0.0, reduction='mean', name=None): + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction) + super(MarginRankingLoss, self).__init__() + self.margin = margin + self.reduction = reduction + self.name = name + + def forward(self, input, other, target): + out = paddle.nn.functional.margin_ranking_loss( + input, other, target, self.margin, self.reduction, self.name) + return out diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 306e683f8ae37b5b1e55bd8d964ff00b33694278..972c9fbce4d2ab11b4a0bbc4f9818f721486c741 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -452,21 +452,18 @@ def dist(x, y, p=2): def dot(x, y, name=None): """ - :alias_main: paddle.dot - :alias: paddle.dot,paddle.tensor.dot,paddle.tensor.linalg.dot - This operator calculates inner product for vectors. .. note:: Only support 1-d Tensor(vector). Parameters: - x(Variable): 1-D ``Tensor`` or ``LoDTensor``. Its datatype should be ``float32``, ``float64``, ``int32``, ``int64`` - y(Variable): 1-D ``Tensor`` or ``LoDTensor``. Its datatype soulde be ``float32``, ``float64``, ``int32``, ``int64`` + x(Tensor): 1-D ``Tensor``. Its datatype should be ``float32``, ``float64``, ``int32``, ``int64`` + y(Tensor): 1-D ``Tensor``. Its datatype soulde be ``float32``, ``float64``, ``int32``, ``int64`` name(str, optional): Name of the output. Default is None. It's used to print debug info for developers. Details: :ref:`api_guide_Name` Returns: - Variable: the calculated result Tensor/LoDTensor. + Variable: the calculated result Tensor. Examples: @@ -475,12 +472,14 @@ def dot(x, y, name=None): import paddle import paddle.fluid as fluid import numpy as np - - with fluid.dygraph.guard(): - x = fluid.dygraph.to_variable(np.random.uniform(0.1, 1, [10]).astype(np.float32)) - y = fluid.dygraph.to_variable(np.random.uniform(1, 3, [10]).astype(np.float32)) - z = paddle.dot(x, y) - print(z.numpy()) + + paddle.disable_static() + x_data = np.random.uniform(0.1, 1, [10]).astype(np.float32) + y_data = np.random.uniform(1, 3, [10]).astype(np.float32) + x = paddle.to_variable(x_data) + y = paddle.to_variable(y_data) + z = paddle.dot(x, y) + print(z.numpy()) """ op_type = 'dot' @@ -651,11 +650,8 @@ def cross(x, y, axis=None, name=None): return out -def cholesky(x, upper=False): +def cholesky(x, upper=False, name=None): """ - :alias_main: paddle.cholesky - :alias: paddle.cholesky,paddle.tensor.cholesky,paddle.tensor.linalg.cholesky - Computes the Cholesky decomposition of one symmetric positive-definite matrix or batches of symmetric positive-definite matrice. @@ -680,21 +676,22 @@ def cholesky(x, upper=False): .. code-block:: python import paddle - import paddle.fluid as fluid import numpy as np - with fluid.dygraph.guard(): - a = np.random.rand(3, 3) - a_t = np.transpose(a, [1, 0]) - x = np.matmul(a, a_t) + 1e-03 - x = fluid.dygraph.to_variable(x) - out = paddle.cholesky(x, upper=False) - print(out.numpy()) - # [[1.190523 0. 0. ] - # [0.9906703 0.27676893 0. ] - # [1.25450498 0.05600871 0.06400121]] + paddle.disable_static() + a = np.random.rand(3, 3) + a_t = np.transpose(a, [1, 0]) + x_data = np.matmul(a, a_t) + 1e-03 + x = paddle.to_variable(x_data) + out = paddle.cholesky(x, upper=False) + print(out.numpy()) + # [[1.190523 0. 0. ] + # [0.9906703 0.27676893 0. ] + # [1.25450498 0.05600871 0.06400121]] """ + if in_dygraph_mode(): + return core.ops.cholesky(x, "upper", upper) check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'cholesky') check_type(upper, 'upper', bool, 'cholesky') helper = LayerHelper('cholesky', **locals()) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 893b2cfde819e516d5c2caa391c14b6f4a539805..8827a0dab395db745cc4ee4bd969dff29f125136 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1099,17 +1099,15 @@ def logsumexp(x, dim=None, keepdim=False, name=None): return layers.log(sum_out, name) -def inverse(input, name=None): - """ - :alias_main: paddle.inverse - :alias: paddle.inverse,paddle.tensor.inverse,paddle.tensor.math.inverse +def inverse(x, name=None): + """ Takes the inverse of the square matrix. A square matrix is a matrix with the same number of rows and columns. The input can be a square matrix (2-D Tensor) or batches of square matrices. Args: - input (Variable): The input Variable which holds a Tensor. The last two + x (Variable): The input tensor. The last two dimensions should be equal. When the number of dimensions is greater than 2, it is treated as batches of square matrix. The data type can be float32 and float64. @@ -1118,52 +1116,38 @@ def inverse(input, name=None): please refer to :ref:`api_guide_Name` Returns: - Variable: A Tensor holds the inverse of input. The shape and data type - is the same as input. + Variable: A Tensor holds the inverse of x. The shape and data type + is the same as x. Examples: .. code-block:: python import numpy as np import paddle - import paddle.fluid as fluid mat_np = np.array([[2, 0], [0, 2]]).astype("float32") + paddle.disable_static() + mat = paddle.to_variable(mat_np) + inv = paddle.inverse(mat) + print(inv) # [[0.5, 0], [0, 0.5]] - # example for static graph - input = fluid.data("input", shape=[2, 2], dtype="float32") - out = paddle.inverse(input) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - results = exe.run(feed={"input": mat_np }, - fetch_list=[out.name]) - print(results[0]) # [[0.5, 0], [0, 0.5]] - - # example for dynamic graph - with fluid.dygraph.guard(): - mat = fluid.dygraph.to_variable(mat_np) - inv = paddle.inverse(mat) - print(inv) # [[0.5, 0], [0, 0.5]] """ if in_dygraph_mode(): - return core.ops.inverse(input) + return core.ops.inverse(x) - def _check_input(input): - check_variable_and_dtype(input, 'input', + def _check_input(x): + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'inverse') - if len(input.shape) < 2: + if len(x.shape) < 2: raise ValueError( "The input of inverse is expected to be a Tensor whose number " "of dimensions is no less than 2. But reviced: %d, " - "input's shape: %s." % (len(input.shape), input.shape)) - - _check_input(input) - + "x's shape: %s." % (len(x.shape), x.shape)) + _check_input(x) helper = LayerHelper('inverse', **locals()) - out = helper.create_variable_for_type_inference(dtype=input.dtype) + out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( - type='inverse', inputs={'Input': [input] }, outputs={'Output': [out]}) + type='inverse', inputs={'Input': [x] }, outputs={'Output': [out]}) return out @@ -1177,19 +1161,19 @@ def max(x, axis=None, keepdim=False, name=None): float64, int32, int64. axis(list|int, optional): The axis along which the maximum is computed. If :attr:`None`, compute the maximum over all elements of - :attr:`input` and return a Tensor variable with a single element, + `x` and return a Tensor variable with a single element, otherwise must be in the range :math:`[-x.ndim(x), x.ndim(x))`. If :math:`axis[i] < 0`, the axis to reduce is :math:`x.ndim + axis[i]`. keepdim(bool, optional): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension - than the :attr:`input` unless :attr:`keepdim` is true, default + than the `x` unless :attr:`keepdim` is true, default value is False. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` Returns: Tensor, results of maximum on the specified axis of input tensor, - it's data type is the same as input's Tensor. + it's data type is the same as `x`. Examples: .. code-block:: python @@ -1232,7 +1216,14 @@ def max(x, axis=None, keepdim=False, name=None): """ if axis is not None and not isinstance(axis, list): - axis = [axis] + if isinstance(axis, tuple): + axis = list(axis) + elif isinstance(axis, int): + axis= [axis] + else: + raise TypeError( + "The type of axis must be int, list or tuple, but received {}".format(type(axis))) + reduce_all = True if axis == None or axis == [] else False axis = axis if axis != None and axis != [] else [0] if in_dygraph_mode(): @@ -1265,12 +1256,12 @@ def min(x, axis=None, keepdim=False, name=None): x(Tensor): A tensor, the data type is float32, float64, int32, int64. axis(list|int, optional): The axis along which the minimum is computed. If :attr:`None`, compute the minimum over all elements of - :attr:`input` and return a Tensor variable with a single element, + `x` and return a Tensor variable with a single element, otherwise must be in the range :math:`[-x.ndim, x.ndim)`. If :math:`axis[i] < 0`, the axis to reduce is :math:`x.ndim + axis[i]`. keepdim(bool, optional): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension - than the :attr:`input` unless :attr:`keepdim` is true, default + than the `x` unless :attr:`keepdim` is true, default value is False. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` @@ -1320,7 +1311,13 @@ def min(x, axis=None, keepdim=False, name=None): """ if axis is not None and not isinstance(axis, list): - axis= [axis] + if isinstance(axis, tuple): + axis = list(axis) + elif isinstance(axis, int): + axis= [axis] + else: + raise TypeError( + "The type of axis must be int, list or tuple, but received {}".format(type(axis))) reduce_all = True if axis == None or axis == [] else False axis = axis if axis != None and axis != [] else [0] if in_dygraph_mode():