提交 65f953a6 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #3142 from Canpio/dev_add_FillZerosLikeOp_test

Add unittest for FillZerosLikeOp
......@@ -43,4 +43,5 @@ cc_library(paddle_pybind SHARED
add_op
mean_op
cross_entropy_op
fill_zeros_like_op
recurrent_op)
......@@ -40,6 +40,7 @@ USE_OP(mean);
USE_OP(sigmoid);
USE_OP(softmax);
USE_OP(rowwise_add);
USE_OP(fill_zeros_like);
USE_OP_WITHOUT_KERNEL(recurrent_op);
namespace paddle {
namespace framework {
......
......@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fill_zeros_like_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
......
......@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_zeros_like_op.h"
......
......@@ -13,9 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace operators {
......@@ -26,7 +24,8 @@ class FillZerosLikeKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto* output = context.Output<framework::Tensor>(0);
output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).setZero();
auto t = framework::EigenVector<T>::Flatten(*output);
t.device(context.GetEigenDevice<Place>()) = t.constant(T(0));
}
};
......
......@@ -6,4 +6,5 @@ cc_library(paddle_pybind SHARED
add_op
mean_op
cross_entropy_op
recurrent_op)
recurrent_op
fill_zeros_like_op)
......@@ -13,6 +13,7 @@ py_test(test_protobuf SRCS test_protobuf.py)
py_test(test_add_two_op SRCS test_add_two_op.py)
py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
py_test(test_softmax_op SRCS test_softmax_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(gradient_checker SRCS gradient_checker.py)
......
import unittest
from op_test_util import OpTestMeta
import numpy
class TestFillZerosLikeOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "fill_zeros_like"
self.inputs = {'Src': numpy.random.random((219, 232)).astype("float32")}
self.outputs = {'Dst': numpy.zeros_like(self.inputs['Src'])}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册