提交 736d078c 编写于 作者: Q qijun

replace Tensor::tensor to EigenTensor::From

上级 8ad9006d
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/framework/op_registry.h> #include "paddle/operators/mul_op.h"
#include <paddle/framework/tensor.h> #include "paddle/framework/op_registry.h"
#include <paddle/operators/mul_op.h> #include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/operators/mul_op.h> #include "paddle/operators/mul_op.h"
#include <paddle/framework/op_registry.h> #include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(mul, REGISTER_OP_GPU_KERNEL(mul,
paddle::operators::MulKernel<paddle::platform paddle::operators::MulKernel<paddle::platform
......
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
#pragma once #pragma once
#include <glog/logging.h> #include "glog/logging.h"
#include <paddle/framework/operator.h> #include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -34,8 +35,10 @@ public: ...@@ -34,8 +35,10 @@ public:
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
output->matrix<T>().device(*(context.GetEigenDevice<Place>())) = framework::EigenMatrix<T>::From(*output).device(
input0.matrix<T>().contract(input1.matrix<T>(), dim_pair); *(context.GetEigenDevice<Place>())) =
framework::EigenMatrix<T>::From(input0).contract(
framework::EigenMatrix<T>::From(input1), dim_pair);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/framework/op_registry.h> #include "paddle/operators/rowwise_add_op.h"
#include <paddle/operators/rowwise_add_op.h> #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
#include <paddle/framework/op_registry.h> #include "paddle/framework/op_registry.h"
#include <paddle/operators/rowwise_add_op.h> #include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
rowwise_add, rowwise_add,
......
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h> #include "glog/logging.h"
#include <paddle/framework/operator.h> #include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,9 +28,9 @@ public: ...@@ -27,9 +28,9 @@ public:
auto in1 = context.Input(1)->Get<framework::Tensor>(); auto in1 = context.Input(1)->Get<framework::Tensor>();
auto* out = context.Output(0)->GetMutable<framework::Tensor>(); auto* out = context.Output(0)->GetMutable<framework::Tensor>();
auto input = in0.matrix<T>(); auto input = framework::EigenMatrix<T>::From(in0);
auto bias = in1.vec<T>(); auto bias = framework::EigenVector<T>::From(in1);
auto output = out->matrix<T>(); auto output = framework::EigenMatrix<T>::From(*out);
const int bias_size = bias.dimension(0); const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size; const int rest_size = input.size() / bias_size;
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/framework/op_registry.h> #include "paddle/operators/sigmoid_op.h"
#include <paddle/operators/sigmoid_op.h> #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
#include <paddle/operators/sigmoid_op.h> #include "paddle/operators/sigmoid_op.h"
#include <paddle/framework/op_registry.h> #include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
sigmoid, paddle::operators::SigmoidKernel<paddle::platform::GPUPlace, float>); sigmoid, paddle::operators::SigmoidKernel<paddle::platform::GPUPlace, float>);
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
#pragma once #pragma once
#include <glog/logging.h> #include "glog/logging.h"
#include <paddle/framework/operator.h> #include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,8 +30,9 @@ public: ...@@ -29,8 +30,9 @@ public:
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
output->flat<T>().device(*(context.GetEigenDevice<Place>())) = framework::EigenVector<T>::Flatten(*output).device(
1.0 / (1.0 + (-1.0 * input.flat<T>()).exp()); *(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * framework::EigenVector<T>::Flatten(input)).exp());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/framework/op_registry.h> #include "paddle/operators/softmax_op.h"
#include <paddle/operators/softmax_op.h> #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
#include <paddle/framework/op_registry.h> #include "paddle/framework/op_registry.h"
#include <paddle/operators/softmax_op.h> #include "paddle/operators/softmax_op.h"
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
softmax, paddle::operators::SoftmaxKernel<paddle::platform::GPUPlace, float>); softmax, paddle::operators::SoftmaxKernel<paddle::platform::GPUPlace, float>);
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
#pragma once #pragma once
#include <glog/logging.h> #include "glog/logging.h"
#include <paddle/framework/operator.h> #include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,8 +28,8 @@ public: ...@@ -27,8 +28,8 @@ public:
auto input = context.Input(0)->Get<framework::Tensor>(); auto input = context.Input(0)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<framework::Tensor>();
auto logits = input.matrix<T>(); auto logits = framework::EigenMatrix<T>::From(input);
auto softmax = output->matrix<T>(); auto softmax = framework::EigenMatrix<T>::From(*output);
const int kBatchDim = 0; const int kBatchDim = 0;
const int kClassDim = 1; const int kClassDim = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册