Apollo  v5.5.0
Open source self driving car software
paddle_net.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2019 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 
17 #pragma once
18 
19 #include <chrono>
20 #include <iostream>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <unordered_map>
26 #include <utility>
27 #include <vector>
28 
29 #include "paddle/paddle_inference_api.h"
30 
32 
33 namespace apollo {
34 namespace perception {
35 namespace inference {
36 
37 typedef std::shared_ptr<apollo::perception::base::Blob<float>> BlobPtr;
38 
39 constexpr uint64_t MemoryPoolInitSizeMb = 100;
40 
41 class PaddleNet : public Inference {
42  public:
43  PaddleNet(const std::string &model_file, const std::string &param_file,
44  const std::vector<std::string> &outputs);
45 
46  PaddleNet(const std::string &model_file, const std::string &param_file,
47  const std::vector<std::string> &outputs,
48  const std::vector<std::string> &inputs);
49 
50  virtual ~PaddleNet() {}
51 
52  bool Init(const std::map<std::string, std::vector<int>> &shapes) override;
53 
54  void Infer() override;
55  BlobPtr get_blob(const std::string &name) override;
56 
57  protected:
58  bool reshape();
59  bool shape(const std::string &name, std::vector<int> *res);
60  std::shared_ptr<paddle::PaddlePredictor> predictor_ = nullptr;
61 
62  private:
63  std::string model_file_;
64  std::string param_file_;
65  std::vector<std::string> output_names_;
66  std::vector<std::string> input_names_;
67  BlobMap blobs_;
68 
69  std::unordered_map<std::string, std::string> name_map_ = {
70  // object detection
71  {"data", "input"},
72  {"obj_pred", "save_infer_model/scale_0"},
73  {"cls_pred", "save_infer_model/scale_1"},
74  {"ori_pred", "save_infer_model/scale_2"},
75  {"dim_pred", "save_infer_model/scale_3"},
76  {"brvis_pred", "save_infer_model/scale_4"},
77  {"ltvis_pred", "save_infer_model/scale_5"},
78  {"rtvis_pred", "save_infer_model/scale_6"},
79  {"brswt_pred", "save_infer_model/scale_7"},
80  {"ltswt_pred", "save_infer_model/scale_8"},
81  {"rtswt_pred", "save_infer_model/scale_9"},
82  {"loc_pred", "save_infer_model/scale_13"},
83  {"conv3_3", "save_infer_model/scale_14"},
84  // lane line
85  {"softmax", "save_infer_model/scale_0"},
86  // lidar cnn_seg
87  {"confidence_score", "save_infer_model/scale_0"},
88  {"class_score", "save_infer_model/scale_1"},
89  {"category_score", "save_infer_model/scale_2"},
90  {"instance_pt", "save_infer_model/scale_3"},
91  {"heading_pt", "save_infer_model/scale_4"},
92  {"height_pt", "save_infer_model/scale_5"}};
93 };
94 
95 } // namespace inference
96 } // namespace perception
97 } // namespace apollo
Definition: blob.h:72
BlobPtr get_blob(const std::string &name) override
std::map< std::string, std::shared_ptr< apollo::perception::base::Blob< float > > > BlobMap
Definition: inference.h:34
bool Init(const std::map< std::string, std::vector< int >> &shapes) override
bool shape(const std::string &name, std::vector< int > *res)
std::shared_ptr< paddle::PaddlePredictor > predictor_
Definition: paddle_net.h:60
constexpr uint64_t MemoryPoolInitSizeMb
Definition: paddle_net.h:39
Definition: paddle_net.h:41
std::shared_ptr< apollo::perception::base::Blob< float > > BlobPtr
Definition: caffe_net.h:33
PaddleNet(const std::string &model_file, const std::string &param_file, const std::vector< std::string > &outputs)
virtual ~PaddleNet()
Definition: paddle_net.h:50