提交 c132c790 编写于 作者: X Xin Pan

address comments and resolve conflicts.

test=develop
上级 b91a7a9d
...@@ -100,26 +100,6 @@ class Autograd { ...@@ -100,26 +100,6 @@ class Autograd {
} }
}; };
void CreateVariable(const std::string& name, const framework::DDim& dim,
float val, bool random_name, framework::Variable* var) {
if (var->IsInitialized()) return;
std::string varname = name;
if (random_name) {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist6(
1, std::numeric_limits<int>::max());
int id = dist6(rng);
varname = string::Sprintf("%s@%d", varname, id);
}
VLOG(3) << "creating var " << varname;
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
float* data = tensor->mutable_data<float>(dim, platform::CPUPlace());
std::fill(data, data + tensor->numel(), val);
}
framework::LoDTensor& VarBase::Grad() { framework::LoDTensor& VarBase::Grad() {
VLOG(3) << "get var grad " << var_desc_->Name(); VLOG(3) << "get var grad " << var_desc_->Name();
return *grads_->GetMutable<framework::LoDTensor>(); return *grads_->GetMutable<framework::LoDTensor>();
......
...@@ -36,7 +36,6 @@ class PreparedOp { ...@@ -36,7 +36,6 @@ class PreparedOp {
static PreparedOp Prepare(const framework::RuntimeContext& ctx, static PreparedOp Prepare(const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place) { const platform::Place& place) {
framework::Scope dummy_scope;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -52,7 +51,7 @@ class PreparedOp { ...@@ -52,7 +51,7 @@ class PreparedOp {
framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second; framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = op.GetExpectedKernelType( auto expected_kernel_key = op.GetExpectedKernelType(
framework::ExecutionContext(op, dummy_scope, *dev_ctx, ctx)); framework::ExecutionContext(op, framework::Scope(), *dev_ctx, ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册