提交 5dda91a8 编写于 作者: C cxxly 提交者: Xiaoxu Chen

fix cast prim and vjp dtype mapping error bug

上级 ece6837f
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -944,6 +944,7 @@ void maximum_grad(const Tensor& x, ...@@ -944,6 +944,7 @@ void maximum_grad(const Tensor& x,
} }
} }
template <typename T>
void dropout_grad(const Tensor& mask, void dropout_grad(const Tensor& mask,
const Tensor& out_grad, const Tensor& out_grad,
const Scalar& p, const Scalar& p,
......
...@@ -244,22 +244,22 @@ class TestCompositeBatchNorm(unittest.TestCase): ...@@ -244,22 +244,22 @@ class TestCompositeBatchNorm(unittest.TestCase):
atol=attrs.get_atol("forward"), atol=attrs.get_atol("forward"),
) )
# def test_forward(self): def test_forward(self):
# for i in self.training: for i in self.training:
# for j in self.dtypes: for j in self.dtypes:
# for m in self.momentum: for m in self.momentum:
# attrs.set_training(i) attrs.set_training(i)
# attrs.set_dtype(j) attrs.set_dtype(j)
# attrs.set_momentum(m) attrs.set_momentum(m)
# self.compare_forward() self.compare_forward()
# for n in self.shapes: for n in self.shapes:
# for s in self.data_formats: for s in self.data_formats:
# for t in self.use_global_stats: for t in self.use_global_stats:
# attrs.set_shape(n) attrs.set_shape(n)
# attrs.set_data_format(s) attrs.set_data_format(s)
# attrs.set_use_global_stats(t) attrs.set_use_global_stats(t)
# self.compare_forward() self.compare_forward()
def apply_to_static(net, use_cinn): def apply_to_static(net, use_cinn):
......
...@@ -226,7 +226,7 @@ def to_prim(blocks, exclude=frozenset()): ...@@ -226,7 +226,7 @@ def to_prim(blocks, exclude=frozenset()):
if not core._is_fwd_prim_enabled(): if not core._is_fwd_prim_enabled():
return return
if isinstance(blocks, paddle.fluid.framework.Block): if isinstance(blocks, paddle.fluid.framework.Block):
logging.debug("Atomize composite op to primitive ops begin.") logging.info("Atomize composite op to primitive ops begin.")
main_program = blocks.program main_program = blocks.program
elif isinstance(blocks, typing.Sequence): elif isinstance(blocks, typing.Sequence):
for item in blocks: for item in blocks:
...@@ -245,9 +245,9 @@ def to_prim(blocks, exclude=frozenset()): ...@@ -245,9 +245,9 @@ def to_prim(blocks, exclude=frozenset()):
) )
with framework.program_guard(main_program): with framework.program_guard(main_program):
logging.debug("Lowering composite forward ops begin...") print("Lowering composite forward ops begin...")
primx._lower_composite( primx._lower_composite(
blocks, prim_config["forward_blacklist"] | exclude blocks, prim_config["forward_blacklist"] | exclude
) )
replace_ops = prim_config["composite_ops_record"] replace_ops = prim_config["composite_ops_record"]
logging.debug(f"Lowering composite forward ops finish: {replace_ops}") print(f"Lowering composite forward ops finish: {replace_ops}")
...@@ -1218,7 +1218,11 @@ class ProgramCache: ...@@ -1218,7 +1218,11 @@ class ProgramCache:
return infer_program return infer_program
partial_program = partial_program_from(concrete_program) partial_program = partial_program_from(concrete_program)
if not _in_amp_guard() and not _in_pure_fp16_guard(): if (
core._is_fwd_prim_enabled()
and not _in_amp_guard()
and not _in_pure_fp16_guard()
):
partial_program.set_hooker(PrimHooker()) partial_program.set_hooker(PrimHooker())
return concrete_program, partial_program return concrete_program, partial_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册