提交 5dda91a8 编写于 作者: C cxxly 提交者: Xiaoxu Chen

fix cast prim and vjp dtype mapping error bug

上级 ece6837f
......@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
namespace paddle {
namespace operators {
......
......@@ -944,6 +944,7 @@ void maximum_grad(const Tensor& x,
}
}
template <typename T>
void dropout_grad(const Tensor& mask,
const Tensor& out_grad,
const Scalar& p,
......
......@@ -244,22 +244,22 @@ class TestCompositeBatchNorm(unittest.TestCase):
atol=attrs.get_atol("forward"),
)
# def test_forward(self):
# for i in self.training:
# for j in self.dtypes:
# for m in self.momentum:
# attrs.set_training(i)
# attrs.set_dtype(j)
# attrs.set_momentum(m)
# self.compare_forward()
# for n in self.shapes:
# for s in self.data_formats:
# for t in self.use_global_stats:
# attrs.set_shape(n)
# attrs.set_data_format(s)
# attrs.set_use_global_stats(t)
# self.compare_forward()
def test_forward(self):
for i in self.training:
for j in self.dtypes:
for m in self.momentum:
attrs.set_training(i)
attrs.set_dtype(j)
attrs.set_momentum(m)
self.compare_forward()
for n in self.shapes:
for s in self.data_formats:
for t in self.use_global_stats:
attrs.set_shape(n)
attrs.set_data_format(s)
attrs.set_use_global_stats(t)
self.compare_forward()
def apply_to_static(net, use_cinn):
......
......@@ -226,7 +226,7 @@ def to_prim(blocks, exclude=frozenset()):
if not core._is_fwd_prim_enabled():
return
if isinstance(blocks, paddle.fluid.framework.Block):
logging.debug("Atomize composite op to primitive ops begin.")
logging.info("Atomize composite op to primitive ops begin.")
main_program = blocks.program
elif isinstance(blocks, typing.Sequence):
for item in blocks:
......@@ -245,9 +245,9 @@ def to_prim(blocks, exclude=frozenset()):
)
with framework.program_guard(main_program):
logging.debug("Lowering composite forward ops begin...")
print("Lowering composite forward ops begin...")
primx._lower_composite(
blocks, prim_config["forward_blacklist"] | exclude
)
replace_ops = prim_config["composite_ops_record"]
logging.debug(f"Lowering composite forward ops finish: {replace_ops}")
print(f"Lowering composite forward ops finish: {replace_ops}")
......@@ -1218,7 +1218,11 @@ class ProgramCache:
return infer_program
partial_program = partial_program_from(concrete_program)
if not _in_amp_guard() and not _in_pure_fp16_guard():
if (
core._is_fwd_prim_enabled()
and not _in_amp_guard()
and not _in_pure_fp16_guard()
):
partial_program.set_hooker(PrimHooker())
return concrete_program, partial_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册