提交 01c5ca73 编写于 作者: J JiayiFeng

fix bugs

上级 917b205c
...@@ -29,6 +29,11 @@ class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -29,6 +29,11 @@ class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Y", string::Sprintf( AddInput("Y", string::Sprintf(
"(LoDTensor) the right hand operand of %s operator", "(LoDTensor) the right hand operand of %s operator",
comment.type)); comment.type));
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
"device")
.SetDefault(false);
AddOutput("Out", string::Sprintf( AddOutput("Out", string::Sprintf(
"(LoDTensor) n-dim bool tensor. Each element is %s", "(LoDTensor) n-dim bool tensor. Each element is %s",
comment.equation)); comment.equation));
...@@ -75,7 +80,9 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -75,7 +80,9 @@ class CompareOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place // CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place(); bool force_cpu = ctx.Attr<bool>("force_cpu");
kt.place_ = force_cpu ? platform::CPUPlace()
: ctx.Input<framework::LoDTensor>("X")->place();
return kt; return kt;
} }
}; };
......
...@@ -54,6 +54,8 @@ class WhileOp : public framework::OperatorBase { ...@@ -54,6 +54,8 @@ class WhileOp : public framework::OperatorBase {
auto step_scopes = auto step_scopes =
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
"Condition of while op must in CPU memory.");
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope); step_scopes->push_back(&current_scope);
......
...@@ -18,6 +18,7 @@ from tensor import assign, fill_constant ...@@ -18,6 +18,7 @@ from tensor import assign, fill_constant
from .. import core from .. import core
from ..framework import Program, Variable, Operator from ..framework import Program, Variable, Operator
from ..layer_helper import LayerHelper, unique_name from ..layer_helper import LayerHelper, unique_name
from ..initializer import force_init_on_cpu
from ops import logical_and, logical_not, logical_or from ops import logical_and, logical_not, logical_or
__all__ = [ __all__ = [
...@@ -949,7 +950,7 @@ def create_array(dtype): ...@@ -949,7 +950,7 @@ def create_array(dtype):
dtype=dtype) dtype=dtype)
def less_than(x, y, cond=None, **ignored): def less_than(x, y, force_cpu=True, cond=None, **ignored):
""" """
**Less than** **Less than**
...@@ -958,6 +959,7 @@ def less_than(x, y, cond=None, **ignored): ...@@ -958,6 +959,7 @@ def less_than(x, y, cond=None, **ignored):
Args: Args:
x(Variable): First operand of *less_than* x(Variable): First operand of *less_than*
y(Variable): Second operand of *less_than* y(Variable): Second operand of *less_than*
force_cpu(Bool|True): The output data will be on CPU if set true.
cond(Variable|None): Optional output variable to store the result of *less_than* cond(Variable|None): Optional output variable to store the result of *less_than*
Returns: Returns:
...@@ -974,8 +976,11 @@ def less_than(x, y, cond=None, **ignored): ...@@ -974,8 +976,11 @@ def less_than(x, y, cond=None, **ignored):
cond.stop_gradient = True cond.stop_gradient = True
helper.append_op( helper.append_op(
type='less_than', inputs={'X': [x], type='less_than',
'Y': [y]}, outputs={'Out': [cond]}) inputs={'X': [x],
'Y': [y]},
outputs={'Out': [cond]},
attrs={'force_cpu': force_cpu or force_init_on_cpu()})
return cond return cond
...@@ -1395,7 +1400,8 @@ class DynamicRNN(object): ...@@ -1395,7 +1400,8 @@ class DynamicRNN(object):
type='less_than', type='less_than',
inputs={'X': self.step_idx, inputs={'X': self.step_idx,
'Y': self.max_seq_len}, 'Y': self.max_seq_len},
outputs={'Out': self.cond}) outputs={'Out': self.cond},
attrs={'force_cpu': True})
input_array = parent_block.create_var( input_array = parent_block.create_var(
name=unique_name.generate('dynamic_rnn_input_array'), name=unique_name.generate('dynamic_rnn_input_array'),
...@@ -1443,7 +1449,11 @@ class DynamicRNN(object): ...@@ -1443,7 +1449,11 @@ class DynamicRNN(object):
for new_mem, mem_array in self.mem_link: for new_mem, mem_array in self.mem_link:
array_write(x=new_mem, i=self.step_idx, array=mem_array) array_write(x=new_mem, i=self.step_idx, array=mem_array)
less_than(x=self.step_idx, y=self.max_seq_len, cond=self.cond) less_than(
x=self.step_idx,
y=self.max_seq_len,
force_cpu=True,
cond=self.cond)
self.status = DynamicRNN.AFTER_RNN self.status = DynamicRNN.AFTER_RNN
for each_array in self.output_array: for each_array in self.output_array:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册