提交 09c65b6d 编写于 作者: H hedaoyuan

Follow comments.

上级 7bf1e76f
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/gemm_conv_op.h" #include "paddle/operators/gemm_conv2d_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -116,7 +116,7 @@ namespace ops = paddle::operators; ...@@ -116,7 +116,7 @@ namespace ops = paddle::operators;
REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad,
ops::Conv2DOpGrad); ops::Conv2DOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>); conv2d, ops::GemmConv2dKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGrad2dKernel<paddle::platform::CPUPlace, float>);
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/gemm_conv_op.h" #include "paddle/operators/gemm_conv2d_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>); conv2d, ops::GemmConv2dKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGrad2dKernel<paddle::platform::GPUPlace, float>);
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class GemmConvKernel : public framework::OpKernel { class GemmConv2dKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
...@@ -101,7 +101,7 @@ class GemmConvKernel : public framework::OpKernel { ...@@ -101,7 +101,7 @@ class GemmConvKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GemmConvGradKernel : public framework::OpKernel { class GemmConvGrad2dKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册