提交 926666c3 编写于 作者: H hjchen2

Reduce memory usage for broadcast binary ops backward

上级 2e96920b
......@@ -21,36 +21,74 @@ limitations under the License.
namespace oneflow {
namespace one {
class BroadcastBinaryGrad : public OpExprGradFunction<AutoGradCaptureState> {
struct BroadcastBinaryCaptureState : public AutoGradCaptureState {
int x_index = -1;
int y_index = -1;
int z_index = -1;
bool x_requires_grad = false;
bool y_requires_grad = false;
bool broadcast_x = false;
bool broadcast_y = false;
};
class BroadcastBinaryGrad : public OpExprGradFunction<BroadcastBinaryCaptureState> {
public:
BroadcastBinaryGrad() = default;
virtual ~BroadcastBinaryGrad() = default;
virtual Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(AutoGradCaptureState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(inputs.at(1));
ctx->SaveTensorForBackward(outputs.at(0));
return Maybe<void>::Ok();
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->y_requires_grad = inputs.at(1)->requires_grad();
ctx->broadcast_x = (*inputs.at(0)->shape() != *outputs.at(0)->shape());
ctx->broadcast_y = (*inputs.at(1)->shape() != *outputs.at(0)->shape());
// ctx->broadcast_x = true;
// ctx->broadcast_y = true;
return SaveTensorForBackward(ctx, inputs, outputs);
}
protected:
virtual Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs) const = 0;
};
class BroadcastAdd : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
in_grads->resize(2);
if (x->requires_grad()) {
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), x));
if (ctx->x_requires_grad) {
if (ctx->broadcast_x) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), x));
} else {
in_grads->at(0) = out_grads.at(0);
}
}
if (ctx->y_requires_grad) {
if (ctx->broadcast_y) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), y));
} else {
in_grads->at(1) = out_grads.at(0);
}
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad && ctx->broadcast_x) {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
}
if (y->requires_grad()) {
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), y));
if (ctx->y_requires_grad && ctx->broadcast_y) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
return Maybe<void>::Ok();
}
......@@ -60,17 +98,37 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_add", BroadcastAdd);
class BroadcastSub : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
in_grads->resize(2);
if (x->requires_grad()) {
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), x));
if (ctx->x_requires_grad) {
if (ctx->broadcast_x) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), x));
} else {
in_grads->at(0) = out_grads.at(0);
}
}
if (y->requires_grad()) {
if (ctx->y_requires_grad) {
const auto& grad = JUST(functional::ScalarMul(out_grads.at(0), Scalar(-1.f), false));
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(grad, y));
if (ctx->broadcast_y) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(grad, y));
} else {
in_grads->at(1) = grad;
}
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad && ctx->broadcast_x) {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
}
if (ctx->y_requires_grad && ctx->broadcast_y) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
return Maybe<void>::Ok();
}
......@@ -80,18 +138,46 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_sub", BroadcastSub);
class BroadcastMul : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
in_grads->resize(2);
if (x->requires_grad()) {
if (ctx->x_requires_grad) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& x_grad = JUST(functional::Mul(out_grads.at(0), y));
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x));
if (ctx->broadcast_x) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x));
} else {
in_grads->at(0) = x_grad;
}
}
if (y->requires_grad()) {
if (ctx->y_requires_grad) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
const auto& y_grad = JUST(functional::Mul(out_grads.at(0), x));
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(y_grad, y));
if (ctx->broadcast_y) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(y_grad, y));
} else {
in_grads->at(1) = y_grad;
}
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
if (ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); }
}
if (ctx->y_requires_grad) {
if (ctx->x_index == -1 /*x has not been saved*/) {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
}
if (ctx->broadcast_y && ctx->y_index == -1 /*y has not been saved*/) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
}
return Maybe<void>::Ok();
}
......@@ -101,17 +187,40 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_mul", BroadcastMul);
class BroadcastDiv : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
const auto& z = ctx->SavedTensors().at(2);
in_grads->resize(2);
if (x->requires_grad()) {
if (ctx->x_requires_grad) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& x_grad = JUST(functional::Div(out_grads.at(0), y));
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x));
if (ctx->broadcast_x) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x));
} else {
in_grads->at(0) = x_grad;
}
}
if (ctx->y_requires_grad) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& z = ctx->SavedTensors().at(ctx->z_index);
in_grads->at(1) = JUST(functional::DivGrad(out_grads.at(0), z, y));
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
if (ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); }
}
if (ctx->y_requires_grad) {
if (ctx->y_index == -1 /*y has not been saved*/) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
ctx->z_index = ctx->SaveTensorForBackward(outputs.at(0));
}
if (y->requires_grad()) { in_grads->at(1) = JUST(functional::DivGrad(out_grads.at(0), z, y)); }
return Maybe<void>::Ok();
}
};
......@@ -120,67 +229,92 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_div", BroadcastDiv);
class BroadcastPow : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
const auto& z = ctx->SavedTensors().at(2);
const auto& x = ctx->SavedTensors().at(ctx->x_index);
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& z = ctx->SavedTensors().at(ctx->z_index);
in_grads->resize(2);
if (x->requires_grad()) {
if (ctx->x_requires_grad) {
in_grads->at(0) = JUST(functional::BroadcastPowXGrad(out_grads.at(0), x, y, z));
}
if (y->requires_grad()) {
if (ctx->y_requires_grad) {
in_grads->at(1) = JUST(functional::BroadcastPowYGrad(out_grads.at(0), x, y, z));
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
ctx->z_index = ctx->SaveTensorForBackward(outputs.at(0));
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_pow", BroadcastPow);
class BroadcastMinMax : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
const auto& out = ctx->SavedTensors().at(2);
const auto& out_shape = *(out->shape());
const auto& out_shape = *(out_grads.at(0)->shape());
in_grads->resize(2);
if (x->requires_grad() || y->requires_grad()) {
const auto& x_shape = *(x->shape());
const auto& y_shape = *(y->shape());
if (ctx->x_requires_grad || ctx->y_requires_grad) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
const auto& y = ctx->SavedTensors().at(ctx->y_index);
auto broad_x_ = x;
auto broad_y_ = y;
if (x_shape != out_shape) {
if (ctx->broadcast_x) {
const auto& x_shape = *(x->shape());
const Shape& left_extended_x_shape =
CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes());
const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> x_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_x_ = JUST(functional::BroadcastLike(x, out, x_axis));
broad_x_ = JUST(functional::BroadcastLike(x, out_grads.at(0), x_axis));
}
if (y_shape != out_shape) {
if (ctx->broadcast_y) {
const auto& y_shape = *(y->shape());
const Shape& left_extended_y_shape =
CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes());
const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> y_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_y_ = JUST(functional::BroadcastLike(y, out, y_axis));
broad_y_ = JUST(functional::BroadcastLike(y, out_grads.at(0), y_axis));
}
const auto& broad_grads =
JUST(elementwise_grad_functor_(out_grads.at(0), broad_x_, broad_y_));
if (x->requires_grad()) {
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(broad_grads->at(0), x));
if (ctx->x_requires_grad) {
if (ctx->broadcast_x) {
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(broad_grads->at(0), x));
} else {
in_grads->at(0) = broad_grads->at(0);
}
}
if (y->requires_grad()) {
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(broad_grads->at(1), y));
if (ctx->y_requires_grad) {
if (ctx->broadcast_y) {
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(broad_grads->at(1), y));
} else {
in_grads->at(1) = broad_grads->at(1);
}
}
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad || ctx->y_requires_grad) {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
return Maybe<void>::Ok();
}
std::function<Maybe<TensorTuple>(const std::shared_ptr<Tensor>&, const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)>
elementwise_grad_functor_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册