diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 4e8d630c2634682ff63b38182108eadebb5c7ff9..d485cdf6109274377ad0057223bdd8401e964aa7 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -21,7 +21,7 @@ #include "paddle/framework/var_desc.h" #include "paddle/operators/net_op.h" -USE_OP(fill_constant); +USE_NO_KERNEL_OP(fill_constant); namespace paddle { namespace framework { diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h index c5ae7b185460c8b0d68ba38bb9db9bd3d3fb14ea..3ec88d7a72c3339bf5e7d0ca3957a3f608f039b7 100644 --- a/paddle/framework/data_type.h +++ b/paddle/framework/data_type.h @@ -34,6 +34,21 @@ inline DataType ToDataType(std::type_index type) { } } +inline std::type_index ToTypeIndex(DataType type) { + switch (type) { + case DataType::FP32: + return typeid(float); + case DataType::FP64: + return typeid(double); + case DataType::INT32: + return typeid(int); + case DataType::INT64: + return typeid(int64_t); + default: + PADDLE_THROW("Not support type %d", type); + } +} + template inline void VisitDataType(DataType type, Visitor visitor) { switch (type) { diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 10c785e04c4fa2192f9c95513009cf7d8c123868..53b899a23997b71e723a298ec360a4e018d89878 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -79,6 +79,13 @@ DDim make_ddim(const std::vector& dims) { return result; } +DDim make_ddim(const std::vector& dims) { + std::vector res(dims.size()); + std::transform(dims.begin(), dims.end(), res.begin(), + [](int d) { return static_cast(d); }); + return make_ddim(res); +} + /// @cond HIDDEN // XXX For some reason, putting this in an anonymous namespace causes errors class DynamicMutableIndexer : public boost::static_visitor { diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index aa773868ab4b68acbc46dfa2cd2569d8b8b7789d..4ca5e49566b7ec006eba80f3f9808bacb1ff2615 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -81,6 +81,8 @@ struct DDim { */ DDim make_ddim(const std::vector& dims); +DDim make_ddim(const std::vector& dims); + /** * \brief Make a DDim from an initializer list * diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc index 5a1cba51f83bb8577bc94ae23d1a44bb801ae4c7..818f113b90a4c239a857791fb9957e51d3287b97 100644 --- a/paddle/operators/fill_constant_op.cc +++ b/paddle/operators/fill_constant_op.cc @@ -12,33 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/fill_constant_op.h" +#include "paddle/framework/data_type.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { -class FillConstantOp : public framework::OperatorWithKernel { +class FillConstantInferShape : public framework::InferShapeBase { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { + void operator()(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FillConstantOp should not be null."); auto &shape = ctx->Attrs().Get>("shape"); - std::vector shape_int64(shape.size(), 0); - std::transform(shape.begin(), shape.end(), shape_int64.begin(), - [](int a) { return static_cast(a); }); - auto dims = framework::make_ddim(shape_int64); - ctx->SetOutputDim("Out", dims); + ctx->SetOutputDim("Out", framework::make_ddim(shape)); } +}; - protected: - framework::OpKernelType GetKernelType( - const framework::ExecutionContext &ctx) const override { - int data_type = ctx.Attr("data_type"); - VLOG(10) << " FillConstant data_type = " << data_type; - return framework::OpKernelType(static_cast(data_type), - ctx.device_context()); +class FillConstantOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto data_type = static_cast(Attr("data_type")); + auto value = Attr("value"); + auto force_cpu = Attr("force_cpu"); + auto &out = + *scope.FindVar(Output("Out"))->GetMutable(); + out.Resize(framework::make_ddim(Attr>("shape"))); + if (force_cpu) { + auto cpu = platform::CPUPlace(); + out.mutable_data(cpu, framework::ToTypeIndex(data_type)); + } else { + out.mutable_data(dev_ctx.GetPlace(), framework::ToTypeIndex(data_type)); + } + math::set_constant(dev_ctx, &out, value); } }; @@ -54,6 +62,11 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("shape", "(vector) The shape of the output"); AddAttr("value", "(float, default 0) The value to be filled") .SetDefault(0.0f); + AddAttr("force_cpu", + "(bool, default false) Force fill output variable to cpu " + "memory. Otherwise, fill output variable to the running " + "device") + .SetDefault(false); AddOutput("Out", "(Tensor) Tensor of specified shape will be filled " "with the specified value"); @@ -69,10 +82,6 @@ Fill up a variable with specified constant value. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp, - ops::FillConstantOpMaker); -REGISTER_OP_CPU_KERNEL( - fill_constant, ops::FillConstantOpKernel, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel); +REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, + ops::FillConstantInferShape, ops::FillConstantOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/operators/fill_constant_op.cu b/paddle/operators/fill_constant_op.cu deleted file mode 100644 index bca402a8b988b570a083e9ce253342304f4b8946..0000000000000000000000000000000000000000 --- a/paddle/operators/fill_constant_op.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#define EIGEN_USE_GPU -#include "paddle/framework/op_registry.h" -#include "paddle/operators/fill_constant_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - fill_constant, ops::FillConstantOpKernel, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel); diff --git a/paddle/operators/fill_constant_op.h b/paddle/operators/fill_constant_op.h deleted file mode 100644 index 3668f42f1c29541e29463ff3969064e80703fa04..0000000000000000000000000000000000000000 --- a/paddle/operators/fill_constant_op.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class FillConstantOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - auto value = ctx.Attr("value"); - - auto out_eigen = framework::EigenVector::Flatten(*out); - auto place = ctx.GetEigenDevice(); - out_eigen.device(place) = out_eigen.constant(static_cast(value)); - } -}; - -} // namespace operators -} // namespace paddle