未验证 提交 b341bac7 编写于 作者: Q QI JUN 提交者: GitHub

Refine cast op (#8923)

* fix mac build error

* override GetExpectedKernelType for cast op

* fix typo

* add cuda unittest
上级 84680379
...@@ -63,13 +63,27 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -63,13 +63,27 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
class CastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CastOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape, REGISTER_OPERATOR(cast, ops::CastOp, ops::CastOpGradMaker,
ops::CastOpProtoMaker); ops::CastOpInferShape, ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>, REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>, ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int>,
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
import paddle.fluid.core as core
def exponential_decay(learning_rate, def exponential_decay(learning_rate,
...@@ -81,6 +82,16 @@ def piecewise_decay(global_step, boundaries, values): ...@@ -81,6 +82,16 @@ def piecewise_decay(global_step, boundaries, values):
class TestLearningRateDecay(unittest.TestCase): class TestLearningRateDecay(unittest.TestCase):
def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs): def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.check_decay_with_place(place, python_decay_fn, fluid_decay_fn,
kwargs)
def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn,
kwargs):
decayed_lr = fluid_decay_fn(**kwargs) decayed_lr = fluid_decay_fn(**kwargs)
place = fluid.CPUPlace() place = fluid.CPUPlace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册