提交 2d8b9a3b 编写于 作者: L lixinqi

1) FixedVector as DimVector; 2) SimplifyBroadcastBinaryShapes

上级 073d104d
......@@ -5,7 +5,7 @@
namespace oneflow {
#define DISABLE_FIXED_SHAPE_VEC
//#define DISABLE_FIXED_SHAPE_VEC
#if defined(DISABLE_FIXED_SHAPE_VEC)
......
......@@ -129,6 +129,15 @@ bool IsInplaceAllowed(
const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn);
if (regst_desc.NumOfLbi() != 1) { return false; }
}
const Shape* first_shape = nullptr;
for (const auto& bn : bns) {
const BlobDesc& blob_desc = *exec_node.RegstDesc4BnInOp(bn)->SoleBlobDesc();
if (first_shape == nullptr) {
first_shape = &blob_desc.shape();
} else {
if (*first_shape != blob_desc.shape()) { return false; }
}
}
return true;
}
......
......@@ -6,6 +6,11 @@
namespace oneflow {
void SimplifyBroadcastBinaryShapes(const XpuShape& y, const XpuShape& b, DimVector* simplified_y,
DimVector* simplified_b);
void SimplifyBroadcastBinaryShapes(const XpuShape& y, const XpuShape& a, const XpuShape& b,
DimVector* simplified_y, DimVector* simplified_a,
DimVector* simplified_b);
template<DeviceType device_type, typename T, int NDIMS, template<typename> class binary_func,
typename Enable = void>
struct NdarrayApplyBroadcastBinary;
......@@ -21,7 +26,14 @@ struct NdarrayApplyBroadcastBinary<
using BroadcastBinary =
NdarrayApplyBroadcastBinaryCoreWrapper<device_type, T, NDIMS, binary_func>;
CheckBroadcastable(y, a, b);
return BroadcastBinary::Apply(ctx, y, a, b);
DimVector simplified_y_dim;
DimVector simplified_a_dim;
DimVector simplified_b_dim;
SimplifyBroadcastBinaryShapes(y.shape(), a.shape(), b.shape(), &simplified_y_dim,
&simplified_a_dim, &simplified_b_dim);
return BroadcastBinary::Apply(ctx, XpuVarNdarray<T>(Shape(simplified_y_dim), y.ptr()),
XpuVarNdarray<const T>(Shape(simplified_a_dim), a.ptr()),
XpuVarNdarray<const T>(Shape(simplified_b_dim), b.ptr()));
}
static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
......@@ -29,7 +41,11 @@ struct NdarrayApplyBroadcastBinary<
using BroadcastBinary =
NdarrayApplyBroadcastBinaryCoreWrapper<device_type, T, NDIMS, binary_func>;
CheckBroadcastable(y, reinterpret_cast<const XpuVarNdarray<const T>&>(y), x);
return BroadcastBinary::InplaceApply(ctx, y, x);
DimVector simplified_y_dim;
DimVector simplified_x_dim;
SimplifyBroadcastBinaryShapes(y.shape(), x.shape(), &simplified_y_dim, &simplified_x_dim);
return BroadcastBinary::InplaceApply(ctx, XpuVarNdarray<T>(Shape(simplified_y_dim), y.ptr()),
XpuVarNdarray<const T>(Shape(simplified_x_dim), x.ptr()));
}
private:
......
......@@ -19,8 +19,8 @@ struct NdarrayReduce<
final {
static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& origin_y,
const XpuVarNdarray<const T>& origin_x, const XpuVarNdarray<T>& tmp_storage) {
std::vector<int64_t> simplified_x_dim;
std::vector<int64_t> simplified_y_dim;
DimVector simplified_x_dim;
DimVector simplified_y_dim;
TrySimplifyDims(origin_x.shape(), origin_y.shape(), &simplified_x_dim, &simplified_y_dim);
XpuVarNdarray<T> y(Shape(simplified_y_dim), origin_y.ptr());
XpuVarNdarray<const T> x(Shape(simplified_x_dim), origin_x.ptr());
......@@ -43,9 +43,8 @@ struct NdarrayReduce<
}
}
static void TrySimplifyDims(const XpuShape& x, const XpuShape& y,
std::vector<int64_t>* simplified_x,
std::vector<int64_t>* simplified_y) {
static void TrySimplifyDims(const XpuShape& x, const XpuShape& y, DimVector* simplified_x,
DimVector* simplified_y) {
CHECK_EQ(y.NumAxes(), x.NumAxes());
CHECK(y.At(0) == 1 || y.At(0) == x.At(0));
CHECK(simplified_x->empty());
......
......@@ -158,7 +158,12 @@ struct NdarrayUtil final {
static void ApplyBinary(
DeviceCtx* ctx, const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,
const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {
return NdarrayApplyBinary<device_type, T, binary_func>::Apply(ctx, y, a, b);
if (a.host_ptr() == y.host_ptr()) {
CHECK(a.host_shape() == y.host_shape());
return NdarrayApplyBinary<device_type, T, binary_func>::InplaceApply(ctx, y, b);
} else {
return NdarrayApplyBinary<device_type, T, binary_func>::Apply(ctx, y, a, b);
}
}
template<template<typename> class unary_func>
......
......@@ -15,7 +15,7 @@ bool IsScalarBlob(const BlobDesc* blob) {
void BroadcastBinaryOp::InitFromOpConf() {
EnrollInputBn("a");
EnrollInputBn("b");
EnrollOutputBn("out");
EnrollOutputBn("out")->set_mutable_inplace_ibn("a");
}
Maybe<void> BroadcastBinaryOp::InferBlobDescs(
......
import oneflow as flow
import numpy as np
@flow.function
def AddJob(a=flow.input_blob_def((5, 2)), b=flow.input_blob_def((5, 2))):
flow.config.default_data_type(flow.float)
a + b
return a + b + b
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
def test_naive(test_case):
@flow.function(func_config)
def AddJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((5, 2))):
return a + b + b
x = np.random.rand(5, 2).astype(np.float32)
y = np.random.rand(5, 2).astype(np.float32)
z = None
x = np.random.rand(5, 2).astype(np.float32)
y = np.random.rand(5, 2).astype(np.float32)
z = None
z = AddJob(x, y).get().ndarray()
test_case.assertTrue(np.array_equal(z, x + y + y))
z = AddJob(x, y).get()
def test_broadcast(test_case):
flow.config.enable_debug_mode(True)
@flow.function(func_config)
def AddJob(a=flow.FixedTensorDef((5, 2)), b=flow.FixedTensorDef((1, 2))):
return a + b
print (np.array_equal(z, x + y + y))
x = np.random.rand(5, 2).astype(np.float32)
y = np.random.rand(1, 2).astype(np.float32)
z = None
z = AddJob(x, y).get().ndarray()
test_case.assertTrue(np.array_equal(z, x + y))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册