diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index 97afd366387e9ba6476be59a4d73d53a38834d0e..05e423b8a52962d47a6615d48243444374b470e3 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -26,12 +26,15 @@ Dim make_dim(const int64_t* d) { } template <> -Dim<1> make_dim<1>(const int64_t* d) { - return Dim<1>(*d); +Dim<0> make_dim<0>(const int64_t* d) { + return Dim<0>(*d); } void make_ddim(DDim& ddim, const int64_t* dims, int n) { switch (n) { + case 0: + ddim = make_dim<0>(dims); + break; case 1: ddim = make_dim<1>(dims); break; @@ -190,7 +193,7 @@ struct VectorizeVisitor : public boost::static_visitor<> { this->operator()(t.tail); } - void operator()(const Dim<1>& t) { vector.push_back(t.head); } + void operator()(const Dim<0>& t) {} }; /// @endcond @@ -247,9 +250,8 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> { } } - void operator()(const Dim<1>& dim) { - PADDLE_ENFORCE(end == 1, "End index in ddim slice is out of bound."); - vector.push_back(dim.head); + void operator()(const Dim<0>& dim) { + PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound."); } }; diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index 5aff10d3b95902fdb9fe432d9f31830304dd3d07..f05b5ee3faee856a41f1376e5952710b550e7c42 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -30,8 +30,8 @@ namespace framework { * The number of dimensions must be between [1, 9]. */ struct DDim { - typedef boost::variant, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>, - Dim<8>, Dim<9>> + typedef boost::variant, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, + Dim<7>, Dim<8>, Dim<9>> DDimVar; DDimVar var; diff --git a/paddle/fluid/framework/dim.h b/paddle/fluid/framework/dim.h index 08b708006aadc4769bde7b37347ac1adfeca2bf7..8d288120e30035673be0ec5dc6230f607dfd1ebe 100644 --- a/paddle/fluid/framework/dim.h +++ b/paddle/fluid/framework/dim.h @@ -72,38 +72,36 @@ struct Dim { // Base case specialization template <> -struct Dim<1> { - static constexpr int dimensions = 1; +struct Dim<0> { + static constexpr int dimensions = 0; HOSTDEVICE - Dim(int64_t _head) : head(_head) {} + Dim(int64_t _head) {} HOSTDEVICE - Dim() : head(0) {} + Dim() {} HOSTDEVICE - Dim(int idx, const Dim<1>& size) : head(idx) { + Dim(int idx, const Dim<0>& size) { #ifndef __CUDA_ARCH__ - if (idx >= size.head) { + if (idx > 0) { throw std::invalid_argument("Index out of range."); } #else - PADDLE_ASSERT(idx < size.head); + PADDLE_ASSERT(idx == 0); #endif } HOSTDEVICE - bool operator==(const Dim<1>& o) const { return (head == o.head); } + bool operator==(const Dim<0>& o) const { return true; } HOSTDEVICE - bool operator!=(const Dim<1>& o) const { return !(*this == o); } + bool operator!=(const Dim<0>& o) const { return false; } HOSTDEVICE int64_t& operator[](int idx); HOSTDEVICE int64_t operator[](int idx) const; - - int64_t head; }; namespace { @@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim& dim, int idx) { } template <> -HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) { +HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) { #ifndef __CUDA_ARCH__ - if (idx != 0) { - throw std::invalid_argument("Invalid index"); - } + throw std::invalid_argument("Invalid index"); #else - PADDLE_ASSERT(idx == 0); + PADDLE_ASSERT(false); #endif - return dim.head; + static int64_t head = 0; + return head; } template @@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim& dim, int idx) { } template <> -HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) { +HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) { #ifndef __CUDA_ARCH__ - if (idx != 0) { - throw std::invalid_argument("Invalid index"); - } + throw std::invalid_argument("Invalid index"); #else - PADDLE_ASSERT(idx == 0); + PADDLE_ASSERT(false); #endif - return dim.head; + static int64_t head = 0; + return head; } } // namespace @@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim::operator[](int i) { } // Dynamic access to constant Dim -inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const { +inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const { return indexer(*this, i); } // Dynamic access to mutable Dim -inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) { +inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) { return indexer(*this, i); } @@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim& a, const Dim& b) { // Base case dot product of two Dims // Notice it is inline because it is no longer a template template <> -HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) { - return a.head * b.head; +HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) { + return 0; } // Product of a Dim @@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim& a, int prod = 1) { // Base case product of a Dim // Notice it is inline because it is no longer a template template <> -HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) { - return prod * a.head; +HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) { + return prod; } // Is 0 <= idx_i < size_i for all i? @@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim& idx, const Dim& size) { // Base case of is 0 <= idx_i < size_i ? // Notice it is inline because it is no longer a template template <> -HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) { - return ((0 <= idx.head) && (idx.head < size.head)); +HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) { + return true; } /** @@ -294,8 +290,8 @@ HOSTDEVICE Dim ex_prefix_mul(const Dim& src, int mul = 1) { // Base case of ex_prefix_mul // Notice it is inline because it is no longer a template template <> -HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) { - return Dim<1>(mul); +HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) { + return Dim<0>(); } ///\endcond @@ -309,8 +305,8 @@ HOSTDEVICE Dim dim_plus(const Dim& a, const Dim& b) { // Base case template <> -HOSTDEVICE inline Dim<1> dim_plus(const Dim<1>& a, const Dim<1>& b) { - return Dim<1>(a.head + b.head); +HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) { + return Dim<0>(); } template @@ -328,8 +324,8 @@ HOSTDEVICE Dim dim_mult(const Dim& a, const Dim& b) { // Base case template <> -HOSTDEVICE inline Dim<1> dim_mult(const Dim<1>& a, const Dim<1>& b) { - return Dim<1>(a.head * b.head); +HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) { + return Dim<0>(); } template @@ -356,10 +352,9 @@ HOSTDEVICE Dim normalize_strides(const Dim& size, const Dim& stride) { ///\cond HIDDEN template <> -HOSTDEVICE inline Dim<1> normalize_strides(const Dim<1>& size, - const Dim<1>& stride) { - int norm_stride = size.head == 1 ? 0 : stride.head; - return Dim<1>(norm_stride); +HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size, + const Dim<0>& stride) { + return Dim<0>(); } ///\endcond @@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<( return os; } +inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) { + return os; +} + template HOST std::string Dim::to_string() const { std::stringstream stream; diff --git a/paddle/fluid/operators/detail/strided_memcpy.h b/paddle/fluid/operators/detail/strided_memcpy.h index bac5cdc99c0133b1e6da3f6a23bc0512ca4177f5..0b7c470fe72eb4270b8d5b2d227642d85683c16d 100644 --- a/paddle/fluid/operators/detail/strided_memcpy.h +++ b/paddle/fluid/operators/detail/strided_memcpy.h @@ -24,6 +24,29 @@ namespace detail { template struct StridedMemcpyFunctor; +template +struct StridedMemcpyFunctor { + void operator()(const platform::DeviceContext& dev_ctx, const T* src, + framework::Dim<0> src_stride, framework::Dim<0> dst_dim, + framework::Dim<0> dst_stride, T* dst) const { + auto place = dev_ctx.GetPlace(); + if (platform::is_cpu_place(place)) { + auto& cpu_place = boost::get(place); + memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T)); + } else { +#ifdef PADDLE_WITH_CUDA + auto& gpu_place = boost::get(place); + auto& cuda_ctx = + reinterpret_cast(dev_ctx); + memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T), + cuda_ctx.stream()); +#else + PADDLE_THROW("Paddle is not compiled with GPU"); +#endif + } + } +}; + template struct StridedMemcpyFunctor { void operator()(const platform::DeviceContext& dev_ctx, const T* src, diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index 06bcd0be646e7dff72b46b1c9031464de21b3c6a..fe31bbaed44fced68b7b51dd2c2031950ec4247d 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -65,12 +65,17 @@ smaller than or equal to the dimensions of $X$. There are two cases for this operator: 1. The shape of $Y$ is same with $X$; -2. The shape of $Y$ is a subset of $X$. +2. The shape of $Y$ is a congiguous subsequencet of $X$. The trailing dimensions + of size 1 for $Y$ will be ignored for the consideration of subsequence. + For case 2: + $Y$ will be broadcasted to match the shape of $X$ and axis should be set to index of the start dimension to broadcast $Y$ onto $X$. +If axis is -1, it is treated as axis=rank(X)-rank(Y). + For example .. code-block:: python @@ -79,6 +84,7 @@ For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 + shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0 Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details) information. However, the output only shares the LoD information with input $X$. diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index 600524936b079fb59e4774f477d272d92c06bdf9..ffda53a383ced411415e528886a23f28f6a62648 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -62,6 +62,19 @@ inline void get_mid_dims(const framework::DDim& x_dims, } } +inline void trim_trailing_singular_dims(framework::DDim& dims) { + // Remove trailing dimensions of size 1 for y + auto actual_dims_size = dims.size(); + for (; actual_dims_size != 0; --actual_dims_size) { + if (dims[actual_dims_size - 1] != 1) break; + } + if (actual_dims_size != dims.size()) { + auto actual_dims = framework::vectorize(dims); + actual_dims.resize(actual_dims_size); + dims = framework::make_ddim(actual_dims); + } +} + template class RowwiseTransformIterator; template @@ -264,44 +277,6 @@ class TransformFunctor { } \ } -template -void ElementwiseCompute(const framework::ExecutionContext& ctx) { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), - "Rank of first input must >= rank of second input."); - - if (x_dims == y_dims) { - functor f; - f.template Run(x, y, z, ctx); - return; - } - - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); - - int pre, n, post; - get_mid_dims(x_dims, y_dims, axis, pre, n, post); - if (post == 1) { - functor f; - f.template RunBroadCast(x, y, z, ctx, pre, n); - return; - } else { - functor f; - f.template RunBroadCast2(x, y, z, ctx, pre, n, post); - return; - } -} - #define EIGEN_ADD(x, y) ((x) + (y)) EIGEN_FUNCTOR(Add, EIGEN_ADD); @@ -496,14 +471,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx, auto x_dim = x.dims(); auto y_dim = y.dims(); - if (y_dim.size() == 1 && y_dim[0] == 1) { - // y is a scalar - auto extended_dims = framework::vectorize(x_dim); - extended_dims.push_back(1); - x_dim = framework::make_ddim(extended_dims); - } - axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); + trim_trailing_singular_dims(y_dim); + axis = (y_dim.size() == 0) ? x_dim.size() : axis; + int pre, n, post; get_mid_dims(x_dim, y_dim, axis, pre, n, post); if (post == 1) { @@ -571,14 +542,9 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx, return; } - if (y_dims.size() == 1 && y_dims[0] == 1) { - // y is a scalar - auto extended_dims = framework::vectorize(x_dims); - extended_dims.push_back(1); - x_dims = framework::make_ddim(extended_dims); - } - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + trim_trailing_singular_dims(y_dims); + axis = (y_dims.size() == 0) ? x_dims.size() : axis; int pre, n, post; get_mid_dims(x_dims, y_dims, axis, pre, n, post); @@ -613,16 +579,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx, return; } - if (y_dims.size() == 1 && y_dims[0] == 1) { - // y is a scalar - auto extended_dims = framework::vectorize(x_dims); - extended_dims.push_back(1); - x_dims = framework::make_ddim(extended_dims); - } - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), "Axis should be in range [0, x_dims)"); + trim_trailing_singular_dims(y_dims); + axis = (y_dims.size() == 0) ? x_dims.size() : axis; int pre, n, post; get_mid_dims(x_dims, y_dims, axis, pre, n, post); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index e2749593057067ec04284940fff8f6a5284806ef..e5fb0e5d628b2df14355aec2718cf46aa641b6cf 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -14,7 +14,7 @@ import numpy as np import contextlib -from framework import Program, default_main_program +from framework import Program, default_main_program, Variable from . import core __all__ = [ @@ -281,6 +281,8 @@ class Executor(object): if not has_fetch_operators(global_block, fetch_list, fetch_var_name): for i, var in enumerate(fetch_list): + assert isinstance(var, Variable) or isinstance(var, str), ( + "Wrong type for fetch_list[%s]: %s" % (i, type(var))) global_block.append_op( type='fetch', inputs={'X': [var]}, diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index faccc3ddf827e4211c9f2e61da7138e5d43f1d11..08a0184c2c2ad5f3c3792fd0a12f0ab0c746849b 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -1,11 +1,11 @@ # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -53,12 +53,22 @@ def monkey_patch_variable(): value = float(value) tmp_name = unique_tmp_name() var = ref_var.block.create_var(name=tmp_name, dtype=dtype) + batch_dim = -1 + for i, d in enumerate(ref_var.shape): + if d < 0: + batch_dim = i + break + assert batch_dim != -1 ref_var.block.append_op( type='fill_constant_batch_size_like', outputs={'Out': [var]}, inputs={'Input': [ref_var]}, - attrs={'shape': ref_var.shape, - 'value': value}) + attrs={ + 'shape': ref_var.shape, + 'value': value, + 'input_dim_idx': batch_dim, + 'output_dim_idx': batch_dim + }) return var def astype(self, dtype): @@ -118,11 +128,20 @@ def monkey_patch_variable(): tmp_name = unique_tmp_name() out = self.block.create_var(name=tmp_name, dtype=lhs_dtype) + axis = -1 + if other_var.shape[0] == -1: + axis = 0 + assert len(self.shape) >= len(other_var.shape), ( + "The rank of the first argument of an binary operator cannot " + "be smaller than the rank of its second argument: %s vs %s" % + (len(self.shape), len(other_var.shape))) + self.block.append_op( type=op_type, inputs={'X': [self], 'Y': [other_var]}, - outputs={'Out': out}) + outputs={'Out': out}, + attrs={'axis': axis}) return out comment = OpProtoHolder.instance().get_op_proto(op_type).comment @@ -131,7 +150,7 @@ def monkey_patch_variable(): {0} Args: self(Variable): left hand variable - other_var(Variable|float|int): right hand variable + other_var(Variable|float|int): right hand variable Returns: Variable diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index c8e930dad762b867d0148ebcdb3637b8cc9560ce..5b2384e94d788342c692fcb8e33f3a2ff663ab53 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -50,6 +50,16 @@ class TestElementwiseAddOp_scalar(TestElementwiseOp): self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} +class TestElementwiseAddOp_scalar2(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_add" + self.inputs = { + 'X': np.random.rand(2, 3, 4).astype(np.float32), + 'Y': np.random.rand(1, 1).astype(np.float32) + } + self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} + + class TestElementwiseAddOp_Vector(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_add" @@ -115,6 +125,20 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp): } +class TestElementwiseAddOp_broadcast_4(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_add" + self.inputs = { + 'X': np.random.rand(2, 3, 4, 5).astype(np.float32), + 'Y': np.random.rand(2, 1).astype(np.float32) + } + + self.attrs = {'axis': 0} + self.outputs = { + 'Out': self.inputs['X'] + self.inputs['Y'].reshape(2, 1, 1, 1) + } + + class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_add" diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index 6864d271e795026d59525e9f1e4d86e32df980bf..852a80261e02f5ed19e7fbe608d490be1f7798a9 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -1,11 +1,11 @@ # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -23,13 +23,21 @@ class TestMathOpPatches(unittest.TestCase): def test_add_scalar(self): a = fluid.layers.data(name="a", shape=[1]) b = a + 10 + ab = fluid.layers.concat(input=[a, b], axis=1) + c = ab + 10 + d = ab + a + # e = a + ab place = fluid.CPUPlace() exe = fluid.Executor(place) a_np = numpy.random.random(size=[10, 1]).astype('float32') - b_np = exe.run(fluid.default_main_program(), - feed={"a": a_np}, - fetch_list=[b]) + b_np, c_np, d_np = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b, c, d]) self.assertTrue(numpy.allclose(a_np + 10, b_np)) + ab_np = numpy.concatenate([a_np, b_np], axis=1) + self.assertTrue(numpy.allclose(ab_np + 10, c_np)) + d_expected = ab_np + numpy.concatenate([a_np, a_np], axis=1) + self.assertTrue(numpy.allclose(d_expected, d_np)) @decorators.prog_scope() def test_radd_scalar(self):