未验证 提交 b84e8226 编写于 作者: Y Yu Yang 提交者: GitHub

Cast Operator (#5149)

* Cast Operator

Cast input variable to other data type

* Fix compile error

* Add cast op

* Follow comments
上级 46a13e37
......@@ -34,5 +34,25 @@ inline DataType ToDataType(std::type_index type) {
}
}
template <typename Visitor>
inline void VisitDataType(DataType type, Visitor visitor) {
switch (type) {
case DataType::FP32:
visitor.template operator()<float>();
break;
case DataType::FP64:
visitor.template operator()<double>();
break;
case DataType::INT32:
visitor.template operator()<int>();
break;
case DataType::INT64:
visitor.template operator()<int64_t>();
break;
default:
PADDLE_THROW("Not supported");
}
}
} // namespace framework
} // namespace paddle
......@@ -162,6 +162,10 @@ class OpKernelRegistrar : public Registrar {
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
op_maker_class);
#define REGISTER_OP_WITH_KERNEL(op_type, ...) \
REGISTER_OPERATOR(op_type, ::paddle::framework::OperatorWithKernel, \
##__VA_ARGS__)
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OPERATOR(op_type, op_class, op_maker_class)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cast_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CastOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensor of cast op");
AddOutput("Out", "the output tensor of cast op");
AddComment(R"DOC(Cast operator.
cast the input tensor to other data type.
)DOC");
AddAttr<int>("out_data_type", "output data type");
AddAttr<int>("in_data_type", "input data type");
}
};
class CastOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"), "The input of cast op must be set");
PADDLE_ENFORCE(context->HasOutput("Out"),
"The output of cast op must be set");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
class CastOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto grad = new framework::OpDescBind();
grad->SetType("cast");
grad->SetInput("X", OutputGrad("Out"));
grad->SetOutput("Out", InputGrad("X"));
grad->SetAttr("out_data_type", GetAttr("in_data_type"));
grad->SetAttr("in_data_type", GetAttr("out_data_type"));
return std::unique_ptr<framework::OpDescBind>(grad);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUPlace;
REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape,
ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cast_op.h"
template <typename T>
using CastOpKernel =
paddle::operators::CastOpKernel<paddle::platform::GPUPlace, T>;
REGISTER_OP_GPU_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
CastOpKernel<int>, CastOpKernel<int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/transform.h"
namespace paddle {
namespace operators {
template <typename InT, typename OutT>
struct CastOpTransformFunctor {
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
};
template <typename Place, typename InT>
struct CastOpFunctor {
const framework::Tensor* in_;
framework::Tensor* out_;
const platform::DeviceContext& ctx_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const platform::DeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {}
template <typename OutT>
void operator()() const {
auto* in_begin = in_->data<InT>();
auto numel = in_->numel();
auto* in_end = in_begin + numel;
auto* out_begin = out_->mutable_data<OutT>(ctx_.GetPlace());
platform::Transform<Place> trans;
trans(ctx_, in_begin, in_end, out_begin,
CastOpTransformFunctor<InT, OutT>());
}
};
template <typename Place, typename InT>
class CastOpKernel : public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::DataType>(context.Attr<int>("out_data_type")),
CastOpFunctor<Place, InT>(in, out, context.device_context()));
}
};
} // namespace operators
} // namespace paddle
......@@ -5,7 +5,7 @@ import re
__all__ = [
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
'StaticRNN'
'StaticRNN', 'cast'
]
......@@ -163,6 +163,18 @@ _create_op_func_('mul')
_create_op_func_('dropout')
def cast(x, data_type, program=None):
helper = LayerHelper('cast', **locals())
out = helper.create_tmp_variable(dtype=data_type)
helper.append_op(
type='cast',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={'in_data_type': x.data_type,
'out_data_type': out.data_type})
return out
def concat(input, axis, program=None, init_program=None):
helper = LayerHelper('concat', **locals())
if not isinstance(input, list) and not isinstance(input, tuple):
......
import op_test
import unittest
import numpy as np
import paddle.v2.framework.core as core
class TestCastOp(op_test.OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float64')}
self.attrs = {
'in_data_type': int(core.DataType.FP32),
'out_data_type': int(core.DataType.FP64)
}
self.op_type = 'cast'
def test_check_output(self):
self.check_output()
def test_grad(self):
self.check_grad(['X'], ['Out'])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册