提交 26ab4538 编写于 作者: F fengjiayi

enum ==> enum class

上级 5e378724
...@@ -358,7 +358,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -358,7 +358,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
3UL /* external input number */ 3UL /* external input number */
+ 1UL /* external output number*/ + 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/ + 1UL /* number of gradient of external output*/
//- 1UL /*ignoreGradient varable number*/
+ 2U /* internal variable number*/); + 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/ EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add */ + 2UL /* input number of rowwise_add */
......
...@@ -8,9 +8,9 @@ You may obtain a copy of the License at ...@@ -8,9 +8,9 @@ You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either
See the License for the specific language governing permissions and express or implied. See the License for the specific language governing
limitations under the License. */ permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
...@@ -23,10 +23,10 @@ class OpRegistry; ...@@ -23,10 +23,10 @@ class OpRegistry;
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
enum OpArgType { IN, OUT }; enum class OpArgType { IN, OUT };
static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) { static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) {
std::string key = type == IN ? "input_format" : "output_name"; std::string key = type == OpArgType::IN ? "input_format" : "output_name";
return op->attrs_.count(key) return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key)) ? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr; : nullptr;
...@@ -34,7 +34,7 @@ static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) { ...@@ -34,7 +34,7 @@ static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) {
static const std::vector<int>* GetOpFormat(const OperatorBase* op, static const std::vector<int>* GetOpFormat(const OperatorBase* op,
const OpArgType& type) { const OpArgType& type) {
std::string key = type == IN ? "input_format" : "output_name"; std::string key = type == OpArgType::IN ? "input_format" : "output_name";
return op->attrs_.count(key) return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key)) ? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr; : nullptr;
...@@ -44,14 +44,15 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, ...@@ -44,14 +44,15 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
const OpArgType& src_type, const OpArgType& dst_type, const OpArgType& src_type, const OpArgType& dst_type,
int& idx, bool is_grad) { int& idx, bool is_grad) {
const std::vector<std::string>& src_inout = const std::vector<std::string>& src_inout =
src_type == IN ? src_op->inputs_ : src_op->outputs_; src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
const std::vector<int>* src_format = GetOpFormat(src_op, src_type); const std::vector<int>* src_format = GetOpFormat(src_op, src_type);
std::vector<std::string>& dst_inout = std::vector<std::string>& dst_inout =
dst_type == IN ? dst_op->inputs_ : dst_op->outputs_; dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type); std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_); const OpProto& proto = OpRegistry::protos().at(src_op->type_);
const auto& src_arg_list = src_type == IN ? proto.inputs() : proto.outputs(); const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
std::string src_name = arg.name(); std::string src_name = arg.name();
...@@ -83,19 +84,20 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { ...@@ -83,19 +84,20 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
grad_op->attrs_ = op->attrs_; grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format"); grad_op->attrs_.erase("output_format");
if (GetOpFormat(op, OUT) != nullptr) { if (GetOpFormat(op, OpArgType::OUT) != nullptr) {
grad_op->attrs_["output_format"] = std::vector<int>({0}); grad_op->attrs_["output_format"] = std::vector<int>({0});
} }
if (GetOpFormat(op, IN) != nullptr || GetOpFormat(op, OUT) != nullptr) { if (GetOpFormat(op, OpArgType::IN) != nullptr ||
GetOpFormat(op, OpArgType::OUT) != nullptr) {
grad_op->attrs_["input_format"] = std::vector<int>({0}); grad_op->attrs_["input_format"] = std::vector<int>({0});
} }
grad_op->in_out_idxs_.reset(new VarIndexMap()); grad_op->in_out_idxs_.reset(new VarIndexMap());
int in_idx = 0; int in_idx = 0;
int out_idx = 0; int out_idx = 0;
TransOpArg(op, grad_op, IN, IN, in_idx, false); // I TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, in_idx, false); // I
TransOpArg(op, grad_op, OUT, IN, in_idx, false); // G TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, false); // G
TransOpArg(op, grad_op, OUT, IN, in_idx, true); // OG TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, true); // OG
TransOpArg(op, grad_op, IN, OUT, out_idx, true); // IG TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG
return grad_op; return grad_op;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册