未验证 提交 2f601282 编写于 作者: J Jiabin Yang 提交者: GitHub

Prim paddle Basic (#49272)

* proto type of composite grad in paddle

* proto type of composite grad in paddle

* refactor composite api with phi

* fix compile error

* support static graph code-gen for squeeze op

* generate static graph code of unsqueeze

* refine op name

* fix compile error

* add extra output in op_compat

* remove debug log

* fix clang compile error

* support prim switch flag

* support prim switch flag

* fix dygraph error

* merge develop

* add code_gen

* add necessary files without codegen

* fix code_gen bug

* add deps

* modify igmnore

* add ignore

* delete std cout

* add composite logic for backward.py

* add tanh first order grad composite

* support enable_prim flag for static graph

* throw expection when both GrapOpMaker and GradCompOpMaker not been registered

* reorganize the directory of prim api tests

* fix windows error

* add eager_utils

* add eager_utils

* modify code gen

* add composite parse

* add unittest for get_grad_op_desc

* code optimize

* fix static test on windows

* support generate static graph code for imag and real op

* fix windows compile error in test_static_prim

* merge develop

* disable test eager in inference

* prim code gen

* disable eager compile in inference

* rm other file

* rm gitignore file

* code_style

* add eager test

* code_style

* merge develop

* remove useless files

* modify static test

* support bool flag from singlton

* merge develop

* recover git ignore

* fix conflict

* recover git ignore for generated op

* fix test compile error

* remove some tests

* add python test

* fix some name issue

* add composite code gen

* modify backward yaml

* fix static composite grad maker code gen

* remove addtional files

* add some static funcs unit test

* fix some bugs

* fix composite grad maker register code gen

* optimize some functions
Co-authored-by: Nzyfncg <zhangyunfei07@baidu.com>
Co-authored-by: Nwangruting <wangruting@baidu.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>
Co-authored-by: Ncharles-hit <wanghao107@baidu.com>
Co-authored-by: Nxiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com>
上级 65d2b4af
......@@ -6,6 +6,7 @@ add_subdirectory(imperative)
add_subdirectory(operators)
add_subdirectory(pybind)
add_subdirectory(eager)
add_subdirectory(prim)
add_subdirectory(jit)
# NOTE: please add subdirectory inference at last.
add_subdirectory(inference)
......@@ -16,7 +16,7 @@ set(eager_deps
custom_operator_node)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
set(eager_deps ${eager_deps} accumulation_node)
set(eager_deps ${eager_deps} accumulation_node prim_utils)
endif()
set(fluid_deps
......
......@@ -7,6 +7,6 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
final_dygraph_node
SRCS nodes.cc ${eager_manual_nodes}
DEPS ${eager_deps})
DEPS ${eager_deps} eager_prim_api)
add_dependencies(final_dygraph_node eager_codegen)
endif()
......@@ -7,6 +7,6 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
final_dygraph_function
SRCS dygraph_functions.cc ${eager_manual_functions}
DEPS ${eager_deps})
DEPS ${eager_deps} final_dygraph_node)
add_dependencies(final_dygraph_function eager_codegen)
endif()
......@@ -403,6 +403,23 @@ def ParseYamlInplaceInfo(string):
return inplace_map
def ParseYamlCompositeInfo(string):
# example: composite: fun(args1, args2, ...)
fname = r'(.*?)'
wspace = r'\s*'
fargs = r'(.*?)'
pattern = fr'{fname}{wspace}\({wspace}{fargs}{wspace}\)'
m = re.search(pattern, string)
composite_fun_info = []
composite_fun_info.append(m.group(1))
func_args = m.group(2).split(",")
for fun_arg in func_args:
composite_fun_info.append(fun_arg.strip())
return composite_fun_info
####################
# Generator Base #
####################
......@@ -438,6 +455,7 @@ class FunctionGeneratorBase:
# Special Op Attributes
self.optional_inputs = [] # [name, ...]
self.no_need_buffers = [] # [name, ...]
self.composite_func_info = [] # [func_name, input_name, ...]
self.intermediate_outputs = [] # [name, ...]
self.forward_inplace_map = {} # {name : name, ...}
......@@ -459,6 +477,13 @@ class FunctionGeneratorBase:
name = RemoveSpecialSymbolsInName(name)
self.no_need_buffers.append(name.strip())
def ParseComposite(self):
grad_api_contents = self.grad_api_contents
if 'composite' in grad_api_contents.keys():
composite_str = grad_api_contents['composite']
self.composite_func_info = ParseYamlCompositeInfo(composite_str)
def ParseDispensable(self):
forward_api_contents = self.forward_api_contents
......
......@@ -332,6 +332,9 @@ NODE_CC_FILE_TEMPLATE = """
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/all.h"
#include "paddle/fluid/prim/utils/utils.h"
DECLARE_bool(check_nan_inf);
{}
"""
......@@ -546,6 +549,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
# self.forward_outputs_position_map
# self.optional_inputs
# self.no_need_buffers
# self.composite_func_info
# self.intermediate_outputs
# self.forward_inplace_map
FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)
......@@ -871,6 +875,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs
is_composite_grad_api = (
False if self.composite_func_info == [] else True
)
# Pass Stop Gradient Args
pass_stop_gradient_args_str = self.GetPassStopGradientArgsList(
......@@ -1056,6 +1063,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.ParseBackwardInplaceInfo()
# Parse no_need_buffer
self.ParseNoNeedBuffer()
# Parse composite
self.ParseComposite()
# Parse optional_inputs
self.ParseDispensable()
......@@ -1826,16 +1835,25 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
is_invoke_forward_api = IsInvokeForwardApi(
self.grad_api_contents, self.forward_apis_dict
)
is_composite_grad_api = (
False if self.composite_func_info == [] else True
)
if next_node_generator is not None:
has_higher_order_node = True
return (
has_higher_order_node,
is_invoke_forward_api,
is_composite_grad_api,
next_grad_node_creation_str,
next_grad_node_out_list,
next_node_generator.backward_forward_inputs_map,
)
elif not is_invoke_forward_api:
# TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered
# TODO(Ruting): modify next_grad_node_creation_str when Flags_prim_enable deleted in the future
# if is_composite_grad_api:
# next_grad_node_creation_str = ''
elif not is_invoke_forward_api and not is_composite_grad_api:
next_grad_node_creation_str = f""" if(trace_backward) {{
PADDLE_THROW(phi::errors::Unavailable(
\"The Op {self.backward_api_name} doesn't have any grad\"
......@@ -1845,6 +1863,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
return (
has_higher_order_node,
is_invoke_forward_api,
is_composite_grad_api,
next_grad_node_creation_str,
next_grad_node_out_list,
None,
......@@ -1942,6 +1961,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
self,
has_higher_order_node,
is_invoke_forward_api,
is_composite_grad_api,
next_grad_node_creation_str,
next_grad_node_out_list,
backward_forward_inputs_map_next,
......@@ -1949,6 +1969,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
namespace = self.namespace
forward_api_name = self.forward_api_name
backward_api_name = self.backward_api_name
composite_grad_api_name = (
self.composite_func_info[0] if is_composite_grad_api else None
)
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
......@@ -2133,6 +2156,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Grad Function Call String
slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
grad_api_namespace = f"paddle::experimental::{namespace}"
composite_grad_api_namespace = f"paddle::prim::{namespace}"
grad_function_prepare_str = f"""
const auto& out_metas = OutputMeta();
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> returns({slot_num_bwd_outputs});
......@@ -2203,6 +2227,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
}}"""
grad_api_args_str = ", ".join(grad_api_args)
composite_grad_api_args_str = ", ".join(grad_api_args)
composite_template_name = "<paddle::experimental::Tensor>"
if is_invoke_forward_api:
autograd_api_out = "auto"
......@@ -2225,6 +2251,17 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
{out_assign_str}}} else {{
{indent}{autograd_api_out} api_output = paddle::experimental::{self.namespace}{self.grad_api_contents['invoke']};
{out_assign_str}{indent}}}
"""
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api:
grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << paddle::string::Sprintf("composite api %s is called" , "{composite_grad_api_name}");
}}else{{
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});
VLOG(4) << paddle::string::Sprintf("origin api %s is called" , "{backward_api_name}");
}}
"""
else:
grad_function_call_str = f"""
......@@ -2361,6 +2398,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
(
has_higher_order_node,
is_invoke_forward_api,
is_composite_grad_api,
next_grad_node_creation_str,
next_grad_node_out_list,
backward_forward_inputs_map,
......@@ -2371,6 +2409,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
self.GenerateNodeDefinition(
has_higher_order_node,
is_invoke_forward_api,
is_composite_grad_api,
next_grad_node_creation_str,
next_grad_node_out_list,
backward_forward_inputs_map,
......
cc_library(
performance_benchmark_utils
SRCS benchmark_utils.cc
DEPS ${eager_deps}
${fluid_deps}
${generated_deps}
eager_scale
scale_node
scale_op
matmul_v2_op
dygraph_function)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
performance_benchmark_utils
SRCS benchmark_utils.cc
DEPS ${eager_deps}
${fluid_deps}
${generated_deps}
eager_scale
scale_node
scale_op
matmul_v2_op
dygraph_function
eager_prim_api)
cc_test_old(
test_egr_performance_benchmark_eager_cpu
SRCS
benchmark_eager_cpu.cc
DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
cc_test_old(
test_egr_performance_benchmark_fluid_cpu
SRCS
benchmark_fluid_cpu.cc
DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
cc_test_old(
test_egr_performance_benchmark_eager_cpu
SRCS
benchmark_eager_cpu.cc
DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
cc_test_old(
test_egr_performance_benchmark_fluid_cpu
SRCS
benchmark_fluid_cpu.cc
DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
cc_test_old(
test_egr_performance_benchmark_eager_cuda
SRCS
benchmark_eager_cuda.cc
DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
cc_test_old(
test_egr_performance_benchmark_fluid_cuda
SRCS
benchmark_fluid_cuda.cc
DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
cc_test_old(
test_egr_performance_benchmark_eager_cuda
SRCS
benchmark_eager_cuda.cc
DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
cc_test_old(
test_egr_performance_benchmark_fluid_cuda
SRCS
benchmark_fluid_cuda.cc
DEPS
performance_benchmark_utils
${eager_deps}
${fluid_deps})
endif()
......@@ -32,6 +32,7 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/imperative/dygraph_grad_maker.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
namespace paddle {
namespace framework {
......@@ -46,6 +47,7 @@ enum OpInfoFillType {
kInplaceOpInference = 5,
kNoNeedBufferVarsInference = 6,
kGradOpBaseMaker = 7,
kGradCompOpDescMaker = 8,
kUnknown = -1
};
......@@ -61,6 +63,7 @@ using OpRegistryClasses = std::tuple< // NOLINT
TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT
TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT
TypePair<imperative::GradOpBaseMakerBase, kGradOpBaseMaker>, // NOLINT
TypePair<prim::GradCompositeOpMakerBase, kGradCompOpDescMaker>, // NOLINT
TypePair<VarTypeInference, kVarTypeInference>, // NOLINT
TypePair<InferShapeBase, kShapeInference>, // NOLINT
TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT
......@@ -252,6 +255,30 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
}
};
template <typename T>
struct OpInfoFiller<T, kGradCompOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->grad_comp_op_maker_,
nullptr,
platform::errors::AlreadyExists(
"GradCompositeOpMakerBase of %s has been registered", op_type));
info->grad_comp_op_maker_ =
[](const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var,
const BlockDesc* current_block,
const std::vector<BlockDesc*>& grad_block) {
T maker(fwd_op, no_grad_set, grad_to_var, current_block, grad_block);
return maker();
};
// TODO(jiabin): Support this later or just not.
info->use_default_grad_op_desc_maker_ = false;
info->use_empty_grad_op_desc_maker_ = false;
}
};
template <typename T>
struct OpInfoFiller<T, kGradOpBaseMaker> {
void operator()(const char* op_type, OpInfo* info) const {
......
......@@ -200,6 +200,8 @@ class OpDesc {
OperatorDistAttr *MutableDistAttr();
void SetDistAttr(const OperatorDistAttr &dist_attr);
void ResetBlock() { this->block_ = nullptr; }
private:
friend class ProgramDesc;
// Find VarDesc from OpDesc located Block into global Block
......
......@@ -43,6 +43,7 @@ class OpInfo {
public:
OpCreator creator_;
GradOpMakerFN grad_op_maker_;
GradCompositeOpMakerFN grad_comp_op_maker_;
proto::OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_;
......@@ -81,20 +82,10 @@ class OpInfo {
return creator_;
}
const GradOpMakerFN& GradOpMaker() const {
// Normally, proto_ should not be null, except some special operators, such
// as LeaklyReluDoubleGrad op.
std::string type = proto_ ? proto_->type() : "unknown";
PADDLE_ENFORCE_NOT_NULL(
grad_op_maker_,
platform::errors::NotFound(
"Operator %s's GradOpMaker has not been "
"registered.\nPlease check whether (%s) operator has "
"gradient operator.\nIf not, please set stop_gradient to be True "
"for its input and output variables using var.stop_gradient=True.",
type.c_str(),
type.c_str()));
return grad_op_maker_;
const GradOpMakerFN& GradOpMaker() const { return grad_op_maker_; }
const GradCompositeOpMakerFN& GradCompOpMaker() const {
return grad_comp_op_maker_;
}
// some ops don't have grad_op_maker, add check before use GradOpMaker()
......
......@@ -96,6 +96,14 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
std::unordered_map<std::string, std::string>* /*grad_to_var*/,
const std::vector<BlockDesc*>& grad_block)>;
using GradCompositeOpMakerFN =
std::function<std::vector<std::unique_ptr<OpDesc>>(
const OpDesc&,
const std::unordered_set<std::string>& /*no_grad_set*/,
std::unordered_map<std::string, std::string>* /*grad_to_var*/,
const BlockDesc*,
const std::vector<BlockDesc*>& grad_block)>;
using DygraphGradOpMakerFN =
std::function<std::shared_ptr<imperative::GradOpNode>(
const std::string& /*op_type*/,
......
......@@ -95,7 +95,7 @@ if(WITH_UNITY_BUILD)
include(unity_build_rule.cmake)
endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils backward_infermeta sparse_backward_infermeta)
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils backward_infermeta sparse_backward_infermeta static_prim_api)
register_operators(EXCLUDES py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
......@@ -147,7 +147,7 @@ cc_library(ops_extra_info SRCS ops_extra_info.cc DEPS attribute cudnn_workspace_
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function
lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling executor generator)
sequence_pooling executor generator static_prim_api)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc_functor matrix_inverse matrix_solve)
......
......@@ -156,6 +156,13 @@ set(generated_static_files
"${generated_static_argument_mapping_path}"
"${generated_sparse_argument_mapping_path}")
set(generated_static_files
"${generated_op_path}"
"${generated_static_op_path}"
"${generated_sparse_ops_path}"
"${generated_argument_mapping_path}"
"${generated_static_argument_mapping_path}"
"${generated_sparse_argument_mapping_path}")
foreach(generated_static_file ${generated_static_files})
if(EXISTS "${generated_static_file}.tmp" AND EXISTS
"${generated_static_file}")
......
......@@ -132,6 +132,17 @@ def to_int_array_tensors_name(attr):
return to_pascal_case(attr['name']) + 'TensorList'
def to_composite_grad_opmaker_name(backward_op_name):
words = backward_op_name.split("_")
for i in range(len(words)):
words[i] = words[i].strip()
words[i] = words[i].capitalize()
composite_grad_opmaker_name = words[0] + "Composite"
composite_grad_opmaker_name += "".join(word for word in words[1:])
composite_grad_opmaker_name += "OpMaker"
return composite_grad_opmaker_name
def cartesian_prod_attrs(attrs):
items = []
for attr in attrs:
......
......@@ -19,6 +19,7 @@ from pathlib import Path
import yaml
from filters import (
cartesian_prod_mapping,
to_composite_grad_opmaker_name,
to_input_name,
to_int_array_tensor_name,
to_int_array_tensors_name,
......@@ -32,6 +33,7 @@ from jinja2 import Environment, FileSystemLoader, StrictUndefined
from parse_utils import to_named_dict
from tests import (
is_base_op,
is_composite_op,
is_initializer_list,
is_scalar,
is_vec,
......@@ -57,7 +59,9 @@ env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
......@@ -153,6 +157,27 @@ def process_int_array(op_item, int_array_configs):
]
def parse_composite_info(ops, backward_ops, backward_op_dict):
for op in ops:
if "backward" in op:
op["phi_backward"] = op["backward"]
for backward_op in backward_ops:
if "backward" in backward_op:
backward_op["phi_backward"] = backward_op["backward"]
for backward_op_name, op_dict in backward_op_dict.items():
if "composite" not in op_dict:
continue
op_dict["composite"]["phi_inputs"] = []
op_dict["composite"]["phi_attrs"] = []
op_dict["composite"]["phi_outputs"] = []
for input in op_dict["inputs"]:
op_dict["composite"]["phi_inputs"].append(input['name'])
for attr in op_dict["attrs"]:
op_dict["composite"]["phi_attrs"].append(attr['name'])
for output in op_dict["outputs"]:
op_dict["composite"]["phi_outputs"].append(output['name'])
# replace name of op and params for OpMaker
def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
def get_phi_and_fluid_op_name(op_item):
......@@ -178,6 +203,37 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
)
item['name'] = args_alias_map[item['name'][:-5]] + '_grad'
def add_fluid_info_in_composite(composite_map, args_alias_map):
fluid_input_list = []
fluid_attr_list = []
fluid_output_list = []
# add fluid op inputs
for input in composite_map["phi_inputs"]:
if input in args_alias_map:
fluid_input_list.append(args_alias_map[input])
else:
fluid_input_list.append(input)
# add fluid op attrs
for attr in composite_map["phi_attrs"]:
if attr in args_alias_map:
fluid_attr_list.append(args_alias_map[attr])
else:
fluid_attr_list.append(attr)
# add fluid op outputs
for output in composite_map["phi_outputs"]:
if output in args_alias_map:
fluid_output_list.append(args_alias_map[output])
else:
fluid_output_list.append(output)
composite_map.update(
{
"fluid_inputs": fluid_input_list,
"fluid_attrs": fluid_attr_list,
"fluid_outputs": fluid_output_list,
}
)
def get_param_list_alias(param_list, args_map):
return [
args_map[param] if param in args_map else param
......@@ -307,6 +363,15 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
continue
backward_op_list = op_args['backward'].split(',')
# add fluid args name in composite map
for backward_op in backward_op_list:
if (
"composite"
in backward_op_dict[backward_op.split('(')[0].strip()]
):
add_fluid_info_in_composite(
backward_op_dict[backward_op]["composite"], args_map
)
_, bw_op_name = get_phi_and_fluid_op_name(backward_op_list[0])
forward_op_item['backward'] = bw_op_name
backward_op_item['op_name'] = bw_op_name
......@@ -406,12 +471,10 @@ def main(
ops = yaml.safe_load(f)
ops = [restruct_io(op) for op in ops]
forward_op_dict = to_named_dict(ops)
with open(backward_yaml_path, "rt") as f:
backward_ops = yaml.safe_load(f)
backward_ops = [restruct_io(op) for op in backward_ops]
backward_op_dict = to_named_dict(backward_ops)
with open(op_version_yaml_path, "rt") as f:
op_versions = yaml.safe_load(f)
# add op version info into op
......@@ -426,6 +489,8 @@ def main(
for bw_op in backward_ops:
bw_op['op_name'] = bw_op['name']
parse_composite_info(ops, backward_ops, backward_op_dict)
replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict)
# prepare for invoke case
......@@ -442,21 +507,21 @@ def main(
op_dict = {}
op_dict.update(forward_op_dict)
op_dict.update(backward_op_dict)
if len(ops) == 0 and len(backward_ops) == 0:
if os.path.isfile(output_op_path):
os.remove(output_op_path)
if os.path.isfile(output_arg_map_path):
os.remove(output_arg_map_path)
return
op_template = env.get_template('op.c.j2')
with open(output_op_path, "wt") as f:
msg = op_template.render(
ops=ops, backward_ops=backward_ops, op_dict=op_dict
ops=ops,
backward_ops=backward_ops,
op_dict=op_dict,
composite_gen_flag=True,
)
f.write(msg)
ks_template = env.get_template('ks.c.j2')
with open(output_arg_map_path, 'wt') as f:
msg = ks_template.render(ops=ops, backward_ops=backward_ops)
......
......@@ -19,6 +19,7 @@ from pathlib import Path
import yaml
from filters import (
cartesian_prod_mapping,
to_composite_grad_opmaker_name,
to_input_name,
to_int_array_tensor_name,
to_int_array_tensors_name,
......@@ -58,6 +59,7 @@ env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.tests["base_op"] = is_base_op
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
......@@ -134,7 +136,10 @@ def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path):
op_template = env.get_template('sparse_op.c.j2')
with open(output_op_path, "wt") as f:
msg = op_template.render(
ops=ops, backward_ops=backward_ops, op_dict=op_dict
ops=ops,
backward_ops=backward_ops,
op_dict=op_dict,
composite_gen_flag=False,
)
f.write(msg)
......
......@@ -19,6 +19,7 @@ from pathlib import Path
import yaml
from filters import (
cartesian_prod_mapping,
to_composite_grad_opmaker_name,
to_input_name,
to_int_array_tensor_name,
to_int_array_tensors_name,
......@@ -58,6 +59,7 @@ env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.tests["base_op"] = is_base_op
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
......@@ -111,7 +113,10 @@ def main(
op_template = env.get_template('op.c.j2')
with open(output_op_path, "wt") as f:
msg = op_template.render(
ops=ops, backward_ops=[], op_dict=forward_op_dict
ops=ops,
backward_ops=[],
op_dict=forward_op_dict,
composite_gen_flag=False,
)
f.write(msg)
......
......@@ -289,6 +289,26 @@ def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]:
return forward_cfg
def parse_composite(
op_name: str,
composite_config: str,
) -> Dict[str, Any]:
# composite_config: func(args1, args2,.....)
fname = r'(.*?)'
wspace = r'\s*'
fargs = r'(.*?)'
pattern = fr'{fname}{wspace}\({wspace}{fargs}{wspace}\)'
m = re.search(pattern, composite_config)
func_name = m.group(1)
func_args = m.group(2)
composite_dict = {}
composite_dict["func_name"] = func_name
composite_dict["func_args"] = func_args
return composite_dict
def check_op_config(op_entry, op_name):
base_key_set = (
'op',
......@@ -306,6 +326,7 @@ def check_op_config(op_entry, op_name):
'intermediate',
'no_need_buffer',
'data_transform',
'composite',
)
infer_meta_key_set = ('func', 'param')
kernel_key_set = ('func', 'param', 'data_type', 'layout', 'backend')
......@@ -331,9 +352,9 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
op_name = op_entry[name_field]
inputs, attrs = parse_input_and_attr(op_name, op_entry["args"])
outputs = parse_outputs(op_name, op_entry["output"])
if "composite" in op_entry:
composite_dict = parse_composite(op_name, op_entry["composite"])
check_op_config(op_entry, op_name)
# validate default value of DataType and DataLayout
for attr in attrs:
if "default_value" in attr:
......@@ -441,6 +462,10 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
invoke = parse_invoke(op_name, op_entry["invoke"])
op["invoke"] = invoke
# has composite ?
if "composite" in op_entry:
op.update({"composite": composite_dict})
# backward
if "backward" in op_entry:
backward = op_entry["backward"]
......
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version %}
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version, composite_grad_op_maker %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -36,6 +39,11 @@ using paddle::framework::GradVarName;
{% else %}
{{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %}
{% if composite_gen_flag == True %}
{% if op is composite_op %}
{{composite_grad_op_maker(op_dict[op["name"]])}}
{% endif %}
{% endif %}
{% endfor %}
} // namespace operators
} // namespace paddle
......@@ -43,7 +51,7 @@ using paddle::framework::GradVarName;
namespace ops = paddle::operators;
{% for op in ops + backward_ops %}
{% if op is base_op %}
{{register_op_with_components(op)}}
{{register_op_with_components(op, op_dict)}}
{{register_op_version(op)}}
{% endif %}
{% endfor %}
......@@ -315,8 +315,9 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER({{op["op_name"] | to_pascal_case}}NoNeedBuff
{% endif %}
{% endmacro%}
{% macro register_op_with_components(op) %}
{% macro register_op_with_components(op, op_dict) %}
{% set name = op["op_name"] %}
{% set phi_name = op["name"] %}
REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if not "forward" in op %}{# it is a forward op #}
ops::{{name | to_pascal_case}}OpMaker,
......@@ -332,6 +333,9 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if op is supports_inplace %}{# inplace#}
ops::{{name | to_pascal_case}}InplaceInferer,
{% endif %}
{% if "phi_backward" in op and op["phi_backward"] is not none and "composite" in op_dict[op["phi_backward"]] %}
ops::{{op["phi_backward"] | to_composite_grad_opmaker_name}},
{% endif %}
{% if op is supports_no_need_buffer %}{# no_need_buffer #}
ops::{{name | to_pascal_case}}NoNeedBufferVarInferer,
{% endif %}
......@@ -486,6 +490,155 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
};
{% endmacro %}
{% macro composite_grad_op_maker(composite_op_dict) %}
{% set op_name = composite_op_dict["name"] %}
class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeOpMakerBase {
public:
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;
void Apply() override {
//get inputs
{{construct_composite_input(composite_op_dict)}}
//get attr
{{construct_composite_attr(composite_op_dict)}}
//get output
{{construct_composite_output(composite_op_dict)}}
//get output ptr
{{construct_composite_output_ptr(composite_op_dict)}}
//get output orginal name
{{get_composite_output_orginal_name(composite_op_dict)}}
//call composite backward func
{{call_composite_backward_api(composite_op_dict)}}
//recover output name
{{recover_composite_output_name(composite_op_dict)}}
}
};
{%- endmacro %}
{% macro construct_composite_input(composite_op_dict) %}
{% set inputs = composite_op_dict["composite"]["phi_inputs"] %}
{% set input_dict = composite_op_dict["input_dict"] %}
{% set fluid_inputs = composite_op_dict["composite"]["fluid_inputs"] %}
{% set forward_fluid_inputs = composite_op_dict["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_fluid_outputs = composite_op_dict["forward"]["outputs"] | map(attribute="name") | list %}
{% set inputs_length = inputs | length %}
{% for i in range(inputs_length) %}
{% set input_typename = input_dict[inputs[i]]["typename"] %}
{% set input_optional_flag = input_dict[inputs[i]]["optional"] %}
{% if fluid_inputs[i] in forward_fluid_inputs %}
{% if input_typename == "Tensor" %}
{% if input_optional_flag == True %}
paddle::optional<paddle::experimental::Tensor> {{inputs[i]}} = this->GetOptionalSingleForwardInput("{{fluid_inputs[i]}}");
{% elif input_optional_flag == False %}
paddle::experimental::Tensor {{inputs[i]}} = this->GetSingleForwardInput("{{fluid_inputs[i]}}");
{% endif %}
{% elif input_typename == "Tensor[]" %}
{% if input_optional_flag == True %}
std::vector<paddle::optional<paddle::experimental::Tensor>> {{inputs[i]}} = this->GetOptionalMultiForwardInput("{{fluid_inputs[i]}}");
{% elif input_optional_flag == False %}
std::vector<paddle::experimental::Tensor> {{inputs[i]}} = this->GetMultiForwardInput("{{fluid_inputs[i]}}");
{% endif %}
{% endif %}
{% elif fluid_inputs[i] in forward_fluid_outputs %}
{% if input_typename == "Tensor" %}
{% if input_optional_flag == True %}
paddle::optional<paddle::experimental::Tensor> {{inputs[i]}} = this->GetOptionalSingleForwardOutput("{{fluid_inputs[i]}}");
{% elif input_optional_flag == False %}
paddle::experimental::Tensor {{inputs[i]}} = this->GetSingleForwardOutput("{{fluid_inputs[i]}}");
{% endif %}
{% elif input_typename == "Tensor[]" %}
{% if input_optional_flag == True %}
std::vector<paddle::optional<paddle::experimental::Tensor>> {{inputs[i]}} = this->GetOptionalMultiForwardOutput("{{fluid_inputs[i]}}");
{% elif input_optional_flag == False %}
std::vector<paddle::experimental::Tensor> {{inputs[i]}} = this->GetMultiForwardOutput("{{fluid_inputs[i]}}");
{% endif %}
{% endif %}
{% elif fluid_inputs[i][:-5] in forward_fluid_outputs %}
{% if input_typename == "Tensor" %}
{% if input_optional_flag == True %}
paddle::optional<paddle::experimental::Tensor> {{inputs[i]}} = this->GetOptionalSingleOutputGrad("{{fluid_inputs[i][:-5]}}");
{% elif input_optional_flag == False %}
paddle::experimental::Tensor {{inputs[i]}} = this->GetSingleOutputGrad("{{fluid_inputs[i][:-5]}}");
{% endif %}
{% elif input_typename == "Tensor[]" %}
{% if input_optional_flag == True %}
std::vector<paddle::optional<paddle::experimental::Tensor>> {{inputs[i]}} = this->GetOptionalMultiOutputGrad("{{fluid_inputs[i][:-5]}}");
{% elif input_optional_flag == False %}
std::vector<paddle::experimental::Tensor> {{inputs[i]}} = this->GetMultiOutputGrad("{{fluid_inputs[i][:-5]}}");
{%- endif %}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- endmacro %}
{% macro construct_composite_attr(composite_op_dict) %}
{% set attrs = composite_op_dict["composite"]["phi_attrs"] %}
{% set fluid_attrs = composite_op_dict["composite"]["fluid_attrs"] %}
{% set fluid_attrs_dict = composite_op_dict["attr_dict"] %}
{% set attrs_length = attrs | length %}
{% for i in range(attrs_length) %}
{% set attrs_data_type = fluid_attrs_dict[fluid_attrs[i]]["typename"] | to_op_attr_type %}
{{attrs_data_type}} {{attrs[i]}} = this->Attr<{{attrs_data_type}}>("{{fluid_attrs[i]}}");
{% endfor %}
{%- endmacro %}
{% macro construct_composite_output(composite_op_dict) %}
{% set outputs = composite_op_dict["composite"]["phi_outputs"] %}
{% set fluid_outputs = composite_op_dict["composite"]["fluid_outputs"] %}
{% set outputs_dict = composite_op_dict["output_dict"] %}
{% set outputs_length = outputs | length %}
{% for i in range(outputs_length) %}
{% set output_typename = outputs_dict[outputs[i]]["typename"] %}
{% if output_typename == "Tensor" %}
paddle::experimental::Tensor {{outputs[i] + "_t"}} = this->GetSingleInputGrad("{{fluid_outputs[i][:-5]}}");
{% elif output_typename == "Tensor[]" %}
std::vector<paddle::experimental::Tensor> {{outputs[i] + "_t"}} = this->GetMultiInputGrad("{{fluid_outputs[i][:-5]}}");
{%- endif %}
{%- endfor %}
{%- endmacro %}
{% macro construct_composite_output_ptr(composite_op_dict) %}
{% set outputs = composite_op_dict["composite"]["phi_outputs"] %}
{% set outputs_dict = composite_op_dict["output_dict"] %}
{% set outputs_length = outputs | length %}
{% for i in range(outputs_length) %}
{% set output_typename = outputs_dict[outputs[i]]["typename"] %}
{% if output_typename == "Tensor" %}
paddle::experimental::Tensor* {{outputs[i]}} = this->GetOutputPtr(&{{outputs[i]+ "_t"}});
{% elif output_typename == "Tensor[]" %}
std::vector<paddle::experimental::Tensor*> {{outputs[i]}}({{outputs[i] + "_t"}}.size());
for(size_t i = 0; i < {{outputs[i]}}.size(); ++i){
{{outputs[i]}}[i] = &{{outputs[i] + "_t"}}[i];
}
{{outputs[i]}} = this->GetOutputPtr({{outputs[i]}});
{%- endif %}
{%- endfor %}
{%- endmacro %}
{% macro get_composite_output_orginal_name(composite_op_dict) %}
{% set outputs = composite_op_dict["composite"]["phi_outputs"] %}
{% set outputs_dict = composite_op_dict["output_dict"] %}
{% set outputs_length = outputs | length %}
{% for i in range(outputs_length) %}
{% set output_typename = outputs_dict[outputs[i]]["typename"] %}
{% if output_typename == "Tensor" %}
std::string {{outputs[i] + "_name"}} = this->GetOutputName({{outputs[i] + "_t"}});
{% elif output_typename == "Tensor[]" %}
std::vector<std::string> {{outputs[i] + "_name"}} = this->GetOutputName({{outputs[i] + "_t"}});
{%- endif %}
{%- endfor %}
{%- endmacro %}
{% macro call_composite_backward_api(composite_op_dict) %}
prim::{{composite_op_dict["composite"]["func_name"]}}<prim::DescTensor>({{composite_op_dict["composite"]["func_args"]}});
{%- endmacro %}
{% macro recover_composite_output_name(composite_op_dict) %}
{% set outputs = composite_op_dict["composite"]["phi_outputs"] %}
{% set outputs_length = outputs | length %}
{% for i in range(outputs_length) %}
this->RecoverOutputName({{outputs[i] + "_t"}}, {{outputs[i] + "_name"}});
{% endfor %}
{%- endmacro %}
{% macro extract_input_from_forward(name,
input_names, output_names,
......
......@@ -46,6 +46,10 @@ def is_base_op(op):
return "kernel" in op and "infer_meta" in op
def is_composite_op(op):
return "composite" in op
def supports_selected_rows_kernel(op):
return is_base_op(op) and len(op["kernel"]["func"]) == 2
......
add_subdirectory(api)
add_subdirectory(utils)
add_subdirectory(tests)
set(static_prim_deps prim_utils static_global_utils static_utils
static_prim_api)
set(eager_prim_deps prim_utils eager_prim_api)
add_subdirectory(manual)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
prim_api
SRCS all.cc
DEPS static_utils static_prim_api eager_prim_api eager_api)
else()
cc_library(
prim_api
SRCS all.cc
DEPS static_utils static_prim_api)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/prim/api/all.h"
namespace paddle {
namespace prim {} // namespace prim
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
add_subdirectory(prim_api)
add_subdirectory(utils)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
namespace paddle {
namespace prim {
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
auto tmp = pow<T>(out, 2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = multiply<T>(grad_out, tmp);
grad_x->set_impl(grad_x_tmp.impl());
}
} // namespace prim
} // namespace paddle
cc_library(
static_prim_api
SRCS static_prim_api.cc
DEPS proto_desc static_utils)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
eager_prim_api
SRCS eager_prim_api.cc
DEPS final_dygraph_function eager_utils)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
namespace paddle {
namespace prim {
template <>
Tensor pow<Tensor>(const Tensor& x, const paddle::experimental::Scalar& y) {
return ::pow_ad_func(x, y);
}
template <>
Tensor scale<Tensor>(const Tensor& x,
const paddle::experimental::Scalar& scale,
float bias,
bool bias_after_scale) {
return ::scale_ad_func(x, scale, bias, bias_after_scale);
}
template <>
Tensor multiply<Tensor>(const Tensor& x, const Tensor& y) {
return ::multiply_ad_func(x, y);
}
} // namespace prim
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/optional.h"
namespace paddle {
namespace prim {
using Tensor = paddle::experimental::Tensor;
template <typename T>
Tensor pow(const Tensor& x, const paddle::experimental::Scalar& y);
template <typename T>
Tensor scale(const Tensor& X,
const paddle::experimental::Scalar& scale,
float bias,
bool bias_after_scale);
template <typename T>
Tensor multiply(const Tensor& x, const Tensor& y);
} // namespace prim
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace prim {
template <>
Tensor pow<DescTensor>(const Tensor& x, const paddle::experimental::Scalar& y) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("pow");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("factor", y.to<float>());
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor scale<DescTensor>(const Tensor& x,
const paddle::experimental::Scalar& scale,
float bias,
bool bias_after_scale) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("scale");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("scale", scale.to<float>());
op->SetAttr("bias", bias);
op->SetAttr("bias_after_scale", bias_after_scale);
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor multiply<DescTensor>(const Tensor& x, const Tensor& y) {
// Grad infershape
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("elementwise_mul");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetInput("Y",
{std::static_pointer_cast<prim::DescTensor>(y.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
} // namespace prim
} // namespace paddle
cc_library(
static_utils
SRCS static_utils.cc
DEPS proto_desc operator static_global_utils)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
eager_utils
SRCS eager_utils.cc
DEPS final_dygraph_function)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace prim {
template <>
Tensor empty<Tensor>(const paddle::experimental::IntArray& shape,
paddle::experimental::DataType dtype,
const paddle::Place& place) {
if (dtype == paddle::experimental::DataType::UNDEFINED) {
dtype = paddle::experimental::DataType::FLOAT32;
}
return empty_ad_func(shape, dtype, place);
}
template <>
Tensor empty_like<Tensor>(const paddle::experimental::Tensor& x,
paddle::experimental::DataType dtype,
const paddle::Place& place) {
if (dtype == paddle::experimental::DataType::UNDEFINED) {
dtype = paddle::experimental::DataType::FLOAT32;
}
return empty_like_ad_func(x, dtype, place);
}
} // namespace prim
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/utils/data_type.h"
namespace paddle {
namespace prim {
using Tensor = paddle::experimental::Tensor;
template <>
Tensor empty<DescTensor>(const paddle::experimental::IntArray& shape,
paddle::experimental::DataType dtype,
const paddle::Place& place) {
framework::VarDesc* new_var =
StaticCompositeContext::Instance().GetBlock()->Var(
std::move(StaticCompositeContext::Instance().GenerateUniqueName()));
new_var->SetShape(shape.GetData());
new_var->SetDataType(framework::TransToProtoVarType(dtype));
// Place is not supported in static mode
return Tensor(std::make_shared<prim::DescTensor>(new_var));
}
template <>
Tensor empty_like<DescTensor>(const Tensor& x,
paddle::experimental::DataType dtype,
const paddle::Place& place) {
return empty<prim::DescTensor>(
paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place());
}
} // namespace prim
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
namespace paddle {
namespace prim {
// We put some api like utils here
template <typename T>
paddle::experimental::Tensor empty(const paddle::experimental::IntArray& shape,
paddle::experimental::DataType dype,
const paddle::Place& place);
template <typename T>
paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x,
paddle::experimental::DataType dtype,
const paddle::Place& place);
} // namespace prim
} // namespace paddle
set(prim_eager_deps
phi_api
phi_dygraph_api
hook_utils
tensor_utils
utils
global_utils
backward
phi_tensor
tracer
layer
autograd_meta
eager_nan_inf_utils
grad_node_info
grad_tensor_holder
custom_operator_node)
set(prim_generated_deps final_dygraph_function final_dygraph_node
dygraph_function dygraph_node)
cc_test_old(
test_static_prim
SRCS
test_static_prim.cc
DEPS
static_utils
static_prim_api
generated_op
prim_utils
operator
elementwise_mul_op
scale_op
activation_op
phi_api
phi_dygraph_api
static_global_utils)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_test_old(
test_eager_prim
SRCS
test_eager_prim.cc
DEPS
${prim_eager_deps}
${prim_generated_deps}
prim_utils
static_global_utils)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/utils/hook_utils.h"
#include "paddle/fluid/eager/backward.h"
#include "paddle/fluid/eager/tests/test_utils.h"
#include "paddle/fluid/prim/utils/utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
#endif
namespace paddle {
namespace prim {
TEST(EagerPrim, TanhBackwardTest) {
// 1. Initialized
eager_test::InitEnv(paddle::platform::CPUPlace());
// 2. pre
paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32});
paddle::experimental::Tensor tensor0 =
::egr::egr_utils_api::CreateTensorWithValue(ddim,
paddle::platform::CPUPlace(),
phi::DataType::FLOAT32,
phi::DataLayout::NCHW,
5.0 /*value*/,
true /*is_leaf*/);
::egr::egr_utils_api::RetainGradForTensor(tensor0);
paddle::experimental::Tensor tensor1 =
::egr::egr_utils_api::CreateTensorWithValue(ddim,
paddle::platform::CPUPlace(),
phi::DataType::FLOAT32,
phi::DataLayout::NCHW,
5.0 /*value*/,
true /*is_leaf*/);
::egr::egr_utils_api::RetainGradForTensor(tensor1);
// 3. Run Forward once
paddle::experimental::Tensor out0 = tanh_ad_func(tensor0);
std::vector<paddle::experimental::Tensor> outs0 = {out0};
// Disable prim
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
// 4. Run Backward
egr::Backward(outs0, {}, false);
paddle::experimental::Tensor out1 = tanh_ad_func(tensor1);
std::vector<paddle::experimental::Tensor> outs1 = {out1};
// Disable prim
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
// 4. Run Backward
::egr::Backward(outs1, {}, false);
VLOG(7)
<< "Target Grad is: "
<< std::static_pointer_cast<phi::DenseTensor>(
::egr::EagerUtils::unsafe_autograd_meta(tensor0)->Grad().impl())
->data<float>()[0];
VLOG(7)
<< "Result Grad is: "
<< std::static_pointer_cast<phi::DenseTensor>(
::egr::EagerUtils::unsafe_autograd_meta(tensor1)->Grad().impl())
->data<float>()[0];
// Examine Backward Grad
eager_test::CompareGradTensorWithValue<float>(
tensor1,
std::static_pointer_cast<phi::DenseTensor>(
::egr::EagerUtils::unsafe_autograd_meta(tensor0)->Grad().impl())
->data<float>()[0]);
}
TEST(EagerPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
}
} // namespace prim
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/prim/utils/utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
DECLARE_bool(prim_enabled);
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
#endif
namespace paddle {
namespace prim {
using Tensor = paddle::experimental::Tensor;
struct TestBaseProgram {
public:
const framework::ProgramDesc& main_program() { return program_; }
std::string unique_name() { return "tmp_" + std::to_string(idx_++); }
framework::VarDesc* lod_tensor(std::string name,
std::vector<int64_t> shape = {},
bool is_persistable = false,
framework::proto::VarType::Type data_type =
framework::proto::VarType::FP32) {
auto* var = program_.MutableBlock(0)->Var(name);
var->SetType(framework::proto::VarType::LOD_TENSOR);
var->SetDataType(data_type);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var;
}
framework::VarDesc* unary_op(std::string type,
framework::VarDesc* x,
framework::VarDesc* out = nullptr,
const framework::AttributeMap* attrs = nullptr) {
if (!out) {
out = lod_tensor(unique_name());
}
framework::OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()});
if (attrs) {
for (auto& iter : *attrs) {
op->SetAttr(iter.first, iter.second);
}
}
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(framework::OpRole::kForward));
return out;
}
framework::VarDesc* tanh(framework::VarDesc* x,
framework::VarDesc* out = nullptr) {
return unary_op("tanh", x, out);
}
framework::BlockDesc* GetBlock(std::size_t id) {
return program_.MutableBlock(id);
}
void concat(std::vector<framework::VarDesc*> inputs,
int axis,
framework::VarDesc* out) {
framework::OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("concat");
std::vector<std::string> input_names(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
input_names[i] = inputs[i]->Name();
}
op->SetInput("X", input_names);
op->SetOutput("Out", {out->Name()});
op->SetAttr("axis", axis);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(framework::OpRole::kForward));
}
void split(framework::VarDesc* input,
int num,
int axis,
std::vector<framework::VarDesc*> outputs) {
framework::OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("split");
const std::string input_name = input->Name();
std::vector<std::string> output_names(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
output_names[i] = outputs[i]->Name();
}
op->SetInput("X", {input_name});
op->SetOutput("Out", output_names);
op->SetAttr("num", num);
op->SetAttr("axis", axis);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(framework::OpRole::kForward));
}
private:
framework::ProgramDesc program_;
int idx_{0};
};
class TestGradCompositeGradMaker : public GradCompositeOpMakerBase {
public:
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;
void Apply() override {}
};
TEST(StaticPrim, TanhBackwardComposite) {
TestBaseProgram base_program = TestBaseProgram();
auto* target_block = base_program.GetBlock(0);
// Prepare for forward tanh
std::vector<int64_t> shape = {2, 2};
StaticCompositeContext::Instance().SetBlock(target_block);
Tensor x = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
Tensor out = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* x_desc =
static_cast<prim::DescTensor*>(x.impl().get())->get_ptr();
target_block->RenameVar(x_desc->Name(), "a");
framework::VarDesc* out_desc =
static_cast<prim::DescTensor*>(out.impl().get())->get_ptr();
target_block->RenameVar(out_desc->Name(), "b");
// TODO(jiabin): Grad out should be created by full, we can test it later
base_program.tanh(target_block->FindVar("a"), target_block->FindVar("b"));
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "tanh");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X")[0], "a");
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out").size(),
std::size_t(1));
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");
ASSERT_EQ(target_block->AllVars().size(), static_cast<std::size_t>(2));
ASSERT_EQ(target_block->AllVars()[0]->Name(), "a");
ASSERT_EQ(target_block->AllVars()[1]->Name(), "b");
auto* forward_opdesc = target_block->AllOps()[0];
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<framework::BlockDesc*> grad_sub_block;
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops =
std::move(framework::OpInfoMap::Instance()
.Get(forward_opdesc->Type())
.GradCompOpMaker()(*forward_opdesc,
std::unordered_set<std::string>(),
&grad_to_var,
target_block,
grad_sub_block));
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(3));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "tanh");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X")[0], "a");
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");
ASSERT_EQ(grad_ops[0]->Type(), "pow");
ASSERT_EQ(grad_ops[0]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[0]->Inputs().at("X")[0], "b");
ASSERT_EQ(PADDLE_GET_CONST(float, grad_ops[0]->GetAttr("factor")),
static_cast<float>(2.0));
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Type(), "scale");
ASSERT_EQ(grad_ops[1]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("X")[0],
grad_ops[0]->Outputs().at("Out")[0]);
ASSERT_EQ(PADDLE_GET_CONST(float, grad_ops[1]->GetAttr("scale")),
static_cast<float>(-1.0));
ASSERT_EQ(PADDLE_GET_CONST(float, grad_ops[1]->GetAttr("bias")),
static_cast<float>(1.0));
ASSERT_EQ(grad_ops[1]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[2]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("Y")[0],
grad_ops[1]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[2]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
}
TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
TestBaseProgram base_program = TestBaseProgram();
auto* target_block = base_program.GetBlock(0);
std::vector<int64_t> shape = {2, 2};
std::vector<int64_t> shape_out = {4, 2};
StaticCompositeContext::Instance().SetBlock(target_block);
Tensor x0 = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
Tensor x1 = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
Tensor out = prim::empty<prim::DescTensor>(
shape_out, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* x0_desc =
static_cast<prim::DescTensor*>(x0.impl().get())->get_ptr();
target_block->RenameVar(x0_desc->Name(), "x0");
framework::VarDesc* x1_desc =
static_cast<prim::DescTensor*>(x1.impl().get())->get_ptr();
target_block->RenameVar(x1_desc->Name(), "x1");
framework::VarDesc* out_desc =
static_cast<prim::DescTensor*>(out.impl().get())->get_ptr();
target_block->RenameVar(out_desc->Name(), "out");
std::vector<framework::VarDesc*> inputs = {target_block->FindVar("x0"),
target_block->FindVar("x1")};
framework::VarDesc* output = target_block->FindVar("out");
base_program.concat(inputs, 0, output);
auto* forward_opdesc = target_block->AllOps()[0];
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<framework::BlockDesc*> grad_sub_block;
auto test = TestGradCompositeGradMaker(*forward_opdesc,
std::unordered_set<std::string>(),
&grad_to_var,
target_block,
grad_sub_block);
test();
std::vector<paddle::experimental::Tensor> muti_fw_input =
test.GetMultiForwardInput("X");
std::vector<paddle::optional<paddle::experimental::Tensor>>
opt_muti_fw_input = test.GetOptionalMultiForwardInput("X");
paddle::experimental::Tensor fw_out = test.GetSingleForwardOutput("Out");
paddle::experimental::Tensor* fw_out_ptr = test.GetOutputPtr(&fw_out);
std::string fw_out_name = test.GetOutputName(fw_out);
ASSERT_EQ(muti_fw_input.size(), static_cast<std::size_t>(2));
ASSERT_EQ(
static_cast<prim::DescTensor*>(muti_fw_input[0].impl().get())->Name(),
"x0");
ASSERT_EQ(
static_cast<prim::DescTensor*>(muti_fw_input[1].impl().get())->Name(),
"x1");
ASSERT_EQ(opt_muti_fw_input.size(), static_cast<std::size_t>(2));
ASSERT_EQ(static_cast<prim::DescTensor*>(
opt_muti_fw_input[0].get_ptr()->impl().get())
->Name(),
"x0");
ASSERT_EQ(static_cast<prim::DescTensor*>(
opt_muti_fw_input[1].get_ptr()->impl().get())
->Name(),
"x1");
ASSERT_EQ(&fw_out, fw_out_ptr);
ASSERT_EQ(fw_out_name, "out");
}
TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
TestBaseProgram base_program = TestBaseProgram();
auto* target_block = base_program.GetBlock(0);
std::vector<int64_t> shape = {4, 2};
std::vector<int64_t> shape_out = {2, 2};
StaticCompositeContext::Instance().SetBlock(target_block);
Tensor x = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
Tensor out1 = prim::empty<prim::DescTensor>(
shape_out, phi::DataType::FLOAT32, paddle::Place());
Tensor out2 = prim::empty<prim::DescTensor>(
shape_out, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* x_desc =
static_cast<prim::DescTensor*>(x.impl().get())->get_ptr();
target_block->RenameVar(x_desc->Name(), "x");
framework::VarDesc* out1_desc =
static_cast<prim::DescTensor*>(out1.impl().get())->get_ptr();
target_block->RenameVar(out1_desc->Name(), "out1");
framework::VarDesc* out2_desc =
static_cast<prim::DescTensor*>(out2.impl().get())->get_ptr();
target_block->RenameVar(out2_desc->Name(), "out2");
framework::VarDesc* input = target_block->FindVar("x");
std::vector<framework::VarDesc*> outputs = {target_block->FindVar("out1"),
target_block->FindVar("out2")};
base_program.split(input, 2, 0, outputs);
auto* forward_opdesc = target_block->AllOps()[0];
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<framework::BlockDesc*> grad_sub_block;
auto test = TestGradCompositeGradMaker(*forward_opdesc,
std::unordered_set<std::string>(),
&grad_to_var,
target_block,
grad_sub_block);
test();
paddle::experimental::Tensor fw_input = test.GetSingleForwardInput("X");
paddle::optional<paddle::experimental::Tensor> opt_fw_input =
test.GetOptionalSingleForwardInput("X");
std::vector<paddle::experimental::Tensor> fw_out =
test.GetMultiForwardOutput("Out");
std::vector<paddle::experimental::Tensor*> fw_out_ptr(fw_out.size());
for (size_t i = 0; i < fw_out.size(); ++i) {
fw_out_ptr[i] = &fw_out[i];
}
fw_out_ptr = test.GetOutputPtr(fw_out_ptr);
std::vector<std::string> fw_out_name = test.GetOutputName(fw_out);
ASSERT_EQ(static_cast<prim::DescTensor*>(fw_input.impl().get())->Name(), "x");
ASSERT_EQ(static_cast<prim::DescTensor*>(opt_fw_input.get_ptr()->impl().get())
->Name(),
"x");
ASSERT_EQ(fw_out.size(), static_cast<std::size_t>(2));
ASSERT_EQ(fw_out_ptr[0], &fw_out[0]);
ASSERT_EQ(fw_out_ptr[1], &fw_out[1]);
ASSERT_EQ(fw_out_name[0], "out1");
ASSERT_EQ(fw_out_name[1], "out2");
}
TEST(StaticPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
}
} // namespace prim
} // namespace paddle
USE_OP_ITSELF(tanh);
USE_OP_ITSELF(tanh_grad);
USE_OP_ITSELF(pow);
USE_OP_ITSELF(elementwise_mul);
USE_OP_ITSELF(scale);
add_subdirectory(eager)
add_subdirectory(static)
cc_library(
prim_utils
SRCS utils.cc
DEPS static_global_utils)
cc_library(
static_global_utils
SRCS static_global_utils.cc
DEPS proto_desc)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/extended_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/utils/any.h"
namespace paddle {
namespace prim {
class DescTensor : public phi::ExtendedTensor,
public phi::TypeInfoTraits<phi::TensorBase, DescTensor> {
public:
explicit DescTensor(framework::VarDesc* desc)
: desc_ptr_(desc), dims_(phi::make_ddim(desc->GetShape())) {}
static const char* name() { return "DescTensor"; }
std::string Name() const { return desc_ptr_->Name(); }
std::vector<int64_t> shape() const { return desc_ptr_->GetShape(); }
const phi::DDim& dims() const override {
dims_ = phi::make_ddim(desc_ptr_->GetShape());
return dims_;
}
DataType dtype() const override {
return paddle::framework::TransToPhiDataType(desc_ptr_->GetDataType());
}
framework::VarDesc* get_ptr() { return desc_ptr_; }
// TODO(jiabin): override more operators here.
private:
// VarDesc's lifetime is holded by block and it's program, so we just conceal
// its funcs instead of its life.
framework::VarDesc* desc_ptr_;
// TODO(jiabin): This is really ugly, but we have to hold a dims here so that
// we can inherient from ExtendedTensor Rmove this when we make VarDesc's as
// same as Tensor, or make Tensor's dims more lightly.
mutable phi::DDim dims_;
};
} // namespace prim
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/prim/utils/static/static_global_utils.h"
namespace paddle {
namespace prim {
StaticCompositeContext* StaticCompositeContext::static_composite_context_ =
new StaticCompositeContext();
thread_local bool StaticCompositeContext::enable_prim_ = false;
} // namespace prim
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <atomic>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/type_defs.h"
namespace paddle {
namespace prim {
class UniqueNameGenerator {
public:
explicit UniqueNameGenerator(std::string prefix = "") : prefix_(prefix) {}
std::string Generate(std::string key = "") {
return prefix_ + key + "_" + std::to_string(id_++);
}
private:
std::atomic<int> id_{0};
std::string prefix_;
};
class StaticCompositeContext {
public:
static StaticCompositeContext& Instance() {
return *static_composite_context_;
}
framework::BlockDesc* GetBlock() { return current_block_desc_; }
void SetBlock(framework::BlockDesc* new_block) {
current_block_desc_ = new_block;
}
std::string GenerateUniqueName(std::string key = "composite_tmp") {
return generator_->Generate(key);
}
void SetPrimEnabled(bool enable_prim) { enable_prim_ = enable_prim; }
bool IsPrimEnabled() { return enable_prim_; }
private:
StaticCompositeContext()
: current_block_desc_(nullptr), generator_(new UniqueNameGenerator()) {}
framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_;
static thread_local bool enable_prim_;
static StaticCompositeContext* static_composite_context_;
DISABLE_COPY_AND_ASSIGN(StaticCompositeContext);
};
} // namespace prim
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/prim/utils/utils.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/phi/core/flags.h"
PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not");
namespace paddle {
namespace prim {
bool PrimCommonUtils::IsPrimEnabled() {
return StaticCompositeContext::Instance().IsPrimEnabled();
}
void PrimCommonUtils::SetPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetPrimEnabled(enable_prim);
}
} // namespace prim
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace prim {
class PrimCommonUtils {
public:
static bool IsPrimEnabled();
static void SetPrimEnabled(bool enabled);
};
} // namespace prim
} // namespace paddle
......@@ -41,7 +41,8 @@ set(PYBIND_DEPS
new_profiler
auto_parallel
jit_layer
jit_property)
jit_property
prim_utils)
if(WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} ps_service)
......
......@@ -66,6 +66,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/prim/utils/utils.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/memory/allocation/cuda_ipc_allocator.h"
#endif
......@@ -645,6 +646,8 @@ PYBIND11_MODULE(libpaddle, m) {
return oss.str();
});
m.def("set_prim_enabled", &paddle::prim::PrimCommonUtils::SetPrimEnabled);
m.def("is_prim_enabled", &paddle::prim::PrimCommonUtils::IsPrimEnabled);
m.def("set_num_threads", &platform::SetNumThreads);
m.def("disable_signal_handler", &DisableSignalHandler);
......@@ -1221,11 +1224,56 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set<std::string> &no_grad_set,
const std::vector<BlockDesc *> &grad_sub_block) {
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<std::unique_ptr<OpDesc>> grad_op_descs =
framework::OpInfoMap::Instance()
.Get(op_desc.Type())
.GradOpMaker()(
op_desc, no_grad_set, &grad_to_var, grad_sub_block);
auto op_info = framework::OpInfoMap::Instance().Get(op_desc.Type());
auto grad_op_maker = op_info.GradOpMaker();
auto grad_comp_op_maker = op_info.GradCompOpMaker();
if ((grad_op_maker == nullptr) && (grad_comp_op_maker == nullptr)) {
// Normally, proto_ should not be null, except some special
// operators, such as LeaklyReluDoubleGrad op.
std::string type =
op_info.proto_ ? op_info.proto_->type() : "unknown";
PADDLE_THROW(platform::errors::NotFound(
"Neither operator %s's GradOpMaker nor GradCompOpMaker has "
"been registered.\nPlease check whether (%s) operator has "
"gradient operator.\nIf not, please set stop_gradient to be "
"True for its input and output variables using "
"var.stop_gradient=True.",
type.c_str(),
type.c_str()));
}
// In PrimEnabled mode, the priority of GradCompOpMaker is greater
// than GradCompMaker as we need split first-order grad operator into
// primitive operators for compiler. In PrimDisabled mode, the
// priority of GradCompOpMaker is less than GradCompMaker for better
// performance.
std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {
if (grad_comp_op_maker != nullptr) {
grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set,
&grad_to_var,
op_desc.Block(),
grad_sub_block);
} else {
grad_op_descs = grad_op_maker(
op_desc, no_grad_set, &grad_to_var, grad_sub_block);
}
} else {
if (grad_op_maker != nullptr) {
grad_op_descs = grad_op_maker(
op_desc, no_grad_set, &grad_to_var, grad_sub_block);
} else {
grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set,
&grad_to_var,
op_desc.Block(),
grad_sub_block);
}
}
std::vector<OpDesc *> grad_op_desc_ptrs(grad_op_descs.size());
std::transform(
grad_op_descs.begin(),
......
......@@ -1322,6 +1322,7 @@
param : [out]
kernel :
func : tanh_grad
composite : tanh_grad(out, out_grad, x_grad)
backward : tanh_double_grad
inplace : (out_grad -> x_grad)
......
......@@ -57,7 +57,6 @@ class TensorBase {
/// \brief Test whether the storage is allocated.
/// \return Whether the storage is allocated.
virtual bool initialized() const = 0;
// TODO(Aurelius84): This interface is under intermediate state now.
// We will remove DataType argument in the future. Please DO NOT
// rely on Datatype too much when designing and implementing other features.
......
......@@ -305,6 +305,8 @@ try:
from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent
from .libpaddle import _set_current_stream
from .libpaddle import _get_phi_kernel_name
from .libpaddle import set_prim_enabled
from .libpaddle import is_prim_enabled
if sys.platform != 'win32':
from .libpaddle import _set_process_pids
......
......@@ -837,6 +837,7 @@ add_subdirectory(dygraph_to_static)
add_subdirectory(rnn)
add_subdirectory(autograd)
add_subdirectory(distribution)
add_subdirectory(prim)
if(NOT WIN32 OR NOT WITH_GPU)
add_subdirectory(fft)
......
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
add_subdirectory(api)
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
add_subdirectory(comp)
add_subdirectory(prim)
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
add_subdirectory(eager_test)
add_subdirectory(static_test)
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
set_tests_properties(test_eager_tanh_grad_comp PROPERTIES TIMEOUT 60)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import autograd
import autograd.numpy
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
core.set_prim_enabled(True)
@param.parameterized_class(
('primal', 'cotangent', 'dtype'),
[
(np.random.rand(10, 10), np.random.rand(10, 10), np.float32),
],
)
class TestTanhGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.dtype)
cls.cotangent = cls.cotangent.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def actual(primal, cotangent):
paddle.disable_static()
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
x.stop_gradient = False
v = paddle.to_tensor(
cotangent, dtype='float32', stop_gradient=False
)
y = paddle.tanh(x)
x_cotangent = paddle.grad(
y, x, v, create_graph=True, retain_graph=True
)
return x_cotangent[0]
def desired(primal, cotangent):
return autograd.make_vjp(autograd.numpy.tanh)(primal)[0](cotangent)
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent),
desired=desired(self.primal, self.cotangent),
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
set_tests_properties(test_tanh_grad_comp PROPERTIES TIMEOUT 60)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.fluid import core
core.set_prim_enabled(True)
import autograd
import autograd.numpy
import numpy as np
import parameterized as param
import paddle
@param.parameterized_class(
('primal', 'cotangent', 'dtype'),
[
(np.random.rand(10, 10), np.random.rand(10, 10), np.float32),
],
)
class TestTanhGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.dtype)
cls.cotangent = cls.cotangent.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def actual(primal, cotangent):
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype)
x.stop_gradient = False
v = paddle.static.data(
'cotangent', cotangent.shape, cotangent.dtype
)
y = paddle.tanh(x)
x_cotangent = paddle.static.gradients(y, x, v)
exe = paddle.static.Executor()
exe.run(sp)
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=mp.blocks[0].ops[-1].output('Out')[0],
)[0]
def desired(primal, cotangent):
return autograd.make_vjp(autograd.numpy.tanh)(primal)[0](cotangent)
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent),
desired=desired(self.primal, self.cotangent),
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.fluid import core
core.set_prim_enabled(False)
import parameterized as param
import paddle
from paddle.fluid import core, framework
@param.parameterized_class(
(
'fwd_type',
'inputs',
'outputs',
'no_grad_var',
'grad_sub_block',
'desired_ops',
),
(
('tanh', {'X': ['x']}, {'Out': ['y']}, set(), tuple(), ('tanh_grad',)),
('empty', {}, {'Out': ['y']}, set(), tuple(), tuple()),
),
)
class TestGetGradOpDescPrimEnabled(unittest.TestCase):
@classmethod
def setUpClass(cls):
paddle.enable_static()
block = framework.Block(framework.Program(), 0)
block.append_op(
type=cls.fwd_type,
inputs={
n: [block.create_var(name=v, stop_gradient=False) for v in vs]
for n, vs in cls.inputs.items()
},
outputs={
n: [block.create_var(name=v, stop_gradient=False) for v in vs]
for n, vs in cls.outputs.items()
},
)
cls.fwd = block.ops[0].desc
@classmethod
def tearDownClass(cls):
paddle.disable_static()
def test_get_grad_op_desc(self):
actual = tuple(
desc.type()
for desc in core.get_grad_op_desc(
self.fwd, self.no_grad_var, self.grad_sub_block
)[0]
)
self.assertEquals(actual, self.desired_ops)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.fluid import core
core.set_prim_enabled(True)
import parameterized as param
import paddle
from paddle.fluid import core, framework
@param.parameterized_class(
(
'fwd_type',
'inputs',
'outputs',
'no_grad_var',
'grad_sub_block',
'desired_ops',
),
(
(
'tanh',
{'X': ['x']},
{'Out': ['y']},
set(),
tuple(),
('pow', 'scale', 'elementwise_mul'),
),
('empty', {}, {'Out': ['y']}, set(), tuple(), tuple()),
),
)
class TestGetGradOpDescPrimEnabled(unittest.TestCase):
@classmethod
def setUpClass(cls):
paddle.enable_static()
block = framework.Block(framework.Program(), 0)
block.append_op(
type=cls.fwd_type,
inputs={
n: [block.create_var(name=v, stop_gradient=False) for v in vs]
for n, vs in cls.inputs.items()
},
outputs={
n: [block.create_var(name=v, stop_gradient=False) for v in vs]
for n, vs in cls.outputs.items()
},
)
cls.fwd = block.ops[0].desc
@classmethod
def tearDownClass(cls):
paddle.disable_static()
def test_get_grad_op_desc(self):
actual = tuple(
desc.type()
for desc in core.get_grad_op_desc(
self.fwd, self.no_grad_var, self.grad_sub_block
)[0]
)
print(actual)
self.assertEquals(actual, self.desired_ops)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册