未验证 提交 f449180b 编写于 作者: Q qingqing01 提交者: GitHub

Register more data type for reshape operator. (#8617)

上级 a67cebaf
......@@ -121,10 +121,15 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OP(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, reshape_grad,
ops::ReshapeGradOp);
REGISTER_OP_CPU_KERNEL(reshape,
ops::ReshapeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
reshape_grad, ops::ReshapeGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(reshape, ops::ReshapeKernel<CPU, float>,
ops::ReshapeKernel<CPU, double>,
ops::ReshapeKernel<CPU, int>,
ops::ReshapeKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<CPU, float>,
ops::ReshapeGradKernel<CPU, double>,
ops::ReshapeGradKernel<CPU, int>,
ops::ReshapeGradKernel<CPU, int64_t>);
......@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reshape_op.h"
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(
reshape,
paddle::operators::ReshapeKernel<paddle::platform::CUDAPlace, float>);
REGISTER_OP_CUDA_KERNEL(
reshape_grad,
paddle::operators::ReshapeGradKernel<paddle::platform::CUDAPlace, float>);
REGISTER_OP_CUDA_KERNEL(reshape, paddle::operators::ReshapeKernel<CUDA, float>,
paddle::operators::ReshapeKernel<CUDA, double>,
paddle::operators::ReshapeKernel<CUDA, int>,
paddle::operators::ReshapeKernel<CUDA, int64_t>);
REGISTER_OP_CUDA_KERNEL(reshape_grad,
paddle::operators::ReshapeGradKernel<CUDA, float>,
paddle::operators::ReshapeGradKernel<CUDA, double>,
paddle::operators::ReshapeGradKernel<CUDA, int>,
paddle::operators::ReshapeGradKernel<CUDA, int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册