未验证 提交 ab198b45 编写于 作者: R Ruibiao Chen 提交者: GitHub

Set more attrs in ReplaceScaleLossGradOp (#44576)

* Set more attrs in ReplaceScaleLossGradOp

* Fix typos

* Fix CI errors

* Add UT
上级 6198ff24
...@@ -44,6 +44,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase { ...@@ -44,6 +44,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
~ScaleLossGradOpHandle() final; ~ScaleLossGradOpHandle() final;
proto::VarType::Type DType() const { return out_dtype_; }
std::string Name() const override; std::string Name() const override;
platform::Place GetPlace() const { return place_; } platform::Place GetPlace() const { return place_; }
......
...@@ -67,7 +67,7 @@ cc_library( ...@@ -67,7 +67,7 @@ cc_library(
cc_library( cc_library(
graph_helper graph_helper
SRCS graph_helper.cc SRCS graph_helper.cc
DEPS graph) DEPS graph scale_loss_grad_op_handle)
cc_library( cc_library(
pass pass
SRCS pass.cc SRCS pass.cc
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <stack> #include <stack>
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
DECLARE_bool(convert_all_blocks); DECLARE_bool(convert_all_blocks);
...@@ -469,11 +470,23 @@ void RemoveControlDepInputAndOuput(OpDesc *op_desc) { ...@@ -469,11 +470,23 @@ void RemoveControlDepInputAndOuput(OpDesc *op_desc) {
static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
desc->SetType("fill_constant"); desc->SetType("fill_constant");
desc->SetAttr("shape", std::vector<int64_t>({1}));
desc->SetAttr("value", 1.0f);
if (node.IsWrappedBy<details::OpHandleBase>()) {
details::OpHandleBase &op_hander =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
desc->SetAttr(
"dtype",
dynamic_cast<details::ScaleLossGradOpHandle *>(&op_hander)->DType());
}
desc->SetAttr("force_cpu", false);
desc->SetAttr( desc->SetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName(), OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward) | static_cast<int>(OpRole::kLoss))); (static_cast<int>(OpRole::kBackward) | static_cast<int>(OpRole::kLoss)));
desc->SetAttr("value", 1.0f); // TODO(Ruibiao) : Set OpDeviceAttrName when needed
desc->SetAttr("shape", std::vector<int64_t>({1}));
std::vector<std::string> output_names; std::vector<std::string> output_names;
for (auto out : node.outputs) { for (auto out : node.outputs) {
output_names.emplace_back(out->Name()); output_names.emplace_back(out->Name());
...@@ -503,6 +516,7 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes, ...@@ -503,6 +516,7 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
// create fill_constant op // create fill_constant op
if (n->Name() == "scale_loss_grad") { if (n->Name() == "scale_loss_grad") {
VLOG(4) << "convert op node scale_loss_grad to desc fill_constant";
ops->emplace_back(); ops->emplace_back();
auto &desc = ops->back(); auto &desc = ops->back();
ReplaceScaleLossGradOp(*n, &desc); ReplaceScaleLossGradOp(*n, &desc);
......
...@@ -7,6 +7,8 @@ if(WITH_GPU OR APPLE) ...@@ -7,6 +7,8 @@ if(WITH_GPU OR APPLE)
# Compiling shared library will cost some time, but running process is very fast. # Compiling shared library will cost some time, but running process is very fast.
set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250) set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250)
set_tests_properties(test_custom_relu_op_setup
PROPERTIES ENVIRONMENT FLAGS_CONVERT_GRAPH_TO_PROGRAM=1)
set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180) set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180) set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180)
set_tests_properties(test_context_pool PROPERTIES TIMEOUT 180) set_tests_properties(test_context_pool PROPERTIES TIMEOUT 180)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册