未验证 提交 f449180b 编写于 作者: Q qingqing01 提交者: GitHub

Register more data type for reshape operator. (#8617)

上级 a67cebaf
...@@ -121,10 +121,15 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -121,10 +121,15 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OP(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, reshape_grad, REGISTER_OP(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, reshape_grad,
ops::ReshapeGradOp); ops::ReshapeGradOp);
REGISTER_OP_CPU_KERNEL(reshape, REGISTER_OP_CPU_KERNEL(reshape, ops::ReshapeKernel<CPU, float>,
ops::ReshapeKernel<paddle::platform::CPUPlace, float>); ops::ReshapeKernel<CPU, double>,
REGISTER_OP_CPU_KERNEL( ops::ReshapeKernel<CPU, int>,
reshape_grad, ops::ReshapeGradKernel<paddle::platform::CPUPlace, float>); ops::ReshapeKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<CPU, float>,
ops::ReshapeGradKernel<CPU, double>,
ops::ReshapeGradKernel<CPU, int>,
ops::ReshapeGradKernel<CPU, int64_t>);
...@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/reshape_op.h" #include "paddle/fluid/operators/reshape_op.h"
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(reshape, paddle::operators::ReshapeKernel<CUDA, float>,
reshape, paddle::operators::ReshapeKernel<CUDA, double>,
paddle::operators::ReshapeKernel<paddle::platform::CUDAPlace, float>); paddle::operators::ReshapeKernel<CUDA, int>,
REGISTER_OP_CUDA_KERNEL( paddle::operators::ReshapeKernel<CUDA, int64_t>);
reshape_grad, REGISTER_OP_CUDA_KERNEL(reshape_grad,
paddle::operators::ReshapeGradKernel<paddle::platform::CUDAPlace, float>); paddle::operators::ReshapeGradKernel<CUDA, float>,
paddle::operators::ReshapeGradKernel<CUDA, double>,
paddle::operators::ReshapeGradKernel<CUDA, int>,
paddle::operators::ReshapeGradKernel<CUDA, int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册