diff --git a/oneflow/core/common/shape_vec.h b/oneflow/core/common/shape_vec.h index 54df0e8d5059e2731aa0255be28e31ccbbf61ac7..03ddd98f9cb598ff412d95d7199cf9d7459474b2 100644 --- a/oneflow/core/common/shape_vec.h +++ b/oneflow/core/common/shape_vec.h @@ -5,7 +5,7 @@ namespace oneflow { -#define DISABLE_FIXED_SHAPE_VEC +//#define DISABLE_FIXED_SHAPE_VEC #if defined(DISABLE_FIXED_SHAPE_VEC) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 6ad38a08be60590fc7cf222b2866efb9287222d1..a5a1a56db7364a663fb0c23903f8b21a72de0819 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -129,6 +129,15 @@ bool IsInplaceAllowed( const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn); if (regst_desc.NumOfLbi() != 1) { return false; } } + const Shape* first_shape = nullptr; + for (const auto& bn : bns) { + const BlobDesc& blob_desc = *exec_node.RegstDesc4BnInOp(bn)->SoleBlobDesc(); + if (first_shape == nullptr) { + first_shape = &blob_desc.shape(); + } else { + if (*first_shape != blob_desc.shape()) { return false; } + } + } return true; } diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h b/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h index ccae19c22a92beae2ef1f881c664317ea831bdc2..f3b278c5d8e1ab9ca396c62a5536686f3818522c 100644 --- a/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h @@ -6,6 +6,11 @@ namespace oneflow { +void SimplifyBroadcastBinaryShapes(const XpuShape& y, const XpuShape& b, DimVector* simplified_y, + DimVector* simplified_b); +void SimplifyBroadcastBinaryShapes(const XpuShape& y, const XpuShape& a, const XpuShape& b, + DimVector* simplified_y, DimVector* simplified_a, + DimVector* simplified_b); template class binary_func, typename Enable = void> struct NdarrayApplyBroadcastBinary; @@ -21,7 +26,14 @@ struct NdarrayApplyBroadcastBinary< using BroadcastBinary = NdarrayApplyBroadcastBinaryCoreWrapper; CheckBroadcastable(y, a, b); - return BroadcastBinary::Apply(ctx, y, a, b); + DimVector simplified_y_dim; + DimVector simplified_a_dim; + DimVector simplified_b_dim; + SimplifyBroadcastBinaryShapes(y.shape(), a.shape(), b.shape(), &simplified_y_dim, + &simplified_a_dim, &simplified_b_dim); + return BroadcastBinary::Apply(ctx, XpuVarNdarray(Shape(simplified_y_dim), y.ptr()), + XpuVarNdarray(Shape(simplified_a_dim), a.ptr()), + XpuVarNdarray(Shape(simplified_b_dim), b.ptr())); } static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray& y, @@ -29,7 +41,11 @@ struct NdarrayApplyBroadcastBinary< using BroadcastBinary = NdarrayApplyBroadcastBinaryCoreWrapper; CheckBroadcastable(y, reinterpret_cast&>(y), x); - return BroadcastBinary::InplaceApply(ctx, y, x); + DimVector simplified_y_dim; + DimVector simplified_x_dim; + SimplifyBroadcastBinaryShapes(y.shape(), x.shape(), &simplified_y_dim, &simplified_x_dim); + return BroadcastBinary::InplaceApply(ctx, XpuVarNdarray(Shape(simplified_y_dim), y.ptr()), + XpuVarNdarray(Shape(simplified_x_dim), x.ptr())); } private: diff --git a/oneflow/core/ndarray/ndarray_reduce.h b/oneflow/core/ndarray/ndarray_reduce.h index fe44fc6096db2ca14a67c554039562ad66e77706..e3f60f33c5575416d7a0dde83730281145b14ed5 100644 --- a/oneflow/core/ndarray/ndarray_reduce.h +++ b/oneflow/core/ndarray/ndarray_reduce.h @@ -19,8 +19,8 @@ struct NdarrayReduce< final { static void Reduce(DeviceCtx* ctx, const XpuVarNdarray& origin_y, const XpuVarNdarray& origin_x, const XpuVarNdarray& tmp_storage) { - std::vector simplified_x_dim; - std::vector simplified_y_dim; + DimVector simplified_x_dim; + DimVector simplified_y_dim; TrySimplifyDims(origin_x.shape(), origin_y.shape(), &simplified_x_dim, &simplified_y_dim); XpuVarNdarray y(Shape(simplified_y_dim), origin_y.ptr()); XpuVarNdarray x(Shape(simplified_x_dim), origin_x.ptr()); @@ -43,9 +43,8 @@ struct NdarrayReduce< } } - static void TrySimplifyDims(const XpuShape& x, const XpuShape& y, - std::vector* simplified_x, - std::vector* simplified_y) { + static void TrySimplifyDims(const XpuShape& x, const XpuShape& y, DimVector* simplified_x, + DimVector* simplified_y) { CHECK_EQ(y.NumAxes(), x.NumAxes()); CHECK(y.At(0) == 1 || y.At(0) == x.At(0)); CHECK(simplified_x->empty()); diff --git a/oneflow/core/ndarray/ndarray_util.h b/oneflow/core/ndarray/ndarray_util.h index a124b65ce6de22cae104e7089c025d963a3df321..9e6d9a9949e62e7c848115d66ae39e732760d743 100644 --- a/oneflow/core/ndarray/ndarray_util.h +++ b/oneflow/core/ndarray/ndarray_util.h @@ -158,7 +158,12 @@ struct NdarrayUtil final { static void ApplyBinary( DeviceCtx* ctx, const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { - return NdarrayApplyBinary::Apply(ctx, y, a, b); + if (a.host_ptr() == y.host_ptr()) { + CHECK(a.host_shape() == y.host_shape()); + return NdarrayApplyBinary::InplaceApply(ctx, y, b); + } else { + return NdarrayApplyBinary::Apply(ctx, y, a, b); + } } template class unary_func> diff --git a/oneflow/core/operator/broadcast_binary_op.cpp b/oneflow/core/operator/broadcast_binary_op.cpp index f8fe248315ddc3a1be4b95bf28bcf4c5be786d4d..3f2517e67e284fc2ae0236896ae7cd738ae75e92 100644 --- a/oneflow/core/operator/broadcast_binary_op.cpp +++ b/oneflow/core/operator/broadcast_binary_op.cpp @@ -15,7 +15,7 @@ bool IsScalarBlob(const BlobDesc* blob) { void BroadcastBinaryOp::InitFromOpConf() { EnrollInputBn("a"); EnrollInputBn("b"); - EnrollOutputBn("out"); + EnrollOutputBn("out")->set_mutable_inplace_ibn("a"); } Maybe BroadcastBinaryOp::InferBlobDescs( diff --git a/oneflow/python/test/ops/test_add.py b/oneflow/python/test/ops/test_add.py index a84de2f201ac63fcef7a02228e079aaac2ecd11a..df7b2665e01fc6d6084ffeec4b6ab1ca5312f796 100644 --- a/oneflow/python/test/ops/test_add.py +++ b/oneflow/python/test/ops/test_add.py @@ -1,17 +1,28 @@ import oneflow as flow import numpy as np -@flow.function -def AddJob(a=flow.input_blob_def((5, 2)), b=flow.input_blob_def((5, 2))): - flow.config.default_data_type(flow.float) - a + b - return a + b + b +func_config = flow.FunctionConfig() +func_config.default_data_type(flow.float) +def test_naive(test_case): + @flow.function(func_config) + def AddJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((5, 2))): + return a + b + b -x = np.random.rand(5, 2).astype(np.float32) -y = np.random.rand(5, 2).astype(np.float32) -z = None + x = np.random.rand(5, 2).astype(np.float32) + y = np.random.rand(5, 2).astype(np.float32) + z = None + z = AddJob(x, y).get().ndarray() + test_case.assertTrue(np.array_equal(z, x + y + y)) -z = AddJob(x, y).get() +def test_broadcast(test_case): + flow.config.enable_debug_mode(True) + @flow.function(func_config) + def AddJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((1, 2))): + return a + b -print (np.array_equal(z, x + y + y)) + x = np.random.rand(5, 2).astype(np.float32) + y = np.random.rand(1, 2).astype(np.float32) + z = None + z = AddJob(x, y).get().ndarray() + test_case.assertTrue(np.array_equal(z, x + y))