提交 91afa0d8 编写于 作者: H hedaoyuan

Some bug fix.

上级 09c65b6d
...@@ -30,7 +30,7 @@ class Conv2DOp : public framework::OperatorWithKernel { ...@@ -30,7 +30,7 @@ class Conv2DOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto in = ctx.Input<Tensor>("Input"); auto in = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter"); auto filter = ctx.Input<Tensor>("Filter");
auto out = ctx.Output<Tensor>("Output"); auto out = ctx.Output<framework::LoDTensor>("Output");
std::vector<int> strides = Attr<std::vector<int>>("strides"); std::vector<int> strides = Attr<std::vector<int>>("strides");
std::vector<int> paddings = Attr<std::vector<int>>("paddings"); std::vector<int> paddings = Attr<std::vector<int>>("paddings");
int groups = Attr<int>("groups"); int groups = Attr<int>("groups");
...@@ -102,8 +102,10 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { ...@@ -102,8 +102,10 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto in = ctx.Input<Tensor>("Input"); auto in = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter"); auto filter = ctx.Input<Tensor>("Filter");
auto d_in = ctx.Output<Tensor>(framework::GradVarName("Input")); auto d_in =
auto d_filter = ctx.Output<Tensor>(framework::GradVarName("Filter")); ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
auto d_filter =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Filter"));
d_in->Resize(in->dims()); d_in->Resize(in->dims());
d_filter->Resize(filter->dims()); d_filter->Resize(filter->dims());
} }
...@@ -117,6 +119,6 @@ REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, ...@@ -117,6 +119,6 @@ REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad,
ops::Conv2DOpGrad); ops::Conv2DOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d, ops::GemmConv2dKernel<paddle::platform::CPUPlace, float>); conv2d, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGrad2dKernel<paddle::platform::CPUPlace, float>); conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>);
...@@ -17,6 +17,6 @@ ...@@ -17,6 +17,6 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d, ops::GemmConv2dKernel<paddle::platform::GPUPlace, float>); conv2d, ops::GemmConv2DKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGrad2dKernel<paddle::platform::GPUPlace, float>); conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::GPUPlace, float>);
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class GemmConv2dKernel : public framework::OpKernel { class GemmConv2DKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
...@@ -101,7 +101,7 @@ class GemmConv2dKernel : public framework::OpKernel { ...@@ -101,7 +101,7 @@ class GemmConv2dKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GemmConvGrad2dKernel : public framework::OpKernel { class GemmConvGrad2DKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册