提交 e9b8ebf4 编写于 作者: X xuwei06

Correctly handling variable with batch dimension for math ops.

When the second argument contains batch dimension, the axis should be 0.

Also makes elementwise ops more tolerant at handling tensors with trailing
singular dimensions.
上级 7d56c6d0
...@@ -65,12 +65,17 @@ smaller than or equal to the dimensions of $X$. ...@@ -65,12 +65,17 @@ smaller than or equal to the dimensions of $X$.
There are two cases for this operator: There are two cases for this operator:
1. The shape of $Y$ is same with $X$; 1. The shape of $Y$ is same with $X$;
2. The shape of $Y$ is a subset of $X$. 2. The shape of $Y$ is a congiguous subsequencet of $X$. The trailing dimensions
of size 1 for $Y$ will be ignored for the consideration of subsequence.
For case 2: For case 2:
$Y$ will be broadcasted to match the shape of $X$ and axis should be $Y$ will be broadcasted to match the shape of $X$ and axis should be
set to index of the start dimension to broadcast $Y$ onto $X$. set to index of the start dimension to broadcast $Y$ onto $X$.
If axis is -1, it is treated as axis=rank(X)-rank(Y).
For example For example
.. code-block:: python .. code-block:: python
...@@ -79,6 +84,7 @@ For example ...@@ -79,6 +84,7 @@ For example
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5)
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details) Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$. information. However, the output only shares the LoD information with input $X$.
......
...@@ -61,6 +61,19 @@ inline void get_mid_dims(const framework::DDim& x_dims, ...@@ -61,6 +61,19 @@ inline void get_mid_dims(const framework::DDim& x_dims,
} }
} }
inline void trim_trailing_singular_dims(framework::DDim& dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
if (actual_dims_size != dims.size()) {
auto actual_dims = framework::vectorize(dims);
actual_dims.resize(actual_dims_size);
dims = framework::make_ddim(actual_dims);
}
}
template <typename T, typename DeviceContext> template <typename T, typename DeviceContext>
class RowwiseTransformIterator; class RowwiseTransformIterator;
template <typename T, typename DeviceContext> template <typename T, typename DeviceContext>
...@@ -263,44 +276,6 @@ class TransformFunctor { ...@@ -263,44 +276,6 @@ class TransformFunctor {
} \ } \
} }
template <class functor, typename DeviceContext, typename T>
void ElementwiseCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.");
if (x_dims == y_dims) {
functor f;
f.template Run<DeviceContext, T>(x, y, z, ctx);
return;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
functor f;
f.template RunBroadCast<DeviceContext, T>(x, y, z, ctx, pre, n);
return;
} else {
functor f;
f.template RunBroadCast2<DeviceContext, T>(x, y, z, ctx, pre, n, post);
return;
}
}
#define EIGEN_ADD(x, y) ((x) + (y)) #define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD); EIGEN_FUNCTOR(Add, EIGEN_ADD);
...@@ -516,14 +491,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -516,14 +491,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
auto x_dim = x.dims(); auto x_dim = x.dims();
auto y_dim = y.dims(); auto y_dim = y.dims();
if (y_dim.size() == 1 && y_dim[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dim);
extended_dims.push_back(1);
x_dim = framework::make_ddim(extended_dims);
}
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
trim_trailing_singular_dims(y_dim);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post; int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, pre, n, post); get_mid_dims(x_dim, y_dim, axis, pre, n, post);
if (post == 1) { if (post == 1) {
...@@ -591,14 +562,9 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -591,14 +562,9 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
return; return;
} }
if (y_dims.size() == 1 && y_dims[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dims);
extended_dims.push_back(1);
x_dims = framework::make_ddim(extended_dims);
}
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
trim_trailing_singular_dims(y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post); get_mid_dims(x_dims, y_dims, axis, pre, n, post);
...@@ -633,16 +599,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx, ...@@ -633,16 +599,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
return; return;
} }
if (y_dims.size() == 1 && y_dims[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dims);
extended_dims.push_back(1);
x_dims = framework::make_ddim(extended_dims);
}
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)"); "Axis should be in range [0, x_dims)");
trim_trailing_singular_dims(y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post); get_mid_dims(x_dims, y_dims, axis, pre, n, post);
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import numpy as np import numpy as np
import contextlib import contextlib
from framework import Program, default_main_program from framework import Program, default_main_program, Variable
from . import core from . import core
__all__ = [ __all__ = [
...@@ -281,6 +281,8 @@ class Executor(object): ...@@ -281,6 +281,8 @@ class Executor(object):
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op( global_block.append_op(
type='fetch', type='fetch',
inputs={'X': [var]}, inputs={'X': [var]},
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -53,12 +53,22 @@ def monkey_patch_variable(): ...@@ -53,12 +53,22 @@ def monkey_patch_variable():
value = float(value) value = float(value)
tmp_name = unique_tmp_name() tmp_name = unique_tmp_name()
var = ref_var.block.create_var(name=tmp_name, dtype=dtype) var = ref_var.block.create_var(name=tmp_name, dtype=dtype)
batch_dim = -1
for i, d in enumerate(ref_var.shape):
if d < 0:
batch_dim = i
break
assert batch_dim != -1
ref_var.block.append_op( ref_var.block.append_op(
type='fill_constant_batch_size_like', type='fill_constant_batch_size_like',
outputs={'Out': [var]}, outputs={'Out': [var]},
inputs={'Input': [ref_var]}, inputs={'Input': [ref_var]},
attrs={'shape': ref_var.shape, attrs={
'value': value}) 'shape': ref_var.shape,
'value': value,
'input_dim_idx': batch_dim,
'output_dim_idx': batch_dim
})
return var return var
def astype(self, dtype): def astype(self, dtype):
...@@ -118,11 +128,20 @@ def monkey_patch_variable(): ...@@ -118,11 +128,20 @@ def monkey_patch_variable():
tmp_name = unique_tmp_name() tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype) out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
axis = -1
if other_var.shape[0] == -1:
axis = 0
assert len(self.shape) >= len(other_var.shape), (
"The rank of the first argument of an binary operator cannot "
"be smaller than the rank of its second argument: %s vs %s" %
(len(self.shape), len(other_var.shape)))
self.block.append_op( self.block.append_op(
type=op_type, type=op_type,
inputs={'X': [self], inputs={'X': [self],
'Y': [other_var]}, 'Y': [other_var]},
outputs={'Out': out}) outputs={'Out': out},
attrs={'axis': axis})
return out return out
comment = OpProtoHolder.instance().get_op_proto(op_type).comment comment = OpProtoHolder.instance().get_op_proto(op_type).comment
...@@ -131,7 +150,7 @@ def monkey_patch_variable(): ...@@ -131,7 +150,7 @@ def monkey_patch_variable():
{0} {0}
Args: Args:
self(Variable): left hand variable self(Variable): left hand variable
other_var(Variable|float|int): right hand variable other_var(Variable|float|int): right hand variable
Returns: Returns:
Variable Variable
......
...@@ -50,6 +50,16 @@ class TestElementwiseAddOp_scalar(TestElementwiseOp): ...@@ -50,6 +50,16 @@ class TestElementwiseAddOp_scalar(TestElementwiseOp):
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
class TestElementwiseAddOp_scalar2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(1, 1).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
class TestElementwiseAddOp_Vector(TestElementwiseOp): class TestElementwiseAddOp_Vector(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_add" self.op_type = "elementwise_add"
...@@ -115,6 +125,20 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp): ...@@ -115,6 +125,20 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp):
} }
class TestElementwiseAddOp_broadcast_4(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4, 5).astype(np.float32),
'Y': np.random.rand(2, 1).astype(np.float32)
}
self.attrs = {'axis': 0}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(2, 1, 1, 1)
}
class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp): class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_add" self.op_type = "elementwise_add"
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -23,13 +23,21 @@ class TestMathOpPatches(unittest.TestCase): ...@@ -23,13 +23,21 @@ class TestMathOpPatches(unittest.TestCase):
def test_add_scalar(self): def test_add_scalar(self):
a = fluid.layers.data(name="a", shape=[1]) a = fluid.layers.data(name="a", shape=[1])
b = a + 10 b = a + 10
ab = fluid.layers.concat(input=[a, b], axis=1)
c = ab + 10
d = ab + a
# e = a + ab
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
a_np = numpy.random.random(size=[10, 1]).astype('float32') a_np = numpy.random.random(size=[10, 1]).astype('float32')
b_np = exe.run(fluid.default_main_program(), b_np, c_np, d_np = exe.run(fluid.default_main_program(),
feed={"a": a_np}, feed={"a": a_np},
fetch_list=[b]) fetch_list=[b, c, d])
self.assertTrue(numpy.allclose(a_np + 10, b_np)) self.assertTrue(numpy.allclose(a_np + 10, b_np))
ab_np = numpy.concatenate([a_np, b_np], axis=1)
self.assertTrue(numpy.allclose(ab_np + 10, c_np))
d_expected = ab_np + numpy.concatenate([a_np, a_np], axis=1)
self.assertTrue(numpy.allclose(d_expected, d_np))
@decorators.prog_scope() @decorators.prog_scope()
def test_radd_scalar(self): def test_radd_scalar(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册