未验证 提交 eac2c3cf 编写于 作者: E emailweixu 提交者: GitHub

Merge pull request #8505 from emailweixu/math_op

Correctly handling variable with batch dimension for math ops.
...@@ -26,12 +26,15 @@ Dim<i> make_dim(const int64_t* d) { ...@@ -26,12 +26,15 @@ Dim<i> make_dim(const int64_t* d) {
} }
template <> template <>
Dim<1> make_dim<1>(const int64_t* d) { Dim<0> make_dim<0>(const int64_t* d) {
return Dim<1>(*d); return Dim<0>(*d);
} }
void make_ddim(DDim& ddim, const int64_t* dims, int n) { void make_ddim(DDim& ddim, const int64_t* dims, int n) {
switch (n) { switch (n) {
case 0:
ddim = make_dim<0>(dims);
break;
case 1: case 1:
ddim = make_dim<1>(dims); ddim = make_dim<1>(dims);
break; break;
...@@ -190,7 +193,7 @@ struct VectorizeVisitor : public boost::static_visitor<> { ...@@ -190,7 +193,7 @@ struct VectorizeVisitor : public boost::static_visitor<> {
this->operator()(t.tail); this->operator()(t.tail);
} }
void operator()(const Dim<1>& t) { vector.push_back(t.head); } void operator()(const Dim<0>& t) {}
}; };
/// @endcond /// @endcond
...@@ -247,9 +250,8 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> { ...@@ -247,9 +250,8 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> {
} }
} }
void operator()(const Dim<1>& dim) { void operator()(const Dim<0>& dim) {
PADDLE_ENFORCE(end == 1, "End index in ddim slice is out of bound."); PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound.");
vector.push_back(dim.head);
} }
}; };
......
...@@ -30,8 +30,8 @@ namespace framework { ...@@ -30,8 +30,8 @@ namespace framework {
* The number of dimensions must be between [1, 9]. * The number of dimensions must be between [1, 9].
*/ */
struct DDim { struct DDim {
typedef boost::variant<Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>, typedef boost::variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>,
Dim<8>, Dim<9>> Dim<7>, Dim<8>, Dim<9>>
DDimVar; DDimVar;
DDimVar var; DDimVar var;
......
...@@ -72,38 +72,36 @@ struct Dim { ...@@ -72,38 +72,36 @@ struct Dim {
// Base case specialization // Base case specialization
template <> template <>
struct Dim<1> { struct Dim<0> {
static constexpr int dimensions = 1; static constexpr int dimensions = 0;
HOSTDEVICE HOSTDEVICE
Dim(int64_t _head) : head(_head) {} Dim(int64_t _head) {}
HOSTDEVICE HOSTDEVICE
Dim() : head(0) {} Dim() {}
HOSTDEVICE HOSTDEVICE
Dim(int idx, const Dim<1>& size) : head(idx) { Dim(int idx, const Dim<0>& size) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx >= size.head) { if (idx > 0) {
throw std::invalid_argument("Index out of range."); throw std::invalid_argument("Index out of range.");
} }
#else #else
PADDLE_ASSERT(idx < size.head); PADDLE_ASSERT(idx == 0);
#endif #endif
} }
HOSTDEVICE HOSTDEVICE
bool operator==(const Dim<1>& o) const { return (head == o.head); } bool operator==(const Dim<0>& o) const { return true; }
HOSTDEVICE HOSTDEVICE
bool operator!=(const Dim<1>& o) const { return !(*this == o); } bool operator!=(const Dim<0>& o) const { return false; }
HOSTDEVICE HOSTDEVICE
int64_t& operator[](int idx); int64_t& operator[](int idx);
HOSTDEVICE HOSTDEVICE
int64_t operator[](int idx) const; int64_t operator[](int idx) const;
int64_t head;
}; };
namespace { namespace {
...@@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) { ...@@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
} }
template <> template <>
HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) { HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx != 0) { throw std::invalid_argument("Invalid index");
throw std::invalid_argument("Invalid index");
}
#else #else
PADDLE_ASSERT(idx == 0); PADDLE_ASSERT(false);
#endif #endif
return dim.head; static int64_t head = 0;
return head;
} }
template <int D> template <int D>
...@@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) { ...@@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
} }
template <> template <>
HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) { HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx != 0) { throw std::invalid_argument("Invalid index");
throw std::invalid_argument("Invalid index");
}
#else #else
PADDLE_ASSERT(idx == 0); PADDLE_ASSERT(false);
#endif #endif
return dim.head; static int64_t head = 0;
return head;
} }
} // namespace } // namespace
...@@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim<l>::operator[](int i) { ...@@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
} }
// Dynamic access to constant Dim // Dynamic access to constant Dim
inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const { inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const {
return indexer(*this, i); return indexer(*this, i);
} }
// Dynamic access to mutable Dim // Dynamic access to mutable Dim
inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) { inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) {
return indexer(*this, i); return indexer(*this, i);
} }
...@@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) { ...@@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
// Base case dot product of two Dims // Base case dot product of two Dims
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) { HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) {
return a.head * b.head; return 0;
} }
// Product of a Dim // Product of a Dim
...@@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) { ...@@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
// Base case product of a Dim // Base case product of a Dim
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) { HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) {
return prod * a.head; return prod;
} }
// Is 0 <= idx_i < size_i for all i? // Is 0 <= idx_i < size_i for all i?
...@@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) { ...@@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
// Base case of is 0 <= idx_i < size_i ? // Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) { HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) {
return ((0 <= idx.head) && (idx.head < size.head)); return true;
} }
/** /**
...@@ -294,8 +290,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) { ...@@ -294,8 +290,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) {
// Base case of ex_prefix_mul // Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) { HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) {
return Dim<1>(mul); return Dim<0>();
} }
///\endcond ///\endcond
...@@ -309,8 +305,8 @@ HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) { ...@@ -309,8 +305,8 @@ HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
// Base case // Base case
template <> template <>
HOSTDEVICE inline Dim<1> dim_plus(const Dim<1>& a, const Dim<1>& b) { HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) {
return Dim<1>(a.head + b.head); return Dim<0>();
} }
template <int i> template <int i>
...@@ -328,8 +324,8 @@ HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) { ...@@ -328,8 +324,8 @@ HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
// Base case // Base case
template <> template <>
HOSTDEVICE inline Dim<1> dim_mult(const Dim<1>& a, const Dim<1>& b) { HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) {
return Dim<1>(a.head * b.head); return Dim<0>();
} }
template <int i> template <int i>
...@@ -356,10 +352,9 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) { ...@@ -356,10 +352,9 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
///\cond HIDDEN ///\cond HIDDEN
template <> template <>
HOSTDEVICE inline Dim<1> normalize_strides(const Dim<1>& size, HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size,
const Dim<1>& stride) { const Dim<0>& stride) {
int norm_stride = size.head == 1 ? 0 : stride.head; return Dim<0>();
return Dim<1>(norm_stride);
} }
///\endcond ///\endcond
...@@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<( ...@@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
return os; return os;
} }
inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
return os;
}
template <int i> template <int i>
HOST std::string Dim<i>::to_string() const { HOST std::string Dim<i>::to_string() const {
std::stringstream stream; std::stringstream stream;
......
...@@ -24,6 +24,29 @@ namespace detail { ...@@ -24,6 +24,29 @@ namespace detail {
template <typename T, int Rank> template <typename T, int Rank>
struct StridedMemcpyFunctor; struct StridedMemcpyFunctor;
template <typename T>
struct StridedMemcpyFunctor<T, 0> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<0> src_stride, framework::Dim<0> dst_dim,
framework::Dim<0> dst_stride, T* dst) const {
auto place = dev_ctx.GetPlace();
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T));
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T),
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
};
template <typename T> template <typename T>
struct StridedMemcpyFunctor<T, 1> { struct StridedMemcpyFunctor<T, 1> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src, void operator()(const platform::DeviceContext& dev_ctx, const T* src,
......
...@@ -65,12 +65,17 @@ smaller than or equal to the dimensions of $X$. ...@@ -65,12 +65,17 @@ smaller than or equal to the dimensions of $X$.
There are two cases for this operator: There are two cases for this operator:
1. The shape of $Y$ is same with $X$; 1. The shape of $Y$ is same with $X$;
2. The shape of $Y$ is a subset of $X$. 2. The shape of $Y$ is a congiguous subsequencet of $X$. The trailing dimensions
of size 1 for $Y$ will be ignored for the consideration of subsequence.
For case 2: For case 2:
$Y$ will be broadcasted to match the shape of $X$ and axis should be $Y$ will be broadcasted to match the shape of $X$ and axis should be
set to index of the start dimension to broadcast $Y$ onto $X$. set to index of the start dimension to broadcast $Y$ onto $X$.
If axis is -1, it is treated as axis=rank(X)-rank(Y).
For example For example
.. code-block:: python .. code-block:: python
...@@ -79,6 +84,7 @@ For example ...@@ -79,6 +84,7 @@ For example
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5)
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details) Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$. information. However, the output only shares the LoD information with input $X$.
......
...@@ -62,6 +62,19 @@ inline void get_mid_dims(const framework::DDim& x_dims, ...@@ -62,6 +62,19 @@ inline void get_mid_dims(const framework::DDim& x_dims,
} }
} }
inline void trim_trailing_singular_dims(framework::DDim& dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
if (actual_dims_size != dims.size()) {
auto actual_dims = framework::vectorize(dims);
actual_dims.resize(actual_dims_size);
dims = framework::make_ddim(actual_dims);
}
}
template <typename T, typename DeviceContext> template <typename T, typename DeviceContext>
class RowwiseTransformIterator; class RowwiseTransformIterator;
template <typename T, typename DeviceContext> template <typename T, typename DeviceContext>
...@@ -264,44 +277,6 @@ class TransformFunctor { ...@@ -264,44 +277,6 @@ class TransformFunctor {
} \ } \
} }
template <class functor, typename DeviceContext, typename T>
void ElementwiseCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.");
if (x_dims == y_dims) {
functor f;
f.template Run<DeviceContext, T>(x, y, z, ctx);
return;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
functor f;
f.template RunBroadCast<DeviceContext, T>(x, y, z, ctx, pre, n);
return;
} else {
functor f;
f.template RunBroadCast2<DeviceContext, T>(x, y, z, ctx, pre, n, post);
return;
}
}
#define EIGEN_ADD(x, y) ((x) + (y)) #define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD); EIGEN_FUNCTOR(Add, EIGEN_ADD);
...@@ -496,14 +471,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -496,14 +471,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
auto x_dim = x.dims(); auto x_dim = x.dims();
auto y_dim = y.dims(); auto y_dim = y.dims();
if (y_dim.size() == 1 && y_dim[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dim);
extended_dims.push_back(1);
x_dim = framework::make_ddim(extended_dims);
}
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
trim_trailing_singular_dims(y_dim);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post; int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, pre, n, post); get_mid_dims(x_dim, y_dim, axis, pre, n, post);
if (post == 1) { if (post == 1) {
...@@ -571,14 +542,9 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -571,14 +542,9 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
return; return;
} }
if (y_dims.size() == 1 && y_dims[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dims);
extended_dims.push_back(1);
x_dims = framework::make_ddim(extended_dims);
}
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
trim_trailing_singular_dims(y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post); get_mid_dims(x_dims, y_dims, axis, pre, n, post);
...@@ -613,16 +579,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx, ...@@ -613,16 +579,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
return; return;
} }
if (y_dims.size() == 1 && y_dims[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dims);
extended_dims.push_back(1);
x_dims = framework::make_ddim(extended_dims);
}
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)"); "Axis should be in range [0, x_dims)");
trim_trailing_singular_dims(y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post); get_mid_dims(x_dims, y_dims, axis, pre, n, post);
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import numpy as np import numpy as np
import contextlib import contextlib
from framework import Program, default_main_program from framework import Program, default_main_program, Variable
from . import core from . import core
__all__ = [ __all__ = [
...@@ -281,6 +281,8 @@ class Executor(object): ...@@ -281,6 +281,8 @@ class Executor(object):
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op( global_block.append_op(
type='fetch', type='fetch',
inputs={'X': [var]}, inputs={'X': [var]},
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -53,12 +53,22 @@ def monkey_patch_variable(): ...@@ -53,12 +53,22 @@ def monkey_patch_variable():
value = float(value) value = float(value)
tmp_name = unique_tmp_name() tmp_name = unique_tmp_name()
var = ref_var.block.create_var(name=tmp_name, dtype=dtype) var = ref_var.block.create_var(name=tmp_name, dtype=dtype)
batch_dim = -1
for i, d in enumerate(ref_var.shape):
if d < 0:
batch_dim = i
break
assert batch_dim != -1
ref_var.block.append_op( ref_var.block.append_op(
type='fill_constant_batch_size_like', type='fill_constant_batch_size_like',
outputs={'Out': [var]}, outputs={'Out': [var]},
inputs={'Input': [ref_var]}, inputs={'Input': [ref_var]},
attrs={'shape': ref_var.shape, attrs={
'value': value}) 'shape': ref_var.shape,
'value': value,
'input_dim_idx': batch_dim,
'output_dim_idx': batch_dim
})
return var return var
def astype(self, dtype): def astype(self, dtype):
...@@ -118,11 +128,20 @@ def monkey_patch_variable(): ...@@ -118,11 +128,20 @@ def monkey_patch_variable():
tmp_name = unique_tmp_name() tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype) out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
axis = -1
if other_var.shape[0] == -1:
axis = 0
assert len(self.shape) >= len(other_var.shape), (
"The rank of the first argument of an binary operator cannot "
"be smaller than the rank of its second argument: %s vs %s" %
(len(self.shape), len(other_var.shape)))
self.block.append_op( self.block.append_op(
type=op_type, type=op_type,
inputs={'X': [self], inputs={'X': [self],
'Y': [other_var]}, 'Y': [other_var]},
outputs={'Out': out}) outputs={'Out': out},
attrs={'axis': axis})
return out return out
comment = OpProtoHolder.instance().get_op_proto(op_type).comment comment = OpProtoHolder.instance().get_op_proto(op_type).comment
...@@ -131,7 +150,7 @@ def monkey_patch_variable(): ...@@ -131,7 +150,7 @@ def monkey_patch_variable():
{0} {0}
Args: Args:
self(Variable): left hand variable self(Variable): left hand variable
other_var(Variable|float|int): right hand variable other_var(Variable|float|int): right hand variable
Returns: Returns:
Variable Variable
......
...@@ -50,6 +50,16 @@ class TestElementwiseAddOp_scalar(TestElementwiseOp): ...@@ -50,6 +50,16 @@ class TestElementwiseAddOp_scalar(TestElementwiseOp):
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
class TestElementwiseAddOp_scalar2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(1, 1).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
class TestElementwiseAddOp_Vector(TestElementwiseOp): class TestElementwiseAddOp_Vector(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_add" self.op_type = "elementwise_add"
...@@ -115,6 +125,20 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp): ...@@ -115,6 +125,20 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp):
} }
class TestElementwiseAddOp_broadcast_4(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4, 5).astype(np.float32),
'Y': np.random.rand(2, 1).astype(np.float32)
}
self.attrs = {'axis': 0}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(2, 1, 1, 1)
}
class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp): class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_add" self.op_type = "elementwise_add"
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -23,13 +23,21 @@ class TestMathOpPatches(unittest.TestCase): ...@@ -23,13 +23,21 @@ class TestMathOpPatches(unittest.TestCase):
def test_add_scalar(self): def test_add_scalar(self):
a = fluid.layers.data(name="a", shape=[1]) a = fluid.layers.data(name="a", shape=[1])
b = a + 10 b = a + 10
ab = fluid.layers.concat(input=[a, b], axis=1)
c = ab + 10
d = ab + a
# e = a + ab
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
a_np = numpy.random.random(size=[10, 1]).astype('float32') a_np = numpy.random.random(size=[10, 1]).astype('float32')
b_np = exe.run(fluid.default_main_program(), b_np, c_np, d_np = exe.run(fluid.default_main_program(),
feed={"a": a_np}, feed={"a": a_np},
fetch_list=[b]) fetch_list=[b, c, d])
self.assertTrue(numpy.allclose(a_np + 10, b_np)) self.assertTrue(numpy.allclose(a_np + 10, b_np))
ab_np = numpy.concatenate([a_np, b_np], axis=1)
self.assertTrue(numpy.allclose(ab_np + 10, c_np))
d_expected = ab_np + numpy.concatenate([a_np, a_np], axis=1)
self.assertTrue(numpy.allclose(d_expected, d_np))
@decorators.prog_scope() @decorators.prog_scope()
def test_radd_scalar(self): def test_radd_scalar(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册