提交 f2ddc060 编写于 作者: S superjomn

update

......@@ -72,9 +72,8 @@ class LightPredictor {
// Create the kernels of the target places, and filter out the specific
// kernel with the target alias.
for (auto& op : program.ops_) {
lite::pb::OpDesc desc(op->op_info()->desc());
auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>();
for (auto& op : program.ops()) {
auto kernel_type = op->op_info()->GetAttr<std::string>(kKernelTypeAttr);
std::string op_type, alias;
Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
......@@ -89,8 +88,8 @@ class LightPredictor {
insts.emplace_back(op, std::move(*it));
}
program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope_);
program_->set_exec_scope(program.exec_scope_);
CHECK(program.exec_scope());
program_->set_exec_scope(program.exec_scope());
}
private:
......
......@@ -45,9 +45,9 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
PADDLE_ENFORCE_EQ(grad->numel(), sz);
paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T *lr = learning_rate->data<T>();
const T *param_data = param->data<T>();
const T *grad_data = grad->data<T>();
const T *lr = learning_rate->template data<T>();
const T *param_data = param->template data<T>();
const T *grad_data = grad->template data<T>();
int64_t rows_idx = 0;
T *out_data =
param_out->mutable_data<T>(context.x86_device_context()->GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册