未验证 提交 d422a1ed 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Handled special sum_grad_op code gen in Eager Dygraph (#38573)

* Handled special sum_grad_op code gen in Eager Dygraph

* Fixed merge issues
上级 89c0877e
......@@ -27,6 +27,8 @@
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
#define NUM_CREATED_DUP_INPUTS 4
namespace paddle {
namespace framework {
......@@ -46,6 +48,62 @@ static std::string LegalizeVariableName(const std::string& var_name) {
return ret;
}
static bool IgnoreGradAttribute(const std::string& op_type,
const std::string& attr_name) {
// Attributes in operators_with_attrs are created manually during code
// generation
// We should ignore these arbitrary attrs when setting up grad attribute map
if (operators_with_attrs.count(op_type)) {
if (operators_with_attrs[op_type].count(attr_name)) {
return true;
}
}
return false;
}
static void PrepareAttrMapForOps() {
// Handle "run_program_op"
static framework::ProgramDesc fake_prog;
operators_with_attrs["run_program"] = {};
operators_with_attrs["run_program"]["global_block"] =
fake_prog.MutableBlock(0);
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
// Handle "c_split"
operators_with_attrs["c_split"] = {};
operators_with_attrs["c_split"]["nranks"] = 1;
}
/* --- Helper Objects --- */
class ForwardGenerationInfo {
public:
......@@ -136,6 +194,13 @@ class GradNodeGenerationInfo {
return &grad_outs_;
}
const paddle::framework::AttributeMap& GetGradAttrs() const {
return grad_attrs_;
}
paddle::framework::AttributeMap* GetMutableGradAttrs() {
return &grad_attrs_;
}
private:
std::string op_base_type_;
std::map<std::string, std::string> grad_outs_slotname_map_;
......@@ -147,6 +212,7 @@ class GradNodeGenerationInfo {
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_outs_;
paddle::framework::AttributeMap grad_attrs_;
};
public:
......@@ -677,6 +743,25 @@ static bool CollectGradInformationFromOpInfo(
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VarBase>>>
ins;
if (op_proto.inputs().size() == 1 && op_proto.outputs().size() == 1 &&
op_proto.inputs()[0].duplicable() &&
!op_proto.outputs()[0].duplicable()) {
VLOG(6) << "Handle op with special op_bases: " << op_type;
// @special case (sum_op): for ops with single duplicable input and single
// non-duplicable output
// feed in NUM_CREATED_DUP_INPUTS inputs to detect a
// special scenario.
const std::string& in_name = op_proto.inputs()[0].name();
ins[in_name] = {};
for (size_t i = 0; i < NUM_CREATED_DUP_INPUTS; i++) {
ins[in_name].emplace_back(std::shared_ptr<paddle::imperative::VarBase>(
new paddle::imperative::VarBase("auto_" + in_name + "_" +
std::to_string(i))));
ins[in_name][i]->SetOverridedStopGradient(false);
ins[in_name][i]->MutableVar()->GetMutable<framework::LoDTensor>();
}
} else {
for (const proto::OpProto::Var& input : op_proto.inputs()) {
const std::string& in_name = input.name();
......@@ -694,11 +779,13 @@ static bool CollectGradInformationFromOpInfo(
// We dont know the exact number of inputs expected,
// but we only need to identify the slot name order,
// therefore fill in 1 single input VarBase is enough in this scenario
ins[in_name] = {std::shared_ptr<paddle::imperative::VarBase>(
new paddle::imperative::VarBase("auto_" + in_name))};
ins[in_name][0]->SetOverridedStopGradient(false);
ins[in_name][0]->MutableVar()->GetMutable<framework::LoDTensor>();
}
}
VLOG(6) << "Prepared Forward Ins Map, size = " << ins.size();
/* ------ Prepare "outs" ------ */
......@@ -725,7 +812,6 @@ static bool CollectGradInformationFromOpInfo(
VLOG(6) << "Checking AttributeMap Settings";
attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
default_attrs = attr_checker->GetDefaultAttrMap();
VLOG(6) << "AttributeMap Checking Passed";
} else {
VLOG(6) << "Detected Null Attribute Checker, use empty default_attrs";
}
......@@ -797,13 +883,13 @@ static bool CollectGradInformationFromOpInfo(
(*op_base_infos)[index].SetOpBaseType(op_base.Type());
}
/* ------ Get Grad ins/outs ---- */
// In case of multiple OpBase, stitch all the respective ins/outs into one
/* ------ Get Grad ins/outs/attrs ---- */
VLOG(6) << "In function size: " << grad_node->size();
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
int index = std::distance(grad_node->begin(), iter);
auto* op_base_grad_ins = (*op_base_infos)[index].GetMutableGradIns();
auto* op_base_grad_outs = (*op_base_infos)[index].GetMutableGradOuts();
auto* op_base_grad_attrs = (*op_base_infos)[index].GetMutableGradAttrs();
const paddle::imperative::OpBase& op_base = *iter;
const std::map<std::string, paddle::imperative::SavedVariableWrapperList>&
......@@ -811,6 +897,8 @@ static bool CollectGradInformationFromOpInfo(
const std::map<std::string, paddle::imperative::SavedVariableWrapperList>&
g_outs = op_base.GetOutsMap();
*op_base_grad_attrs = op_base.Attrs();
for (const auto& it : g_ins) {
if (!op_base_grad_ins->count(it.first))
(*op_base_grad_ins)[it.first] = {};
......@@ -1395,84 +1483,29 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
return {fwd_function_str, dygraph_function_declaration_str};
}
/* ---------------------------------------------- */
/* --------- CodeGen: GradNode::operator() ------ */
/* ---------------------------------------------- */
static std::string GenerateGradNodeCCContents(
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info) {
/* --- Process Forward Info --- */
const std::string& fwd_op_type = fwd_info.GetOpType();
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
fwd_info.GetFwdInputsNamePosMap();
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map =
fwd_info.GetFwdOutputsNamePosMap();
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
VLOG(6) << "Generating Grad Node CC";
/* [Outline]
vector<vector<Tensor>> GradNodeXXX::operator()(vector<vector<Tensor>>& grads)
{
const std::shared_ptr<Tracer>& tracer = imperative::GetCurrentTracer();
// Comes from "grad_ins"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> ins =
{
"X" : this->"X", "Y" : this->"Y",
"Out0@Grad":
SyncToVars(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out0@Grad"]]"]),
"Out1@Grad":
TensorsToVarBases(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out1@Grad"]]"])
};
// Comes from "grad_outs"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> outs =
{
"X@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["X@Grad"]]"].Size()),
"Y@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["Y@Grad"]]"].Size())
};
// Visit each OpBase
for(auto iter = "grad_node->begin()"; iter < "grad_node->end()"; iter++) {
// Simply pass entire attribute map to kernels
egr::legacy::RunOp("iter->Type()", ins, outs, this->attr_map_,
egr::Controller::Instance().ExpectedPlace(), false, {});
}
vector<vector<egr::EagerTensor>> outputs(outs.size());
for(auto& kv : outs) {
outputs["fwd_inputs_name_pos_map[grad_outs_slotname_map[kv.first]]"] =
GetOutputs(outs["kv.first"]);
}
return outputs;
}
*/
static std::string GenerateSingleOpBase(
const std::string& fwd_op_type, const std::string& op_base_type,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map,
const std::vector<proto::OpProto::Var>& in_vars,
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
const std::map<std::string, std::string>& grad_ins_grad_slotname_map,
const std::map<std::string, std::string>& grad_outs_slotname_map,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_ins,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs,
const paddle::framework::AttributeMap& grad_attrs,
bool is_op_base_per_duplicable_input, size_t* outs_size) {
std::string generated_grad_function_body = "";
size_t outs_size = 0;
const auto& op_base_infos = bwd_info.GetOpBaseInfos();
for (size_t i = 0; i < op_base_infos.size(); i++) {
const auto& op_base_info = op_base_infos[i];
const auto& grad_ins_fwd_slotname_map =
op_base_info.GetGradInsFwdSlotnameMap();
const auto& grad_ins_grad_slotname_map =
op_base_info.GetGradInsGradSlotnameMap();
const auto& grad_outs_slotname_map = op_base_info.GetGradOutsSlotnameMap();
const auto& grad_ins = op_base_info.GetGradIns();
const auto& grad_outs = op_base_info.GetGradOuts();
const std::string& op_base_type = op_base_info.GetOpBaseType();
const std::string& ins_name = "ins" + std::to_string(i);
const std::string& outs_name = "outs" + std::to_string(i);
outs_size += grad_outs.size();
const std::string& ins_name = "ins" + std::to_string(*outs_size);
const std::string& outs_name = "outs" + std::to_string(*outs_size);
const std::string& attrs_name = "attrs_map" + std::to_string(*outs_size);
// [Generation] Get Ins Map
std::string ins_contents_str = "";
......@@ -1499,9 +1532,8 @@ static std::string GenerateGradNodeCCContents(
grad_ins_grad_slotname_map.at(grad_input_name));
const char* GRAD_INS_GRAD_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_GRAD_CONTENT_TEMPLATE,
grad_input_name, fwd_output_position);
ins_contents_str += paddle::string::Sprintf(
GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position);
} else {
PADDLE_THROW(platform::errors::Fatal(
......@@ -1517,8 +1549,8 @@ static std::string GenerateGradNodeCCContents(
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { "
"%s };\n";
std::string ins_map_str = paddle::string::Sprintf(
BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str);
std::string ins_map_str =
paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str);
generated_grad_function_body += ins_map_str;
VLOG(6) << "Generated Ins Map";
......@@ -1535,8 +1567,7 @@ static std::string GenerateGradNodeCCContents(
if (grad_outs_slotname_map.count(grad_output_name)) {
// Fwd Tensor
const std::string& fwd_name =
grad_outs_slotname_map.at(grad_output_name);
const std::string& fwd_name = grad_outs_slotname_map.at(grad_output_name);
/* Handle Special Case: "PullSparseOp", etc
......@@ -1569,8 +1600,7 @@ static std::string GenerateGradNodeCCContents(
For returns, append "GradOut" to the very end of return list.
*/
if (!fwd_inputs_name_pos_map.count(fwd_name)) {
PADDLE_ENFORCE(
fwd_outputs_name_pos_map.count(fwd_name),
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
......@@ -1584,13 +1614,13 @@ static std::string GenerateGradNodeCCContents(
} else {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (duplicable_input_name_set.count(fwd_name)) {
if (duplicable_input_name_set.count(fwd_name) &&
!is_op_base_per_duplicable_input) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },";
outs_contents_str +=
paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
grad_output_name, fwd_input_position);
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position);
} else {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
......@@ -1623,15 +1653,32 @@ static std::string GenerateGradNodeCCContents(
VLOG(6) << "Generated Outs Map";
// [Generation] Get Attrs Map
const char* ATTRS_TEMPLATE = " auto %s = this->attr_map_;\n";
std::string grad_attrs_str =
paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name);
for (const auto& iter : grad_attrs) {
if (IgnoreGradAttribute(fwd_op_type, iter.first)) continue;
std::pair<std::string, std::string> type_val =
GetAttrType(iter.second, false /*is_arg*/);
const char* GRAD_ATTRS_TEMPLATE =
" %s %s = %s;\n"
" %s[\"%s\"] = %s;\n";
std::string var_name = iter.first + std::to_string(*outs_size);
grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, type_val.first, var_name, type_val.second,
attrs_name, iter.first, var_name);
}
generated_grad_function_body += grad_attrs_str;
const char* TRACE_OP_TEMPLATE =
" // Pass the entire attribute map to TraceOp\n"
" // The underlying kernel will pickup whatever attribute they need "
"at runtime\n"
" egr::legacy::RunOp(\"%s\", %s, %s, this->attr_map_,\n"
" egr::legacy::RunOp(\"%s\", %s, %s, %s,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n"
" &this->default_attr_map_, false, {});\n";
std::string trace_opbase_str = paddle::string::Sprintf(
TRACE_OP_TEMPLATE, op_base_type, ins_name, outs_name);
TRACE_OP_TEMPLATE, op_base_type, ins_name, outs_name, attrs_name);
generated_grad_function_body += trace_opbase_str;
......@@ -1646,10 +1693,19 @@ static std::string GenerateGradNodeCCContents(
if (fwd_inputs_name_pos_map.count(fwd_name)) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (!is_op_base_per_duplicable_input) {
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name);
} else {
const char* BWD_OUTPUT_TEMPLATE =
" "
"outputs[0].emplace_back(egr::EagerUtils::GetOutputs(%s[\"%s\"])[0]"
");\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, outs_name,
grad_out_name);
}
num_appended_outputs++;
} else {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
......@@ -1668,15 +1724,127 @@ static std::string GenerateGradNodeCCContents(
if (fwd_outputs_name_pos_map.count(fwd_name)) {
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str +=
paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, num_appended_outputs,
outs_name, grad_out_name);
outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, num_appended_outputs, outs_name, grad_out_name);
num_appended_outputs++;
}
}
generated_grad_function_body += outputs_str;
generated_grad_function_body += "\n";
*outs_size += grad_outs.size();
return generated_grad_function_body;
}
/* ---------------------------------------------- */
/* --------- CodeGen: GradNode::operator() ------ */
/* ---------------------------------------------- */
static std::string GenerateGradNodeCCContents(
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info) {
/* --- Process Forward Info --- */
const std::string& fwd_op_type = fwd_info.GetOpType();
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
fwd_info.GetFwdInputsNamePosMap();
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map =
fwd_info.GetFwdOutputsNamePosMap();
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
const std::vector<proto::OpProto::Var>& out_vars = fwd_info.GetOutVars();
VLOG(6) << "Generating Grad Node CC";
/* [Outline]
vector<vector<Tensor>> GradNodeXXX::operator()(vector<vector<Tensor>>& grads)
{
const std::shared_ptr<Tracer>& tracer = imperative::GetCurrentTracer();
// Comes from "grad_ins"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> ins =
{
"X" : this->"X", "Y" : this->"Y",
"Out0@Grad":
SyncToVars(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out0@Grad"]]"]),
"Out1@Grad":
TensorsToVarBases(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out1@Grad"]]"])
};
// Comes from "grad_outs"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> outs =
{
"X@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["X@Grad"]]"].Size()),
"Y@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["Y@Grad"]]"].Size())
};
// Visit each OpBase
for(auto iter = "grad_node->begin()"; iter < "grad_node->end()"; iter++) {
// Simply pass entire attribute map to kernels
egr::legacy::RunOp("iter->Type()", ins, outs, this->attr_map_,
egr::Controller::Instance().ExpectedPlace(), false, {});
}
vector<vector<egr::EagerTensor>> outputs(outs.size());
for(auto& kv : outs) {
outputs["fwd_inputs_name_pos_map[grad_outs_slotname_map[kv.first]]"] =
GetOutputs(outs["kv.first"]);
}
return outputs;
}
*/
// This is a Copy
auto op_base_infos = bwd_info.GetOpBaseInfos();
/* Special Case: ops such as sum_grad_op is implemented abnormaly,
where it unpacked duplicable GradX and created one OpBase
corresponds to each member of GradX[i]
*/
bool is_op_base_per_duplicable_input = false;
if (in_vars.size() == 1 && out_vars.size() == 1 && in_vars[0].duplicable() &&
!out_vars[0].duplicable() &&
op_base_infos.size() == NUM_CREATED_DUP_INPUTS) {
is_op_base_per_duplicable_input = true;
// Only keep the first op_base
auto op_base_info = op_base_infos[0];
op_base_infos.clear();
op_base_infos.emplace_back(std::move(op_base_info));
}
std::string generated_grad_function_body = "";
size_t outs_size = 0;
for (size_t i = 0; i < op_base_infos.size(); i++) {
const auto& op_base_info = op_base_infos[i];
const auto& grad_ins_fwd_slotname_map =
op_base_info.GetGradInsFwdSlotnameMap();
const auto& grad_ins_grad_slotname_map =
op_base_info.GetGradInsGradSlotnameMap();
const auto& grad_outs_slotname_map = op_base_info.GetGradOutsSlotnameMap();
const auto& grad_ins = op_base_info.GetGradIns();
const auto& grad_outs = op_base_info.GetGradOuts();
const auto& grad_attrs = op_base_info.GetGradAttrs();
const std::string& op_base_type = op_base_info.GetOpBaseType();
generated_grad_function_body += GenerateSingleOpBase(
fwd_op_type, op_base_type, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, in_vars, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, grad_outs,
grad_attrs, is_op_base_per_duplicable_input, &outs_size);
}
if (is_op_base_per_duplicable_input) {
const char* OP_BASE_PER_DUP_INPUT_TEMPLATE =
" for(int i = 0; i < this->OutputMeta()[0].Size(); i++) {\n"
" %s\n"
" }\n";
generated_grad_function_body = paddle::string::Sprintf(
OP_BASE_PER_DUP_INPUT_TEMPLATE, generated_grad_function_body);
}
const char* BWD_RETURN_TEMPLATE =
......@@ -2045,47 +2213,6 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
GenerateNodeCCFile(node_cc_path, grad_node_cc_str);
}
static void PrepareAttrMapForOps() {
// Handle "run_program_op"
static framework::ProgramDesc fake_prog;
operators_with_attrs["run_program"] = {};
operators_with_attrs["run_program"]["global_block"] =
fake_prog.MutableBlock(0);
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
// Handle "c_split"
operators_with_attrs["c_split"] = {};
operators_with_attrs["c_split"]["nranks"] = 1;
}
} // namespace framework
} // namespace paddle
......
......@@ -321,7 +321,7 @@ class TestImperative(unittest.TestCase):
with paddle.set_grad_enabled(True):
self.assertTrue(paddle.is_grad_enabled())
def test_sum_op(self):
def func_sum_op(self):
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs = []
......@@ -338,7 +338,7 @@ class TestImperative(unittest.TestCase):
tmp = paddle.to_tensor(x)
tmp.stop_gradient = False
inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2)
ret2 = paddle.add_n(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
fluid.set_flags({'FLAGS_sort_sum_gradient': True})
loss2.backward()
......@@ -349,6 +349,11 @@ class TestImperative(unittest.TestCase):
a = inputs2[0].gradient()
self.assertTrue(np.allclose(inputs2[0].gradient(), x))
def test_sum_op(self):
with _test_eager_guard():
self.func_sum_op()
self.func_sum_op()
def func_empty_var(self):
with fluid.dygraph.guard():
cur_program = fluid.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册