提交 e9b8ebf4 编写于 作者: X xuwei06

Correctly handling variable with batch dimension for math ops.

When the second argument contains batch dimension, the axis should be 0.

Also makes elementwise ops more tolerant at handling tensors with trailing
singular dimensions.
上级 7d56c6d0
......@@ -65,12 +65,17 @@ smaller than or equal to the dimensions of $X$.
There are two cases for this operator:
1. The shape of $Y$ is same with $X$;
2. The shape of $Y$ is a subset of $X$.
2. The shape of $Y$ is a congiguous subsequencet of $X$. The trailing dimensions
of size 1 for $Y$ will be ignored for the consideration of subsequence.
For case 2:
$Y$ will be broadcasted to match the shape of $X$ and axis should be
set to index of the start dimension to broadcast $Y$ onto $X$.
If axis is -1, it is treated as axis=rank(X)-rank(Y).
For example
.. code-block:: python
......@@ -79,6 +84,7 @@ For example
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5)
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$.
......
......@@ -61,6 +61,19 @@ inline void get_mid_dims(const framework::DDim& x_dims,
}
}
inline void trim_trailing_singular_dims(framework::DDim& dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
if (actual_dims_size != dims.size()) {
auto actual_dims = framework::vectorize(dims);
actual_dims.resize(actual_dims_size);
dims = framework::make_ddim(actual_dims);
}
}
template <typename T, typename DeviceContext>
class RowwiseTransformIterator;
template <typename T, typename DeviceContext>
......@@ -263,44 +276,6 @@ class TransformFunctor {
} \
}
template <class functor, typename DeviceContext, typename T>
void ElementwiseCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.");
if (x_dims == y_dims) {
functor f;
f.template Run<DeviceContext, T>(x, y, z, ctx);
return;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
functor f;
f.template RunBroadCast<DeviceContext, T>(x, y, z, ctx, pre, n);
return;
} else {
functor f;
f.template RunBroadCast2<DeviceContext, T>(x, y, z, ctx, pre, n, post);
return;
}
}
#define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD);
......@@ -516,14 +491,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
auto x_dim = x.dims();
auto y_dim = y.dims();
if (y_dim.size() == 1 && y_dim[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dim);
extended_dims.push_back(1);
x_dim = framework::make_ddim(extended_dims);
}
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
trim_trailing_singular_dims(y_dim);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, pre, n, post);
if (post == 1) {
......@@ -591,14 +562,9 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
return;
}
if (y_dims.size() == 1 && y_dims[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dims);
extended_dims.push_back(1);
x_dims = framework::make_ddim(extended_dims);
}
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
trim_trailing_singular_dims(y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
......@@ -633,16 +599,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
return;
}
if (y_dims.size() == 1 && y_dims[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dims);
extended_dims.push_back(1);
x_dims = framework::make_ddim(extended_dims);
}
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
trim_trailing_singular_dims(y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
......
......@@ -14,7 +14,7 @@
import numpy as np
import contextlib
from framework import Program, default_main_program
from framework import Program, default_main_program, Variable
from . import core
__all__ = [
......@@ -281,6 +281,8 @@ class Executor(object):
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
......
......@@ -53,12 +53,22 @@ def monkey_patch_variable():
value = float(value)
tmp_name = unique_tmp_name()
var = ref_var.block.create_var(name=tmp_name, dtype=dtype)
batch_dim = -1
for i, d in enumerate(ref_var.shape):
if d < 0:
batch_dim = i
break
assert batch_dim != -1
ref_var.block.append_op(
type='fill_constant_batch_size_like',
outputs={'Out': [var]},
inputs={'Input': [ref_var]},
attrs={'shape': ref_var.shape,
'value': value})
attrs={
'shape': ref_var.shape,
'value': value,
'input_dim_idx': batch_dim,
'output_dim_idx': batch_dim
})
return var
def astype(self, dtype):
......@@ -118,11 +128,20 @@ def monkey_patch_variable():
tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
axis = -1
if other_var.shape[0] == -1:
axis = 0
assert len(self.shape) >= len(other_var.shape), (
"The rank of the first argument of an binary operator cannot "
"be smaller than the rank of its second argument: %s vs %s" %
(len(self.shape), len(other_var.shape)))
self.block.append_op(
type=op_type,
inputs={'X': [self],
'Y': [other_var]},
outputs={'Out': out})
outputs={'Out': out},
attrs={'axis': axis})
return out
comment = OpProtoHolder.instance().get_op_proto(op_type).comment
......
......@@ -50,6 +50,16 @@ class TestElementwiseAddOp_scalar(TestElementwiseOp):
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
class TestElementwiseAddOp_scalar2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(1, 1).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
class TestElementwiseAddOp_Vector(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
......@@ -115,6 +125,20 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp):
}
class TestElementwiseAddOp_broadcast_4(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4, 5).astype(np.float32),
'Y': np.random.rand(2, 1).astype(np.float32)
}
self.attrs = {'axis': 0}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(2, 1, 1, 1)
}
class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
......
......@@ -23,13 +23,21 @@ class TestMathOpPatches(unittest.TestCase):
def test_add_scalar(self):
a = fluid.layers.data(name="a", shape=[1])
b = a + 10
ab = fluid.layers.concat(input=[a, b], axis=1)
c = ab + 10
d = ab + a
# e = a + ab
place = fluid.CPUPlace()
exe = fluid.Executor(place)
a_np = numpy.random.random(size=[10, 1]).astype('float32')
b_np = exe.run(fluid.default_main_program(),
b_np, c_np, d_np = exe.run(fluid.default_main_program(),
feed={"a": a_np},
fetch_list=[b])
fetch_list=[b, c, d])
self.assertTrue(numpy.allclose(a_np + 10, b_np))
ab_np = numpy.concatenate([a_np, b_np], axis=1)
self.assertTrue(numpy.allclose(ab_np + 10, c_np))
d_expected = ab_np + numpy.concatenate([a_np, a_np], axis=1)
self.assertTrue(numpy.allclose(d_expected, d_np))
@decorators.prog_scope()
def test_radd_scalar(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册