提交 2d8b9a3b 编写于 作者: L lixinqi

1) FixedVector as DimVector; 2) SimplifyBroadcastBinaryShapes

上级 073d104d
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace oneflow { namespace oneflow {
#define DISABLE_FIXED_SHAPE_VEC //#define DISABLE_FIXED_SHAPE_VEC
#if defined(DISABLE_FIXED_SHAPE_VEC) #if defined(DISABLE_FIXED_SHAPE_VEC)
......
...@@ -129,6 +129,15 @@ bool IsInplaceAllowed( ...@@ -129,6 +129,15 @@ bool IsInplaceAllowed(
const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn); const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn);
if (regst_desc.NumOfLbi() != 1) { return false; } if (regst_desc.NumOfLbi() != 1) { return false; }
} }
const Shape* first_shape = nullptr;
for (const auto& bn : bns) {
const BlobDesc& blob_desc = *exec_node.RegstDesc4BnInOp(bn)->SoleBlobDesc();
if (first_shape == nullptr) {
first_shape = &blob_desc.shape();
} else {
if (*first_shape != blob_desc.shape()) { return false; }
}
}
return true; return true;
} }
......
...@@ -6,6 +6,11 @@ ...@@ -6,6 +6,11 @@
namespace oneflow { namespace oneflow {
void SimplifyBroadcastBinaryShapes(const XpuShape& y, const XpuShape& b, DimVector* simplified_y,
DimVector* simplified_b);
void SimplifyBroadcastBinaryShapes(const XpuShape& y, const XpuShape& a, const XpuShape& b,
DimVector* simplified_y, DimVector* simplified_a,
DimVector* simplified_b);
template<DeviceType device_type, typename T, int NDIMS, template<typename> class binary_func, template<DeviceType device_type, typename T, int NDIMS, template<typename> class binary_func,
typename Enable = void> typename Enable = void>
struct NdarrayApplyBroadcastBinary; struct NdarrayApplyBroadcastBinary;
...@@ -21,7 +26,14 @@ struct NdarrayApplyBroadcastBinary< ...@@ -21,7 +26,14 @@ struct NdarrayApplyBroadcastBinary<
using BroadcastBinary = using BroadcastBinary =
NdarrayApplyBroadcastBinaryCoreWrapper<device_type, T, NDIMS, binary_func>; NdarrayApplyBroadcastBinaryCoreWrapper<device_type, T, NDIMS, binary_func>;
CheckBroadcastable(y, a, b); CheckBroadcastable(y, a, b);
return BroadcastBinary::Apply(ctx, y, a, b); DimVector simplified_y_dim;
DimVector simplified_a_dim;
DimVector simplified_b_dim;
SimplifyBroadcastBinaryShapes(y.shape(), a.shape(), b.shape(), &simplified_y_dim,
&simplified_a_dim, &simplified_b_dim);
return BroadcastBinary::Apply(ctx, XpuVarNdarray<T>(Shape(simplified_y_dim), y.ptr()),
XpuVarNdarray<const T>(Shape(simplified_a_dim), a.ptr()),
XpuVarNdarray<const T>(Shape(simplified_b_dim), b.ptr()));
} }
static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
...@@ -29,7 +41,11 @@ struct NdarrayApplyBroadcastBinary< ...@@ -29,7 +41,11 @@ struct NdarrayApplyBroadcastBinary<
using BroadcastBinary = using BroadcastBinary =
NdarrayApplyBroadcastBinaryCoreWrapper<device_type, T, NDIMS, binary_func>; NdarrayApplyBroadcastBinaryCoreWrapper<device_type, T, NDIMS, binary_func>;
CheckBroadcastable(y, reinterpret_cast<const XpuVarNdarray<const T>&>(y), x); CheckBroadcastable(y, reinterpret_cast<const XpuVarNdarray<const T>&>(y), x);
return BroadcastBinary::InplaceApply(ctx, y, x); DimVector simplified_y_dim;
DimVector simplified_x_dim;
SimplifyBroadcastBinaryShapes(y.shape(), x.shape(), &simplified_y_dim, &simplified_x_dim);
return BroadcastBinary::InplaceApply(ctx, XpuVarNdarray<T>(Shape(simplified_y_dim), y.ptr()),
XpuVarNdarray<const T>(Shape(simplified_x_dim), x.ptr()));
} }
private: private:
......
...@@ -19,8 +19,8 @@ struct NdarrayReduce< ...@@ -19,8 +19,8 @@ struct NdarrayReduce<
final { final {
static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& origin_y, static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& origin_y,
const XpuVarNdarray<const T>& origin_x, const XpuVarNdarray<T>& tmp_storage) { const XpuVarNdarray<const T>& origin_x, const XpuVarNdarray<T>& tmp_storage) {
std::vector<int64_t> simplified_x_dim; DimVector simplified_x_dim;
std::vector<int64_t> simplified_y_dim; DimVector simplified_y_dim;
TrySimplifyDims(origin_x.shape(), origin_y.shape(), &simplified_x_dim, &simplified_y_dim); TrySimplifyDims(origin_x.shape(), origin_y.shape(), &simplified_x_dim, &simplified_y_dim);
XpuVarNdarray<T> y(Shape(simplified_y_dim), origin_y.ptr()); XpuVarNdarray<T> y(Shape(simplified_y_dim), origin_y.ptr());
XpuVarNdarray<const T> x(Shape(simplified_x_dim), origin_x.ptr()); XpuVarNdarray<const T> x(Shape(simplified_x_dim), origin_x.ptr());
...@@ -43,9 +43,8 @@ struct NdarrayReduce< ...@@ -43,9 +43,8 @@ struct NdarrayReduce<
} }
} }
static void TrySimplifyDims(const XpuShape& x, const XpuShape& y, static void TrySimplifyDims(const XpuShape& x, const XpuShape& y, DimVector* simplified_x,
std::vector<int64_t>* simplified_x, DimVector* simplified_y) {
std::vector<int64_t>* simplified_y) {
CHECK_EQ(y.NumAxes(), x.NumAxes()); CHECK_EQ(y.NumAxes(), x.NumAxes());
CHECK(y.At(0) == 1 || y.At(0) == x.At(0)); CHECK(y.At(0) == 1 || y.At(0) == x.At(0));
CHECK(simplified_x->empty()); CHECK(simplified_x->empty());
......
...@@ -158,7 +158,12 @@ struct NdarrayUtil final { ...@@ -158,7 +158,12 @@ struct NdarrayUtil final {
static void ApplyBinary( static void ApplyBinary(
DeviceCtx* ctx, const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y, DeviceCtx* ctx, const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,
const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) { const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {
return NdarrayApplyBinary<device_type, T, binary_func>::Apply(ctx, y, a, b); if (a.host_ptr() == y.host_ptr()) {
CHECK(a.host_shape() == y.host_shape());
return NdarrayApplyBinary<device_type, T, binary_func>::InplaceApply(ctx, y, b);
} else {
return NdarrayApplyBinary<device_type, T, binary_func>::Apply(ctx, y, a, b);
}
} }
template<template<typename> class unary_func> template<template<typename> class unary_func>
......
...@@ -15,7 +15,7 @@ bool IsScalarBlob(const BlobDesc* blob) { ...@@ -15,7 +15,7 @@ bool IsScalarBlob(const BlobDesc* blob) {
void BroadcastBinaryOp::InitFromOpConf() { void BroadcastBinaryOp::InitFromOpConf() {
EnrollInputBn("a"); EnrollInputBn("a");
EnrollInputBn("b"); EnrollInputBn("b");
EnrollOutputBn("out"); EnrollOutputBn("out")->set_mutable_inplace_ibn("a");
} }
Maybe<void> BroadcastBinaryOp::InferBlobDescs( Maybe<void> BroadcastBinaryOp::InferBlobDescs(
......
import oneflow as flow import oneflow as flow
import numpy as np import numpy as np
@flow.function func_config = flow.FunctionConfig()
def AddJob(a=flow.input_blob_def((5, 2)), b=flow.input_blob_def((5, 2))): func_config.default_data_type(flow.float)
flow.config.default_data_type(flow.float)
a + b
return a + b + b
def test_naive(test_case):
@flow.function(func_config)
def AddJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((5, 2))):
return a + b + b
x = np.random.rand(5, 2).astype(np.float32) x = np.random.rand(5, 2).astype(np.float32)
y = np.random.rand(5, 2).astype(np.float32) y = np.random.rand(5, 2).astype(np.float32)
z = None z = None
z = AddJob(x, y).get().ndarray()
test_case.assertTrue(np.array_equal(z, x + y + y))
z = AddJob(x, y).get() def test_broadcast(test_case):
flow.config.enable_debug_mode(True)
@flow.function(func_config)
def AddJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((1, 2))):
return a + b
print (np.array_equal(z, x + y + y)) x = np.random.rand(5, 2).astype(np.float32)
y = np.random.rand(1, 2).astype(np.float32)
z = None
z = AddJob(x, y).get().ndarray()
test_case.assertTrue(np.array_equal(z, x + y))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册