未验证 提交 deee91d8 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] register set_value in new ir (#56436)

* register set_value in new ir

* fix

* register set_value_grad

* fix

* fix

* remove debug info

* add unittest

* fix

* fix

* fix

* fix

* fix

* resolve comments
上级 38b8e8a5
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute_storage.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/phi/common/scalar.h"
......@@ -49,6 +50,10 @@ class ScalarAttribute : public ir::Attribute {
(val.type_id() == ir::StrAttribute::type_id());
}
static ir::Attribute get(ir::IrContext *ctx, phi::Scalar scalar) {
return TransToIrAttribute(scalar, ctx);
}
phi::Scalar data();
};
......
......@@ -329,7 +329,6 @@
view: null
backward: null
- name: shadow_feed
inputs:
- typename: Tensor
......@@ -355,3 +354,72 @@
force_backend: null
inplace: null
backward: null
- name : set_value
inputs:
- {typename: Tensor, name: x, optional: false, no_need_buffer: false, data_transform: {} }
attrs:
- {typename: 'int64_t[]', name: starts}
- {typename: 'int64_t[]', name: ends}
- {typename: 'int64_t[]', name: steps}
- {typename: 'int64_t[]', name: axes}
- {typename: 'int64_t[]', name: decrease_axes}
- {typename: 'int64_t[]', name: none_axes}
- {typename: 'int64_t[]', name: shape}
- {typename: 'Scalar[]', name: values}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
infer_meta:
func: SetValueInferMeta
param: [x]
kernel:
func: [set_value]
param: [x, starts, ends, steps, axes, decrease_axes, none_axes, shape, values]
inplace: {out: x}
backward: set_value_grad
- name : set_value_with_tensor
inputs:
- {typename: Tensor, name: x, optional: false, no_need_buffer: false, data_transform: {} }
- {typename: Tensor, name: values, optional: false, no_need_buffer: false, data_transform: {} }
attrs:
- {typename: 'int64_t[]', name: starts}
- {typename: 'int64_t[]', name: ends}
- {typename: 'int64_t[]', name: steps}
- {typename: 'int64_t[]', name: axes}
- {typename: 'int64_t[]', name: decrease_axes}
- {typename: 'int64_t[]', name: none_axes}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
infer_meta:
func: SetValueInferMeta
param: [x]
kernel:
func: [set_value_with_tensor]
param: [x, values, starts, ends, steps, axes, decrease_axes, none_axes]
inplace: {out: x}
backward: set_value_grad
- name : set_value_grad
inputs:
- {typename: Tensor, name: out_grad, optional: false, no_need_buffer: false, data_transform: {} }
- {typename: Tensor, name: values, optional: false, no_need_buffer: false, data_transform: {} }
attrs:
- {typename: 'int64_t[]', name: starts}
- {typename: 'int64_t[]', name: ends}
- {typename: 'int64_t[]', name: steps}
- {typename: 'int64_t[]', name: axes}
- {typename: 'int64_t[]', name: decrease_axes}
- {typename: 'int64_t[]', name: none_axes}
outputs:
- {typename: Tensor, name: x_grad, optional: false, intermediate: false}
- {typename: Tensor, name: values_grad, optional: false, intermediate: false}
infer_meta:
func: SetValueGradInferMeta
param: [out_grad, values]
kernel:
func: [set_value_grad]
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]
inplace: null
backward: null
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
namespace paddle {
namespace dialect {
......
......@@ -16,7 +16,6 @@
// #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
......
......@@ -37,6 +37,7 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h"
#include "paddle/ir/core/type_name.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "glog/logging.h"
......@@ -81,8 +82,8 @@ void BuildPhiContext(ir::Operation* op,
Context* ctx) {
paddle::framework::Scope* inner_scope =
local_scope != nullptr ? local_scope : scope;
VLOG(6) << "BuildPhiContext in scope[" << scope << "] inner_scope["
<< inner_scope << "]";
VLOG(6) << "Build " << get_type_name<Context>() << " in scope[" << scope
<< "] inner_scope[" << inner_scope << "]";
auto attr_map = op->attributes();
......
......@@ -113,7 +113,7 @@ class AttributeVisitor {
}
virtual ir::Attribute operator()(const std::vector<int64_t>& i64s) {
VLOG(10) << "translating vector<int64>";
VLOG(10) << "translating vector<int64> size: " << i64s.size();
std::vector<ir::Attribute> attrs;
attrs.reserve(i64s.size());
for (const auto& v : i64s) {
......@@ -135,8 +135,13 @@ class AttributeVisitor {
virtual ir::Attribute operator()(
const std::vector<paddle::experimental::Scalar>& ss) {
VLOG(10) << "translating vector<scalar>";
IR_THROW(
"not support translating std::vector<paddle::experimental::Scalar>");
std::vector<ir::Attribute> attrs;
attrs.reserve(ss.size());
for (const auto& v : ss) {
attrs.push_back(dialect::ScalarAttribute::get(ctx, v));
}
VLOG(10) << "translating vector<scalar> Done";
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const paddle::blank& blank) {
......@@ -164,6 +169,11 @@ class Int64ArrayAttributeVisitor : public AttributeVisitor {
}
return ir::ArrayAttribute::get(ctx, attrs);
}
ir::Attribute operator()(const paddle::blank& blank) override {
VLOG(10) << "translating paddle::blank to int64[]";
return ir::ArrayAttribute::get(ctx, {});
}
};
class IntArrayAttributeVisitor : public AttributeVisitor {
......
......@@ -126,6 +126,9 @@ def OpNameNormalizerInitialization(
backward_op, op_compat_item["scalar"]
)
# special mapping list
op_arg_name_mappings["set_value_grad"]["values_grad"] = "ValueTensor@GRAD"
op_name_normailzer_template = env.get_template("op_compat_info.cc.j2")
with open(output_source_file, 'wt') as f:
op_compat_definition = op_name_normailzer_template.render(
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <functional>
#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -75,42 +76,66 @@ class OpNameNormalizer {
return op_mutable_attribute_infos.at(op_type).at(arg_name);
}
std::optional<std::string> GetDirectMapping(const std::string& op_type,
const std::string& arg_name) {
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
return {};
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
return {};
}
return arg_mappings.at(arg_name);
}
std::optional<std::string> GetGradNameMapping(const std::string& op_type,
const std::string& arg_name) {
std::string target = kPhiGradSuffix;
std::string data = kFluidVarGradSuffix;
size_t first_grad_pos = arg_name.find(target);
size_t type_pos = op_type.find(target);
std::string legacy_name = arg_name.substr(0, first_grad_pos);
std::optional<std::string> ret =
this->GetDirectMapping(op_type.substr(0, type_pos), legacy_name);
if (ret) {
legacy_name = ret.value();
}
legacy_name = legacy_name + arg_name.substr(first_grad_pos);
for (size_t pos = 0;
legacy_name.npos != (pos = legacy_name.find(target, pos));
pos += data.length()) {
legacy_name.replace(pos, target.length(), data);
}
return legacy_name;
}
std::string GetLegacyArgName(const std::string& op_type,
const std::string& arg_name) {
if (auto ret = GetDirectMapping(op_type, arg_name)) {
VLOG(10) << "[" << op_type << "] found " << ret.value();
return ret.value();
}
bool is_grad_op = (op_type.find(kPhiGradSuffix) != std::string::npos);
bool is_grad_arg = (arg_name.find(kPhiGradSuffix) != std::string::npos);
if (is_grad_op && is_grad_arg) {
std::string target = kPhiGradSuffix;
std::string data = kFluidVarGradSuffix;
size_t first_grad_pos = arg_name.find(target);
size_t type_pos = op_type.find(target);
std::string legacy_name = this->GetLegacyArgName(
op_type.substr(0, type_pos), arg_name.substr(0, first_grad_pos));
legacy_name += arg_name.substr(first_grad_pos);
for (size_t pos = 0;
legacy_name.npos != (pos = legacy_name.find(target, pos));
pos += data.length()) {
legacy_name.replace(pos, target.length(), data);
if (auto ret = GetGradNameMapping(op_type, arg_name)) {
VLOG(10) << "[" << op_type << "] found " << ret.value();
return ret.value();
}
return legacy_name;
} else if (is_grad_op && !is_grad_arg) {
// backwward op using forward args: like trace_grad using forward input
size_t type_pos = op_type.find(kPhiGradSuffix);
std::string legacy_name =
this->GetLegacyArgName(op_type.substr(0, type_pos), arg_name);
return legacy_name;
}
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
return arg_name;
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
return arg_name;
if (auto ret = GetDirectMapping(op_type.substr(0, type_pos), arg_name)) {
VLOG(10) << "[" << op_type << "] found " << ret.value();
return ret.value();
}
}
return arg_mappings.at(arg_name);
VLOG(10) << "[" << op_type << "] not found mapping for " << arg_name;
return arg_name;
}
std::string GetLegacyAttrName(const std::string& op_type,
......
......@@ -186,9 +186,7 @@ inline ir::Operation* InsertFullArrayOperationForAttributeInput(
IR_ENFORCE(attr.isa<dialect::IntArrayAttribute>(),
"Encounter non IntArray type when trying to insert IntArray "
"mutable attribute");
phi::IntArray int_array = attr.dyn_cast<dialect::IntArrayAttribute>().data();
ir::Builder builder(ctx, program->block());
dialect::FullIntArrayOp full_int_array_op =
builder.Build<dialect::FullIntArrayOp>(
......@@ -210,40 +208,6 @@ inline ir::Operation* InsertStackOperationForTarget(
return stack_op.operation();
}
inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
ir::Program* program,
const OpDesc& op_desc,
const OpInputInfo& input_info) {
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name);
if (!op_desc.HasAttr(legacy_attr_name)) {
IR_THROW("Op %s arg %s should not be zero size",
op_desc.Type(),
legacy_attr_name);
}
paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name);
VLOG(10) << "[" << op_desc.Type() << "][attribute]"
<< " name: " << legacy_attr_name << " " << legacy_attr.index();
ir::Attribute new_attr =
attribute_translator(input_info.type_name, legacy_attr);
ir::Operation* defining_op = nullptr;
bool is_int_array = (input_info.type_name.find("IntArrayAttribute") !=
input_info.type_name.npos);
if (is_int_array) {
defining_op =
InsertFullArrayOperationForAttributeInput(ctx, program, new_attr);
} else {
defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr);
}
return defining_op->result(0);
}
} // namespace
ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx,
......@@ -301,6 +265,40 @@ void OpTranscriber::InsertSliceOperationForInput(
}
}
ir::OpResult OpTranscriber::GetAttributeAsInput(ir::IrContext* ctx,
ir::Program* program,
const OpDesc& op_desc,
const OpInputInfo& input_info) {
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name);
if (!op_desc.HasAttr(legacy_attr_name)) {
IR_THROW("Op %s arg %s should not be zero size",
op_desc.Type(),
legacy_attr_name);
}
paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name);
VLOG(10) << "[" << op_desc.Type() << "][attribute]"
<< " name: " << legacy_attr_name << " " << legacy_attr.index();
ir::Attribute new_attr =
attribute_translator(input_info.type_name, legacy_attr);
ir::Operation* defining_op = nullptr;
bool is_int_array = (input_info.type_name.find("IntArrayAttribute") !=
input_info.type_name.npos);
if (is_int_array) {
defining_op =
InsertFullArrayOperationForAttributeInput(ctx, program, new_attr);
} else {
defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr);
}
return defining_op->result(0);
}
std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
......@@ -1583,6 +1581,114 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
}
};
struct SetValueOpTranscriber : public OpTranscriber {
ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
ir::Program* program,
const OpDesc& op_desc,
const OpInputInfo& input_info) override {
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name);
if (!op_desc.HasAttr(legacy_attr_name)) {
IR_THROW("Op %s arg %s should not be zero size",
op_desc.Type(),
legacy_attr_name);
}
framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name);
VLOG(10) << "[" << op_desc.Type() << "][attribute]"
<< " name: " << legacy_attr_name << " " << legacy_attr.index();
ir::Attribute new_attr =
attribute_translator("paddle::dialect::IntArrayAttribute", legacy_attr);
ir::Operation* defining_op =
InsertFullArrayOperationForAttributeInput(ctx, program, new_attr);
return defining_op->result(0);
}
};
struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber {
ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
std::string target_op_name = dialect::SetValueWithTensorOp::name();
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
IR_THROW(
"Op set_value should have corresponding OpInfo "
"pd.set_value_with_tensor");
}
return op_info;
}
InputHandlerFn GetSpecialInputHandlers(
const std::string& input_name) override {
if (input_name != "values") {
return nullptr;
}
return [](ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string&,
const OpInputInfo& info,
ir::Program* program) -> ir::OpResult {
std::vector<std::string> legacy_input_vars;
IR_ENFORCE(op_desc.HasInput("ValueTensor"),
"[set_value] should have ValueTensor");
legacy_input_vars = op_desc.Input("ValueTensor", true);
IR_ENFORCE(
legacy_input_vars.size() == 1u,
"[set_value][ValueTensor] should only have 1 variable, but got %d",
legacy_input_vars.size());
auto var_name = legacy_input_vars[0];
auto defining_info = (*param_map)[var_name];
if (defining_info.generated_by_vector) {
InsertSliceOperationForTarget(
ctx, param_map, program, defining_info, var_name);
defining_info = param_map->at(var_name).value;
}
return defining_info.value;
};
}
};
struct SetValueGradOpTranscriber : public SetValueWithTensorOpTranscriber {
ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
std::string target_op_name = dialect::SetValueGradOp::name();
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
IR_THROW(
"Op set_value_grad should have corresponding OpInfo "
"pd.set_value_grad");
}
return op_info;
}
};
struct LegacySetValueDispatcher : public OpTranscriber {
ir::Operation* operator()(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Program* program) override {
std::vector<std::string> legacy_input_vars;
// if op has input with name "ValueTensor", then use that input as value
if (op_desc.HasInput("ValueTensor")) {
legacy_input_vars = op_desc.Input("ValueTensor", true);
if (legacy_input_vars.size() > 0) {
VLOG(10) << "legacy op:" << op_desc.Type()
<< " has ValueTensor and convert to set_value_with_tensor";
return SetValueWithTensorOpTranscriber()(
ctx, param_map, op_desc, program);
}
}
return SetValueOpTranscriber()(ctx, param_map, op_desc, program);
}
};
OpTranslator::OpTranslator() {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
......@@ -1594,6 +1700,7 @@ OpTranslator::OpTranslator() {
special_handlers["feed"] = FeedOpTranscriber();
special_handlers["data"] = DataOpTranscriber();
special_handlers["fetch_v2"] = FetchOpTranscriber();
special_handlers["fill_constant"] = FillConstantTranscriber();
special_handlers["grad_add"] = GradAddOpTranscriber();
special_handlers["increment"] = IncrementOpTranscriber();
special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber();
......@@ -1603,10 +1710,11 @@ OpTranslator::OpTranslator() {
special_handlers["reduce_any"] = ReduceOpTranscriber();
special_handlers["rnn"] = RnnOpTranscriber();
special_handlers["shadow_output"] = ShadowOutputOpTranscriber();
special_handlers["set_value"] = LegacySetValueDispatcher();
special_handlers["set_value_grad"] = SetValueGradOpTranscriber();
special_handlers["split"] = SplitOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber();
special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();
special_handlers["fill_constant"] = FillConstantTranscriber();
// special handler for elementwise ops with axis != -1
// note(lyk): maybe we should do this by a pass, which seems more reasonable
......
......@@ -85,6 +85,10 @@ struct OpTranscriber {
const std::string& normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc);
virtual ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
ir::Program* program,
const OpDesc& op_desc,
const OpInputInfo& input_info);
virtual void RecordOpResultMapping(ir::IrContext* ctx,
TranslationContext* param_map,
......
......@@ -3083,6 +3083,40 @@
outputs:
{out: Out, noise: Noise}
- op: set_value
backward: set_value_grad
inputs:
x : Input
outputs:
out: Out
int_array:
starts:
data_type : int64_t
tensors_name : StartsTensorList
ends:
data_type : int64_t
tensors_name : EndsTensorList
steps:
data_type : int64_t
tensors_name : StepsTensorList
- op: set_value_with_tensor
backward: set_value_grad
inputs:
x : Input
outputs:
out: Out
int_array:
starts:
data_type : int64_t
tensors_name : StartsTensorList
ends:
data_type : int64_t
tensors_name : EndsTensorList
steps:
data_type : int64_t
tensors_name : StepsTensorList
- op: sigmoid_cross_entropy_with_logits
backward: sigmoid_cross_entropy_with_logits_grad
inputs :
......
......@@ -1247,4 +1247,20 @@ void FusedRopeGradInferMeta(const MetaTensor& sin,
}
}
void SetValueGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& values,
MetaTensor* x_grad,
MetaTensor* value_grad) {
if (x_grad) {
x_grad->set_dims(out_grad.dims());
x_grad->set_dtype(out_grad.dtype());
x_grad->share_lod(out_grad);
}
if (value_grad) {
value_grad->set_dims(values.dims());
value_grad->set_dtype(values.dtype());
value_grad->share_lod(values);
}
}
} // namespace phi
......@@ -476,4 +476,9 @@ void IndexPutGradInferMeta(const MetaTensor& x,
bool accumulate,
MetaTensor* x_grad,
MetaTensor* value_grad);
void SetValueGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& values,
MetaTensor* x_grad,
MetaTensor* value_grad);
} // namespace phi
......@@ -325,6 +325,89 @@ class TestShadowOutputSlice(unittest.TestCase):
l = ir.translate_to_new_ir(main_program.desc)
class TestSetValueOp(unittest.TestCase):
def test_no_mutable_attribute(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
exe = paddle.static.Executor(place)
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.ones(shape=[2, 3, 4], dtype="float32")
x = paddle.static.setitem(x, (0, 0), 6)
ret = exe.run(main_program, fetch_list=x.name)
x_data = np.ones([2, 3, 4]).astype("float32")
x_data[0, 0] = 6
np.testing.assert_array_equal(ret[0], x_data)
def test_with_mutable_attribute(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
exe = paddle.static.Executor(place)
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.ones(shape=[2, 3, 4], dtype="float32")
zero = paddle.full([], 0, dtype="int32")
x = paddle.static.setitem(x, zero, 6)
ret = exe.run(main_program, fetch_list=x.name)
x_data = np.ones([2, 3, 4]).astype("float32")
x_data[0] = 6
np.testing.assert_array_equal(ret[0], x_data)
def test_grad(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
exe = paddle.static.Executor(place)
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
input_shape = [7, 6, 5, 4, 3, 2]
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.ones(shape=input_shape, dtype="float32")
value = paddle.tensor.fill_constant([1, 3, 2], "float32", 1)
# test stop_gradient
value.stop_gradient = False
x.stop_gradient = False
attrs = {
'axes': [0],
'starts': [6],
'ends': [0],
'steps': [-4],
'decrease_axes': [],
'none_axes': [],
'dtype': paddle.float32,
}
inputs = {'Input': x, 'ValueTensor': value}
helper = LayerHelper("set_value")
y = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': y},
attrs=attrs,
)
y2 = y + 1
loss = paddle.sum(y2)
opt = paddle.optimizer.Adam()
opt.minimize(loss)
x_data = np.arange(
0, np.prod(input_shape), dtype="float32"
).reshape(input_shape)
fetch_list = [x.grad_name, value.grad_name]
ret = exe.run(main_program, fetch_list=fetch_list)
self.assertTrue((ret[0][6:0:-4] == 0).all())
class TestCheckUnregisteredOp(unittest.TestCase):
def test_program(self):
main_program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册