提交 016445c4 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix compile (#18039)

test=develop
上级 b8572aa3
...@@ -72,9 +72,8 @@ class LightPredictor { ...@@ -72,9 +72,8 @@ class LightPredictor {
// Create the kernels of the target places, and filter out the specific // Create the kernels of the target places, and filter out the specific
// kernel with the target alias. // kernel with the target alias.
for (auto& op : program.ops_) { for (auto& op : program.ops()) {
lite::pb::OpDesc desc(op->op_info()->desc()); auto kernel_type = op->op_info()->GetAttr<std::string>(kKernelTypeAttr);
auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>();
std::string op_type, alias; std::string op_type, alias;
Place place; Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
...@@ -89,8 +88,8 @@ class LightPredictor { ...@@ -89,8 +88,8 @@ class LightPredictor {
insts.emplace_back(op, std::move(*it)); insts.emplace_back(op, std::move(*it));
} }
program_.reset(new RuntimeProgram(std::move(insts))); program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope_); CHECK(program.exec_scope());
program_->set_exec_scope(program.exec_scope_); program_->set_exec_scope(program.exec_scope());
} }
private: private:
......
...@@ -29,9 +29,9 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -29,9 +29,9 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
using param_t = operators::ActivationParam; using param_t = operators::ActivationParam;
void Run() override { void Run() override {
auto &context = context_->As<X86Context>(); auto &context = ctx_->As<X86Context>();
auto &sgd_param = *param_.get_mutable<operators::SGDParam>(); auto &sgd_param = *param_.get_mutable<operators::SGDParam>();
CHECK(context.x86_device_context); CHECK(context.x86_device_context());
// param.Out->template mutable_data<T>(); // param.Out->template mutable_data<T>();
...@@ -45,12 +45,12 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -45,12 +45,12 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
PADDLE_ENFORCE_EQ(grad->numel(), sz); PADDLE_ENFORCE_EQ(grad->numel(), sz);
paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1); paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T *lr = learning_rate->data<T>(); const T *lr = learning_rate->template data<T>();
const T *param_data = param->data<T>(); const T *param_data = param->template data<T>();
const T *grad_data = grad->data<T>(); const T *grad_data = grad->template data<T>();
int64_t rows_idx = 0; int64_t rows_idx = 0;
T *out_data = T *out_data = param_out->template mutable_data<T>(
param_out->mutable_data<T>(context.x86_device_context->GetPlace()); context.x86_device_context()->GetPlace());
auto sgd = auto sgd =
paddle::operators::jit::KernelFuncs<paddle::operators::jit::SgdTuple<T>, paddle::operators::jit::KernelFuncs<paddle::operators::jit::SgdTuple<T>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册