未验证 提交 95658767 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9428 from JiayiFeng/kernel_of_increment_op

kernels of IncrementOp
...@@ -29,6 +29,11 @@ class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -29,6 +29,11 @@ class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Y", string::Sprintf( AddInput("Y", string::Sprintf(
"(LoDTensor) the right hand operand of %s operator", "(LoDTensor) the right hand operand of %s operator",
comment.type)); comment.type));
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
"device")
.SetDefault(false);
AddOutput("Out", string::Sprintf( AddOutput("Out", string::Sprintf(
"(LoDTensor) n-dim bool tensor. Each element is %s", "(LoDTensor) n-dim bool tensor. Each element is %s",
comment.equation)); comment.equation));
...@@ -75,7 +80,9 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -75,7 +80,9 @@ class CompareOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place // CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place(); bool force_cpu = ctx.Attr<bool>("force_cpu");
kt.place_ = force_cpu ? platform::CPUPlace()
: ctx.Input<framework::LoDTensor>("X")->place();
return kt; return kt;
} }
}; };
......
...@@ -54,7 +54,18 @@ class ConditionalOp : public framework::OperatorBase { ...@@ -54,7 +54,18 @@ class ConditionalOp : public framework::OperatorBase {
"numel should be 1, actual numel is %d", "numel should be 1, actual numel is %d",
ips[0]->numel()); ips[0]->numel());
} }
return ips[0]->data<bool>()[0]; bool res = false;
if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA
framework::LoDTensor cpu_tensor;
framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else {
res = ips[0]->data<bool>()[0];
}
return res;
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/increment_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class IncrementInferShape : public framework::InferShapeBase { class IncrementOp : public framework::OperatorWithKernel {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { IncrementOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of IncrementOp should not be null."); "Input(X) of IncrementOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of IncrementOp should not be null."); "Output(Out) of IncrementOp should not be null.");
PADDLE_ENFORCE_EQ(1, framework::product(ctx->GetInputDim("X"))); PADDLE_ENFORCE_EQ(1, framework::product(ctx->GetInputDim("X")));
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", "Out");
} }
};
struct IncrementFunctor {
IncrementFunctor(const framework::LoDTensor &x, framework::LoDTensor *out,
float value)
: x_(x), out_(out), value_(value) {}
template <typename T>
void operator()() const {
*out_->data<T>() = *x_.data<T>() + static_cast<T>(value_);
}
const framework::LoDTensor &x_;
framework::LoDTensor *out_;
float value_;
};
class IncrementOp : public framework::OperatorBase {
public:
IncrementOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(platform::is_cpu_place(x.place())); protected:
out.Resize(x.dims()); framework::OpKernelType GetExpectedKernelType(
out.mutable_data(x.place(), x.type()); const framework::ExecutionContext &ctx) const override {
float value = Attr<float>("step"); framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
VLOG(10) << Output("Out") << " increase " << Input("X") << " with " // IncrementOp kernel's device type is decided by input tensor place
<< value; kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
framework::VisitDataType(framework::ToDataType(out.type()), return kt;
IncrementFunctor(x, &out, value));
} }
}; };
...@@ -108,5 +83,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker { ...@@ -108,5 +83,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementInferShape, REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker,
ops::IncrementOpMaker, ops::IncrementGradOpMaker); ops::IncrementGradOpMaker);
REGISTER_OP_CPU_KERNEL(
increment, ops::IncrementKernel<paddle::platform::CPUDeviceContext, float>,
ops::IncrementKernel<paddle::platform::CPUDeviceContext, double>,
ops::IncrementKernel<paddle::platform::CPUDeviceContext, int>,
ops::IncrementKernel<paddle::platform::CPUDeviceContext, int64_t>)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/increment_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
increment, ops::IncrementKernel<paddle::platform::CUDADeviceContext, float>,
ops::IncrementKernel<paddle::platform::CUDADeviceContext, double>,
ops::IncrementKernel<paddle::platform::CUDADeviceContext, int>,
ops::IncrementKernel<paddle::platform::CUDADeviceContext, int64_t>)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class IncrementKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_tensor = context.Input<framework::Tensor>("X");
auto* out_tensor = context.Output<framework::Tensor>("Out");
float step = context.Attr<float>("step");
out_tensor->mutable_data<T>(context.GetPlace());
auto& dev =
*context.template device_context<DeviceContext>().eigen_device();
framework::EigenScalar<T>::From(*out_tensor).device(dev) =
framework::EigenScalar<T>::From(*x_tensor) + static_cast<T>(step);
}
};
} // namespace operators
} // namespace paddle
...@@ -18,6 +18,7 @@ from tensor import assign, fill_constant ...@@ -18,6 +18,7 @@ from tensor import assign, fill_constant
from .. import core from .. import core
from ..framework import Program, Variable, Operator from ..framework import Program, Variable, Operator
from ..layer_helper import LayerHelper, unique_name from ..layer_helper import LayerHelper, unique_name
from ..initializer import force_init_on_cpu
from ops import logical_and, logical_not, logical_or from ops import logical_and, logical_not, logical_or
__all__ = [ __all__ = [
...@@ -949,7 +950,7 @@ def create_array(dtype): ...@@ -949,7 +950,7 @@ def create_array(dtype):
dtype=dtype) dtype=dtype)
def less_than(x, y, cond=None, **ignored): def less_than(x, y, force_cpu=True, cond=None, **ignored):
""" """
**Less than** **Less than**
...@@ -958,6 +959,7 @@ def less_than(x, y, cond=None, **ignored): ...@@ -958,6 +959,7 @@ def less_than(x, y, cond=None, **ignored):
Args: Args:
x(Variable): First operand of *less_than* x(Variable): First operand of *less_than*
y(Variable): Second operand of *less_than* y(Variable): Second operand of *less_than*
force_cpu(Bool|True): The output data will be on CPU if set true.
cond(Variable|None): Optional output variable to store the result of *less_than* cond(Variable|None): Optional output variable to store the result of *less_than*
Returns: Returns:
...@@ -974,8 +976,11 @@ def less_than(x, y, cond=None, **ignored): ...@@ -974,8 +976,11 @@ def less_than(x, y, cond=None, **ignored):
cond.stop_gradient = True cond.stop_gradient = True
helper.append_op( helper.append_op(
type='less_than', inputs={'X': [x], type='less_than',
'Y': [y]}, outputs={'Out': [cond]}) inputs={'X': [x],
'Y': [y]},
outputs={'Out': [cond]},
attrs={'force_cpu': force_cpu or force_init_on_cpu()})
return cond return cond
...@@ -1396,7 +1401,8 @@ class DynamicRNN(object): ...@@ -1396,7 +1401,8 @@ class DynamicRNN(object):
type='less_than', type='less_than',
inputs={'X': self.step_idx, inputs={'X': self.step_idx,
'Y': self.max_seq_len}, 'Y': self.max_seq_len},
outputs={'Out': self.cond}) outputs={'Out': self.cond},
attrs={'force_cpu': True})
input_array = parent_block.create_var( input_array = parent_block.create_var(
name=unique_name.generate('dynamic_rnn_input_array'), name=unique_name.generate('dynamic_rnn_input_array'),
...@@ -1445,7 +1451,11 @@ class DynamicRNN(object): ...@@ -1445,7 +1451,11 @@ class DynamicRNN(object):
for new_mem, mem_array in self.mem_link: for new_mem, mem_array in self.mem_link:
array_write(x=new_mem, i=self.step_idx, array=mem_array) array_write(x=new_mem, i=self.step_idx, array=mem_array)
less_than(x=self.step_idx, y=self.max_seq_len, cond=self.cond) less_than(
x=self.step_idx,
y=self.max_seq_len,
force_cpu=True,
cond=self.cond)
self.status = DynamicRNN.AFTER_RNN self.status = DynamicRNN.AFTER_RNN
for each_array in self.output_array: for each_array in self.output_array:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册