未验证 提交 b341bac7 编写于 作者: Q QI JUN 提交者: GitHub

Refine cast op (#8923)

* fix mac build error

* override GetExpectedKernelType for cast op

* fix typo

* add cuda unittest
上级 84680379
......@@ -63,13 +63,27 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
}
};
class CastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CastOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape,
ops::CastOpProtoMaker);
REGISTER_OPERATOR(cast, ops::CastOp, ops::CastOpGradMaker,
ops::CastOpInferShape, ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>,
......
......@@ -19,6 +19,7 @@ import unittest
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.framework as framework
import paddle.fluid.core as core
def exponential_decay(learning_rate,
......@@ -81,6 +82,16 @@ def piecewise_decay(global_step, boundaries, values):
class TestLearningRateDecay(unittest.TestCase):
def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.check_decay_with_place(place, python_decay_fn, fluid_decay_fn,
kwargs)
def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn,
kwargs):
decayed_lr = fluid_decay_fn(**kwargs)
place = fluid.CPUPlace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册