提交 f2ddc060 编写于 作者: S superjomn

update

...@@ -72,9 +72,8 @@ class LightPredictor { ...@@ -72,9 +72,8 @@ class LightPredictor {
// Create the kernels of the target places, and filter out the specific // Create the kernels of the target places, and filter out the specific
// kernel with the target alias. // kernel with the target alias.
for (auto& op : program.ops_) { for (auto& op : program.ops()) {
lite::pb::OpDesc desc(op->op_info()->desc()); auto kernel_type = op->op_info()->GetAttr<std::string>(kKernelTypeAttr);
auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>();
std::string op_type, alias; std::string op_type, alias;
Place place; Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
...@@ -89,8 +88,8 @@ class LightPredictor { ...@@ -89,8 +88,8 @@ class LightPredictor {
insts.emplace_back(op, std::move(*it)); insts.emplace_back(op, std::move(*it));
} }
program_.reset(new RuntimeProgram(std::move(insts))); program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope_); CHECK(program.exec_scope());
program_->set_exec_scope(program.exec_scope_); program_->set_exec_scope(program.exec_scope());
} }
private: private:
......
...@@ -45,9 +45,9 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -45,9 +45,9 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
PADDLE_ENFORCE_EQ(grad->numel(), sz); PADDLE_ENFORCE_EQ(grad->numel(), sz);
paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1); paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T *lr = learning_rate->data<T>(); const T *lr = learning_rate->template data<T>();
const T *param_data = param->data<T>(); const T *param_data = param->template data<T>();
const T *grad_data = grad->data<T>(); const T *grad_data = grad->template data<T>();
int64_t rows_idx = 0; int64_t rows_idx = 0;
T *out_data = T *out_data =
param_out->mutable_data<T>(context.x86_device_context()->GetPlace()); param_out->mutable_data<T>(context.x86_device_context()->GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册