提交 d0c80f43 编写于 作者: C cxxly 提交者: Xiaoxu Chen

[prim] enable dygraph_to_static to support custom_vjp

上级 539d05c6
......@@ -91,6 +91,8 @@ class OpInfo {
// some ops don't have grad_op_maker, add check before use GradOpMaker()
bool HasGradOpMaker() const { return grad_op_maker_ != nullptr; }
bool HasCompGradOpMaker() const { return grad_comp_op_maker_ != nullptr; }
bool HasNonEmptyGradOpMaker() const {
return grad_op_maker_ != nullptr && !use_empty_grad_op_desc_maker_;
}
......
......@@ -17,6 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
......@@ -158,6 +160,26 @@ class DropoutGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class DropoutCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
auto mask = this->GetSingleForwardOutput("Mask");
auto out_grad = this->GetSingleOutputGrad("Out");
auto x_grad = this->GetSingleInputGrad("X");
auto x_grad_p = this->GetOutputPtr(&x_grad);
auto x_grad_name = this->GetOutputName(x_grad);
auto p = this->Attr<float>("dropout_prob");
auto is_test = this->Attr<bool>("is_test");
auto mode = this->Attr<std::string>("dropout_implementation");
prim::dropout_grad<prim::DescTensor>(
mask, out_grad, p, is_test, mode, x_grad_p);
VLOG(3) << "Runing dropout_grad composite func";
this->RecoverOutputName(x_grad, x_grad_name);
}
};
class DropoutNdOpMaker : public DropoutOpMaker {
public:
void Make() override {
......@@ -195,6 +217,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(dropout,
REGISTER_OPERATOR(dropout,
ops::DropoutOp,
ops::DropoutOpMaker,
ops::DropoutCompositeGradOpMaker,
ops::DropoutGradOpMaker<paddle::framework::OpDesc>,
ops::DropoutGradOpMaker<paddle::imperative::OpBase>,
DropoutInferShapeFunctor);
......
......@@ -944,5 +944,32 @@ void maximum_grad(const Tensor& x,
}
}
void dropout_grad(const Tensor& mask,
const Tensor& out_grad,
const Scalar& p,
bool is_test,
const std::string& mode,
Tensor* x_grad) {
if (!x_grad) return;
if (is_test) {
if (mode == "unscale_in_train") {
by_pass<T>(out_grad, x_grad);
} else {
set_output<T>(out_grad * (1.0 - p.to<float>()), x_grad);
}
} else {
if (mode == "unscale_in_train") {
if (p.to<float>() == 1.0f) {
set_output<T>(out_grad * 0.0, x_grad);
} else {
set_output<T>(
out_grad * cast<T>(mask, out_grad.dtype()) / (1.0 - p.to<float>()),
x_grad);
}
} else {
set_output<T>(out_grad * cast<T>(mask, out_grad.dtype()), x_grad);
}
}
}
} // namespace prim
} // namespace paddle
......@@ -1472,6 +1472,9 @@ All parameter, weight, gradient are variables in Paddle.
[](std::unique_ptr<OpDesc> &p) { return p.release(); });
return std::make_pair(grad_op_desc_ptrs, grad_to_var);
});
m.def("has_comp_grad_op_maker", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasCompGradOpMaker();
});
m.def("has_grad_op_maker", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasGradOpMaker();
});
......
......@@ -645,6 +645,7 @@ def _lower_composite(block, blacklist=frozenset()):
else:
none_vars_to_remove.add(orig_out.name)
else:
<<<<<<< HEAD
inputs = {}
for i in range(len(op.input_names)):
inputs[op.input_names[i]] = bind_name(
......@@ -669,6 +670,11 @@ def _lower_composite(block, blacklist=frozenset()):
attrs=None,
)
block.ops.append(op)
=======
op_desc = block.desc.append_op()
op_desc.copy_from(op.desc)
block._sync_with_cpp()
>>>>>>> [prim] enable dygraph_to_static to support custom_vjp
# Step3: Do some post-processing work
for op_idx in reversed(ops_to_remove):
......
......@@ -956,13 +956,6 @@ class ConcreteProgram:
self.function = function
self.kwargs = kwargs
@switch_to_static_graph
def _to_prim(self):
# TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd.primapi import to_prim
to_prim(self.main_program.blocks)
@staticmethod
@switch_to_static_graph
def from_func_spec(
......@@ -1188,10 +1181,29 @@ class ProgramCache:
var.name, var.shape
)
)
if not _in_amp_guard() and not _in_pure_fp16_guard():
concrete_program._to_prim()
return concrete_program, partial_program_from(concrete_program)
custom_vjps = set()
if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled():
custom_vjps = {
op.type
for op in concrete_program.main_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
}
if core._is_fwd_prim_enabled():
if not _in_amp_guard() and not _in_pure_fp16_guard():
_to_prim(
concrete_program.main_program.blocks, exclude=custom_vjps
)
partial_program = partial_program_from(concrete_program)
if core._is_fwd_prim_enabled() and len(custom_vjps) != 0:
if not _in_amp_guard() and not _in_pure_fp16_guard():
_to_prim(partial_program.forward_program.blocks)
return concrete_program, partial_program
def __getitem__(self, item):
if not isinstance(item, CacheKey):
......@@ -1660,3 +1672,11 @@ def enable_to_static(enable_to_static_bool):
)
_program_trans = ProgramTranslator()
_program_trans.enable(enable_to_static_bool)
@switch_to_static_graph
def _to_prim(blocks, exclude=frozenset()):
# TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd import primapi
primapi.to_prim(blocks, exclude=exclude)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册