diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index a885b301e77d4af1b8db37a76f8be33e07ab4437..4f7cfcf112a0595641b16447b417cbe86db31120 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,74 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/operators/fill_op.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" -#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { -struct FillOpVisitor { - FillOpVisitor(framework::LoDTensor *tensor, const std::vector &value) - : tensor_(tensor), value_(value) {} - - template - void apply() const { - platform::CPUPlace cpu; - auto *data = tensor_->mutable_data(cpu); - std::transform(value_.data(), value_.data() + tensor_->numel(), data, - [](float dat) { return static_cast(dat); }); - } - - framework::LoDTensor *tensor_; - const std::vector &value_; -}; - -class FillOp : public framework::OperatorBase { - public: - FillOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto &out = - detail::Ref(detail::Ref(scope.FindVar(Output("Out")), - "Cannot find variable %s", Output("Out")) - .GetMutable()); - out.Resize(framework::make_ddim(Attr>("shape"))); - auto dtype = - static_cast(Attr("dtype")); - platform::CPUPlace cpu; - auto force_cpu = Attr("force_cpu"); - out.mutable_data(force_cpu ? cpu : place, dtype); - - framework::LoDTensor tensor; - - if (force_cpu || platform::is_cpu_place(place)) { - tensor.ShareDataWith(out); - } else { - // Always make tensor in CPU memory. - tensor.Resize(out.dims()); - tensor.mutable_data(cpu, dtype); - } - - framework::VisitDataType( - dtype, FillOpVisitor(&tensor, Attr>("value"))); - - if (!force_cpu && platform::is_gpu_place(place)) { - // Copy tensor to out - platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - framework::TensorCopy(tensor, place, dev_ctx, &out); - } - } -}; - class FillOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -101,16 +39,42 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by } }; -class FillOpInferShape : public framework::InferShapeBase { +class FillOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* context) const override { + PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true, + "Output(Out) of FillOp should not be null."); + auto& shape = context->Attrs().Get>("shape"); + context->SetOutputDim("Out", framework::make_ddim(shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::proto::VarType::Type(ctx.Attr("dtype")), + ctx.GetPlace()); + } +}; + +class FillOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferShapeContext *context) const override { - context->SetOutputDim( - "Out", - framework::make_ddim(context->Attrs().Get>("shape"))); + void operator()(framework::InferVarTypeContext* ctx) const override { + auto data_type = static_cast( + boost::get(ctx->GetAttr("dtype"))); + auto& out_var_name = ctx->Output("Out").front(); + ctx->SetDataType(out_var_name, data_type); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(fill, ops::FillOp, ops::FillOpInferShape, ops::FillOpMaker); +REGISTER_OPERATOR(fill, ops::FillOp, ops::FillOpMaker, + ops::FillOpVarTypeInference, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(fill, ops::FillKernel, ops::FillKernel, + ops::FillKernel, ops::FillKernel, + ops::FillKernel); diff --git a/paddle/fluid/operators/fill_op.cu.cc b/paddle/fluid/operators/fill_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..fdef8ab2a17080bdd204e3ab5ae83d4107957fc5 --- /dev/null +++ b/paddle/fluid/operators/fill_op.cu.cc @@ -0,0 +1,20 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fill_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(fill, ops::FillKernel, ops::FillKernel, + ops::FillKernel, ops::FillKernel, + ops::FillKernel); diff --git a/paddle/fluid/operators/fill_op.h b/paddle/fluid/operators/fill_op.h new file mode 100644 index 0000000000000000000000000000000000000000..fa2d5b858d95bcafbdcbf975dea1e183444bf118 --- /dev/null +++ b/paddle/fluid/operators/fill_op.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" + +namespace paddle { +namespace operators { + +struct FillOpVisitor { + FillOpVisitor(framework::LoDTensor *tensor, const std::vector &value) + : tensor_(tensor), value_(value) {} + + template + void apply() const { + platform::CPUPlace cpu; + auto *data = tensor_->mutable_data(cpu); + std::transform(value_.data(), value_.data() + tensor_->numel(), data, + [](float dat) { return static_cast(dat); }); + } + + framework::LoDTensor *tensor_; + const std::vector &value_; +}; + +template +class FillKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto &out = + detail::Ref(ctx.Output("Out"), + "Cannot get output lod tensor Out, variable name = %s", + ctx.op().Output("Out")); + out.Resize(framework::make_ddim(ctx.Attr>("shape"))); + auto dtype = + static_cast(ctx.Attr("dtype")); + platform::CPUPlace cpu; + auto force_cpu = ctx.Attr("force_cpu"); + out.mutable_data(force_cpu ? cpu : ctx.GetPlace(), dtype); + + framework::LoDTensor tensor; + + if (force_cpu || platform::is_cpu_place(ctx.GetPlace())) { + tensor.ShareDataWith(out); + } else { + // Always make tensor in CPU memory. + tensor.Resize(out.dims()); + tensor.mutable_data(cpu, dtype); + } + + framework::VisitDataType( + dtype, FillOpVisitor(&tensor, ctx.Attr>("value"))); + + if (!force_cpu && platform::is_gpu_place(ctx.GetPlace())) { + // Copy tensor to out + framework::TensorCopy( + tensor, ctx.GetPlace(), + ctx.template device_context(), &out); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_fill_op.py b/python/paddle/fluid/tests/unittests/test_fill_op.py index b734ee05b3f2291d7a79f1550946bf6546ada6e0..0dd1b0d869ae7f21f1d64c374010e2175e70ee33 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +18,10 @@ import unittest import numpy as np from op_test import OpTest import paddle.fluid.core as core +from paddle.fluid.op import Operator -class TestFillOp(OpTest): +class TestFillOp1(OpTest): def setUp(self): self.op_type = "fill" val = np.random.random(size=[100, 200]) @@ -28,7 +29,8 @@ class TestFillOp(OpTest): self.attrs = { 'value': val.flatten().tolist(), 'shape': [100, 200], - 'dtype': int(core.VarDesc.VarType.FP64) + 'dtype': int(core.VarDesc.VarType.FP64), + 'force_cpu': False } self.outputs = {'Out': val.astype('float64')} @@ -36,5 +38,55 @@ class TestFillOp(OpTest): self.check_output() +class TestFillOp2(OpTest): + def setUp(self): + self.op_type = "fill" + val = np.random.random(size=[100, 200]) + self.inputs = {} + self.attrs = { + 'value': val.flatten().tolist(), + 'shape': [100, 200], + 'dtype': int(core.VarDesc.VarType.FP64), + 'force_cpu': True + } + self.outputs = {'Out': val.astype('float64')} + + def test_check_output(self): + self.check_output() + + +class TestFillOp3(OpTest): + def check_with_place(self, place, f_cpu): + scope = core.Scope() + # create Out Variable + out = scope.var('Out').get_tensor() + + # create and run fill_op operator + val = np.random.random(size=[300, 200]) + fill_op = Operator( + "fill", + value=val.flatten(), + shape=[300, 200], + dtype=int(core.VarDesc.VarType.FP32), + force_cpu=f_cpu, + Out='Out') + fill_op.run(scope, place) + + # get result from Out + result_array = np.array(out) + full_array = np.array(val, 'float32') + + self.assertTrue(np.array_equal(result_array, full_array)) + + def test_fill_op(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + + for place in places: + self.check_with_place(place, True) + self.check_with_place(place, False) + + if __name__ == '__main__': unittest.main()