未验证 提交 ab198b45 编写于 作者: R Ruibiao Chen 提交者: GitHub

Set more attrs in ReplaceScaleLossGradOp (#44576)

* Set more attrs in ReplaceScaleLossGradOp

* Fix typos

* Fix CI errors

* Add UT
上级 6198ff24
......@@ -44,6 +44,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
~ScaleLossGradOpHandle() final;
proto::VarType::Type DType() const { return out_dtype_; }
std::string Name() const override;
platform::Place GetPlace() const { return place_; }
......
......@@ -67,7 +67,7 @@ cc_library(
cc_library(
graph_helper
SRCS graph_helper.cc
DEPS graph)
DEPS graph scale_loss_grad_op_handle)
cc_library(
pass
SRCS pass.cc
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <stack>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/op_proto_maker.h"
DECLARE_bool(convert_all_blocks);
......@@ -469,11 +470,23 @@ void RemoveControlDepInputAndOuput(OpDesc *op_desc) {
static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
desc->SetType("fill_constant");
desc->SetAttr("shape", std::vector<int64_t>({1}));
desc->SetAttr("value", 1.0f);
if (node.IsWrappedBy<details::OpHandleBase>()) {
details::OpHandleBase &op_hander =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
desc->SetAttr(
"dtype",
dynamic_cast<details::ScaleLossGradOpHandle *>(&op_hander)->DType());
}
desc->SetAttr("force_cpu", false);
desc->SetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward) | static_cast<int>(OpRole::kLoss)));
desc->SetAttr("value", 1.0f);
desc->SetAttr("shape", std::vector<int64_t>({1}));
// TODO(Ruibiao) : Set OpDeviceAttrName when needed
std::vector<std::string> output_names;
for (auto out : node.outputs) {
output_names.emplace_back(out->Name());
......@@ -503,6 +516,7 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
// create fill_constant op
if (n->Name() == "scale_loss_grad") {
VLOG(4) << "convert op node scale_loss_grad to desc fill_constant";
ops->emplace_back();
auto &desc = ops->back();
ReplaceScaleLossGradOp(*n, &desc);
......
......@@ -7,6 +7,8 @@ if(WITH_GPU OR APPLE)
# Compiling shared library will cost some time, but running process is very fast.
set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250)
set_tests_properties(test_custom_relu_op_setup
PROPERTIES ENVIRONMENT FLAGS_CONVERT_GRAPH_TO_PROGRAM=1)
set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180)
set_tests_properties(test_context_pool PROPERTIES TIMEOUT 180)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册