提交 b1bb2384 编写于 作者: Z zhongpu 提交者: Jiabin Yang

add kernel for fill_op, test=develop (#19719)

* add kernel for fill_op, test=develop

* modify PADDLE_ENFORCE to PADDLE_ENFORCE_EQ, test=develop

* add op test for fill_op, test=develop

* REGISTER COP CUDA KERNEL, test=develop

* update test_fill_op.py, test=develop

* change FillConstantOpVarTypeInference to FillOpVarTypeInference, test=develop

* fix op test, test=develop

* add head file, test=develop
上级 382d099d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,74 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/fill_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
struct FillOpVisitor {
FillOpVisitor(framework::LoDTensor *tensor, const std::vector<float> &value)
: tensor_(tensor), value_(value) {}
template <typename T>
void apply() const {
platform::CPUPlace cpu;
auto *data = tensor_->mutable_data<T>(cpu);
std::transform(value_.data(), value_.data() + tensor_->numel(), data,
[](float dat) { return static_cast<T>(dat); });
}
framework::LoDTensor *tensor_;
const std::vector<float> &value_;
};
class FillOp : public framework::OperatorBase {
public:
FillOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &out =
detail::Ref(detail::Ref(scope.FindVar(Output("Out")),
"Cannot find variable %s", Output("Out"))
.GetMutable<framework::LoDTensor>());
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
auto dtype =
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : place, dtype);
framework::LoDTensor tensor;
if (force_cpu || platform::is_cpu_place(place)) {
tensor.ShareDataWith(out);
} else {
// Always make tensor in CPU memory.
tensor.Resize(out.dims());
tensor.mutable_data(cpu, dtype);
}
framework::VisitDataType(
dtype, FillOpVisitor(&tensor, Attr<std::vector<float>>("value")));
if (!force_cpu && platform::is_gpu_place(place)) {
// Copy tensor to out
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::TensorCopy(tensor, place, dev_ctx, &out);
}
}
};
class FillOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -101,16 +39,42 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by
}
};
class FillOpInferShape : public framework::InferShapeBase {
class FillOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* context) const override {
PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true,
"Output(Out) of FillOp should not be null.");
auto& shape = context->Attrs().Get<std::vector<int>>("shape");
context->SetOutputDim("Out", framework::make_ddim(shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class FillOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim(
"Out",
framework::make_ddim(context->Attrs().Get<std::vector<int>>("shape")));
void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, data_type);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fill, ops::FillOp, ops::FillOpInferShape, ops::FillOpMaker);
REGISTER_OPERATOR(fill, ops::FillOp, ops::FillOpMaker,
ops::FillOpVarTypeInference,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fill, ops::FillKernel<float>, ops::FillKernel<double>,
ops::FillKernel<int64_t>, ops::FillKernel<int>,
ops::FillKernel<paddle::platform::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fill, ops::FillKernel<float>, ops::FillKernel<double>,
ops::FillKernel<int64_t>, ops::FillKernel<int>,
ops::FillKernel<paddle::platform::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle {
namespace operators {
struct FillOpVisitor {
FillOpVisitor(framework::LoDTensor *tensor, const std::vector<float> &value)
: tensor_(tensor), value_(value) {}
template <typename T>
void apply() const {
platform::CPUPlace cpu;
auto *data = tensor_->mutable_data<T>(cpu);
std::transform(value_.data(), value_.data() + tensor_->numel(), data,
[](float dat) { return static_cast<T>(dat); });
}
framework::LoDTensor *tensor_;
const std::vector<float> &value_;
};
template <typename T>
class FillKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto &out =
detail::Ref(ctx.Output<framework::LoDTensor>("Out"),
"Cannot get output lod tensor Out, variable name = %s",
ctx.op().Output("Out"));
out.Resize(framework::make_ddim(ctx.Attr<std::vector<int>>("shape")));
auto dtype =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
platform::CPUPlace cpu;
auto force_cpu = ctx.Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : ctx.GetPlace(), dtype);
framework::LoDTensor tensor;
if (force_cpu || platform::is_cpu_place(ctx.GetPlace())) {
tensor.ShareDataWith(out);
} else {
// Always make tensor in CPU memory.
tensor.Resize(out.dims());
tensor.mutable_data(cpu, dtype);
}
framework::VisitDataType(
dtype, FillOpVisitor(&tensor, ctx.Attr<std::vector<float>>("value")));
if (!force_cpu && platform::is_gpu_place(ctx.GetPlace())) {
// Copy tensor to out
framework::TensorCopy(
tensor, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), &out);
}
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,9 +18,10 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestFillOp(OpTest):
class TestFillOp1(OpTest):
def setUp(self):
self.op_type = "fill"
val = np.random.random(size=[100, 200])
......@@ -28,7 +29,8 @@ class TestFillOp(OpTest):
self.attrs = {
'value': val.flatten().tolist(),
'shape': [100, 200],
'dtype': int(core.VarDesc.VarType.FP64)
'dtype': int(core.VarDesc.VarType.FP64),
'force_cpu': False
}
self.outputs = {'Out': val.astype('float64')}
......@@ -36,5 +38,55 @@ class TestFillOp(OpTest):
self.check_output()
class TestFillOp2(OpTest):
def setUp(self):
self.op_type = "fill"
val = np.random.random(size=[100, 200])
self.inputs = {}
self.attrs = {
'value': val.flatten().tolist(),
'shape': [100, 200],
'dtype': int(core.VarDesc.VarType.FP64),
'force_cpu': True
}
self.outputs = {'Out': val.astype('float64')}
def test_check_output(self):
self.check_output()
class TestFillOp3(OpTest):
def check_with_place(self, place, f_cpu):
scope = core.Scope()
# create Out Variable
out = scope.var('Out').get_tensor()
# create and run fill_op operator
val = np.random.random(size=[300, 200])
fill_op = Operator(
"fill",
value=val.flatten(),
shape=[300, 200],
dtype=int(core.VarDesc.VarType.FP32),
force_cpu=f_cpu,
Out='Out')
fill_op.run(scope, place)
# get result from Out
result_array = np.array(out)
full_array = np.array(val, 'float32')
self.assertTrue(np.array_equal(result_array, full_array))
def test_fill_op(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place, True)
self.check_with_place(place, False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册