未验证 提交 e1695388 编写于 作者: H hong 提交者: GitHub

fix kernel config bug in dygraph mode; test=develop (#19532)

上级 c2c5b1b9
...@@ -376,8 +376,8 @@ std::vector<VarBasePtrMap> OpBase::ApplyGrad( ...@@ -376,8 +376,8 @@ std::vector<VarBasePtrMap> OpBase::ApplyGrad(
framework::Scope scope; framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_); PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx); p.op.RuntimeInferShape(scope, place_, ctx);
p.func( p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx,
framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr)); p.kernel_configs));
} }
platform::RecordEvent record_event("merge_grads"); platform::RecordEvent record_event("merge_grads");
......
...@@ -98,6 +98,7 @@ class PreparedOp { ...@@ -98,6 +98,7 @@ class PreparedOp {
} }
std::vector<framework::KernelConfig>* kernel_configs = std::vector<framework::KernelConfig>* kernel_configs =
op.GetKernelConfig(expected_kernel_key); op.GetKernelConfig(expected_kernel_key);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs); return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册