提交 26ab4538 编写于 作者: F fengjiayi

enum ==> enum class

上级 5e378724
......@@ -358,7 +358,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
3UL /* external input number */
+ 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/
//- 1UL /*ignoreGradient varable number*/
+ 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add */
......
......@@ -8,9 +8,9 @@ You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either
express or implied. See the License for the specific language governing
permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_proto.pb.h"
......@@ -23,10 +23,10 @@ class OpRegistry;
using VarIndexMap = std::unordered_map<std::string, int>;
enum OpArgType { IN, OUT };
enum class OpArgType { IN, OUT };
static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) {
std::string key = type == IN ? "input_format" : "output_name";
std::string key = type == OpArgType::IN ? "input_format" : "output_name";
return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr;
......@@ -34,7 +34,7 @@ static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) {
static const std::vector<int>* GetOpFormat(const OperatorBase* op,
const OpArgType& type) {
std::string key = type == IN ? "input_format" : "output_name";
std::string key = type == OpArgType::IN ? "input_format" : "output_name";
return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr;
......@@ -44,14 +44,15 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
const OpArgType& src_type, const OpArgType& dst_type,
int& idx, bool is_grad) {
const std::vector<std::string>& src_inout =
src_type == IN ? src_op->inputs_ : src_op->outputs_;
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
const std::vector<int>* src_format = GetOpFormat(src_op, src_type);
std::vector<std::string>& dst_inout =
dst_type == IN ? dst_op->inputs_ : dst_op->outputs_;
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_);
const auto& src_arg_list = src_type == IN ? proto.inputs() : proto.outputs();
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) {
std::string src_name = arg.name();
......@@ -83,19 +84,20 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
if (GetOpFormat(op, OUT) != nullptr) {
if (GetOpFormat(op, OpArgType::OUT) != nullptr) {
grad_op->attrs_["output_format"] = std::vector<int>({0});
}
if (GetOpFormat(op, IN) != nullptr || GetOpFormat(op, OUT) != nullptr) {
if (GetOpFormat(op, OpArgType::IN) != nullptr ||
GetOpFormat(op, OpArgType::OUT) != nullptr) {
grad_op->attrs_["input_format"] = std::vector<int>({0});
}
grad_op->in_out_idxs_.reset(new VarIndexMap());
int in_idx = 0;
int out_idx = 0;
TransOpArg(op, grad_op, IN, IN, in_idx, false); // I
TransOpArg(op, grad_op, OUT, IN, in_idx, false); // G
TransOpArg(op, grad_op, OUT, IN, in_idx, true); // OG
TransOpArg(op, grad_op, IN, OUT, out_idx, true); // IG
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, in_idx, false); // I
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, false); // G
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, true); // OG
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG
return grad_op;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册