diff --git a/paddle/fluid/operators/increment_op.cc b/paddle/fluid/operators/increment_op.cc index 6b5c3db13c0929ae0dd2fb2c981867df0a36c1ce..2893ab7127601bc2e787e44859086dc354693e89 100644 --- a/paddle/fluid/operators/increment_op.cc +++ b/paddle/fluid/operators/increment_op.cc @@ -1,71 +1,37 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/increment_op.h" namespace paddle { namespace operators { -class IncrementInferShape : public framework::InferShapeBase { +class IncrementOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { + IncrementOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of IncrementOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of IncrementOp should not be null."); PADDLE_ENFORCE_EQ(1, framework::product(ctx->GetInputDim("X"))); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - } -}; - -struct IncrementFunctor { - IncrementFunctor(const framework::LoDTensor &x, framework::LoDTensor *out, - float value) - : x_(x), out_(out), value_(value) {} - - template - void operator()() const { - *out_->data() = *x_.data() + static_cast(value_); - } - - const framework::LoDTensor &x_; - framework::LoDTensor *out_; - float value_; -}; - -class IncrementOp : public framework::OperatorBase { - public: - IncrementOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto &x = scope.FindVar(Input("X"))->Get(); - auto &out = - *scope.FindVar(Output("Out"))->GetMutable(); - - PADDLE_ENFORCE(platform::is_cpu_place(x.place())); - out.Resize(x.dims()); - out.mutable_data(x.place(), x.type()); - float value = Attr("step"); - VLOG(10) << Output("Out") << " increase " << Input("X") << " with " - << value; - framework::VisitDataType(framework::ToDataType(out.type()), - IncrementFunctor(x, &out, value)); + ctx->ShareLoD("X", "Out"); } }; @@ -108,5 +74,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementInferShape, - ops::IncrementOpMaker, ops::IncrementGradOpMaker); +REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker, + ops::IncrementGradOpMaker); +REGISTER_OP_CPU_KERNEL( + increment, ops::IncrementKernel, + ops::IncrementKernel, + ops::IncrementKernel, + ops::IncrementKernel) diff --git a/paddle/fluid/operators/increment_op.cu b/paddle/fluid/operators/increment_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..0b6cb1fc853d00a8cb92d8a21ff687d019554be0 --- /dev/null +++ b/paddle/fluid/operators/increment_op.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/minus_op.h" + +REGISTER_OP_CUDA_KERNEL( + increment, ops::IncrementKernel, + ops::IncrementKernel, + ops::IncrementKernel, + ops::IncrementKernel) \ No newline at end of file diff --git a/paddle/fluid/operators/increment_op.h b/paddle/fluid/operators/increment_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d0e8c66255ef68b975701fb6b3c145be2590e271 --- /dev/null +++ b/paddle/fluid/operators/increment_op.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class IncrementKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x_tensor = context.Input("X"); + auto* out_tensor = context.Output("Out"); + float step = context.Attr("step"); + + out_tensor->mutable_data(context.GetPlace()); + auto& dev = + *context.template device_context().eigen_device(); + framework::EigenScalar::From(*out_tensor).device(dev) = + framework::EigenScalar::From(*x_tensor) + static_cast(step); + } +}; + +} // namespace operators +} // namespace paddle