Apollo  v5.5.0
Open source self driving car software
async_sequence_data_loader.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2019 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 #pragma once
17 #include <algorithm>
18 #include <future>
19 #include <map>
20 #include <memory>
21 #include <mutex>
22 #include <string>
23 #include <vector>
26 
27 namespace apollo {
28 namespace perception {
29 namespace benchmark {
30 
31 template <class Dtype>
32 struct Cache {
33  std::shared_ptr<Dtype> data;
34  bool loaded = false;
35  bool load_success = false;
36  std::future<void> status;
37 };
38 
39 // @brief General asynchronous data loader class, class DataType must implement
40 // "bool load(const std::vector<std::string>& filenames)" member function
41 template <class DataType>
42 class AsyncSequenceDataLoader : public SequenceDataLoader<DataType> {
43  public:
44  AsyncSequenceDataLoader() = default;
45  ~AsyncSequenceDataLoader() = default;
46  void set(std::size_t cache_size, std::size_t prefetch_size,
47  std::size_t thread_num) {
48  std::lock_guard<std::mutex> lock(_mutex);
49  _fixed_cache_size = cache_size;
50  _prefetch_data_size = prefetch_size;
51  if (_thread_pool == nullptr) {
52  _thread_pool.reset(new ctpl::thread_pool(static_cast<int>(thread_num)));
53  } else {
54  _thread_pool->stop(true);
55  _thread_pool->start(static_cast<int>(thread_num));
56  }
57  _cached_data.clear();
58  }
59  bool query_next(std::shared_ptr<DataType>& data) override; // NOLINT
60  bool query_last(std::shared_ptr<DataType>& data) override; // NOLINT
61 
62  protected:
63  using CachePtr = std::shared_ptr<Cache<DataType>>;
67 
68  protected:
69  std::size_t _fixed_cache_size = 50;
70  std::size_t _prefetch_data_size = 5;
71 
72  private:
73  std::map<int, CachePtr> _cached_data;
74  std::unique_ptr<ctpl::thread_pool> _thread_pool;
75  std::mutex _mutex;
76 };
77 
78 // @brief pipeline for query next and query last
79 // case 1. not load, then loading
80 // case 2. pre-loading, then wait
81 // case 3. loaded, then do nothing
82 // finally trigger prefetching
83 
84 template <class DataType>
86  std::shared_ptr<DataType>& data) { // NOLINT
87  if (_thread_pool == nullptr) {
88  return false;
89  }
90  if (!_initialized) {
91  return false;
92  }
93  std::lock_guard<std::mutex> lock(_mutex);
94  ++_idx;
95  if (_idx >= static_cast<int>(_filenames[0].size())) {
96  return false;
97  } else if (_idx < 0) {
98  _idx = 0;
99  }
100 
101  bool load_success = false;
102  // load current idx
103  auto iter = _cached_data.find(_idx);
104  // not load yet
105  if (iter == _cached_data.end()) {
106  std::cerr << "Fail to prefetch, start loading..." << std::endl;
107  CachePtr cache_ptr(new Cache<DataType>);
108  cache_ptr->data.reset(new DataType);
109  std::vector<std::string> files;
110  for (auto& names : _filenames) {
111  files.push_back(names[_idx]);
112  }
113  cache_ptr->loaded = true;
114  cache_ptr->load_success = cache_ptr->data->load(files);
115  _cached_data.emplace(_idx, cache_ptr);
116  data = cache_ptr->data;
117  load_success = cache_ptr->load_success;
118  } else {
119  // loaded
120  if (!iter->second->loaded) {
121  iter->second->status.wait();
122  iter->second->loaded = true;
123  }
124  data = iter->second->data;
125  load_success = iter->second->load_success;
126  }
127  // prefetch next data
128  for (int i = _idx + 1;
129  i <= std::min(static_cast<std::size_t>(_idx) + _prefetch_data_size,
130  _filenames[0].size() - 1);
131  ++i) {
132  auto prefetch_iter = _cached_data.find(i);
133  if (prefetch_iter == _cached_data.end()) {
134  CachePtr cache_ptr(new Cache<DataType>);
135  cache_ptr->data.reset(new DataType);
136  std::vector<std::string> files;
137  for (auto& names : _filenames) {
138  files.push_back(names[i]);
139  }
140  cache_ptr->status = _thread_pool->push([cache_ptr, files](int id) {
141  cache_ptr->load_success = cache_ptr->data->load(files);
142  });
143  _cached_data.emplace(i, cache_ptr);
144  }
145  }
146  // clean stale data on the other end
147  for (auto citer = _cached_data.begin();
148  citer != _cached_data.end() &&
149  _cached_data.size() > _fixed_cache_size;) {
150  if (citer->second->loaded && _idx != citer->first) {
151  citer->second->data->release();
152  _cached_data.erase(citer++);
153  } else {
154  break;
155  }
156  }
157  return load_success;
158 }
159 
160 template <class DataType>
162  std::shared_ptr<DataType>& data) { // NOLINT
163  if (_thread_pool == nullptr) {
164  return false;
165  }
166  if (!_initialized) {
167  return false;
168  }
169  std::lock_guard<std::mutex> lock(_mutex);
170  if (data == nullptr) {
171  data.reset(new DataType);
172  }
173  --_idx;
174  if (_idx < 0) {
175  return false;
176  } else if (_idx >= static_cast<int>(_filenames[0].size())) {
177  _idx = static_cast<int>(_filenames[0].size() - 1);
178  }
179  bool load_success = false;
180  // load current idx
181  auto iter = _cached_data.find(_idx);
182  // not load yet
183  if (iter == _cached_data.end()) {
184  std::cerr << "Fail to prefetch, start loading..." << std::endl;
185  CachePtr cache_ptr(new Cache<DataType>);
186  cache_ptr->data.reset(new DataType);
187  cache_ptr->loaded = true;
188  std::vector<std::string> files;
189  for (auto& names : _filenames) {
190  files.push_back(names[_idx]);
191  }
192  cache_ptr->load_success = cache_ptr->data->load(files);
193  _cached_data.emplace(_idx, cache_ptr);
194  data = cache_ptr->data;
195  load_success = cache_ptr->load_success;
196  } else {
197  // loaded
198  if (!iter->second->loaded) {
199  iter->second->status.wait();
200  iter->second->loaded = true;
201  }
202  data = iter->second->data;
203  load_success = iter->second->load_success;
204  }
205  // prefetch last data
206  for (int i = _idx - 1;
207  i >= std::max(_idx - static_cast<int>(_prefetch_data_size), 0); --i) {
208  auto prefetch_iter = _cached_data.find(i);
209  if (prefetch_iter == _cached_data.end()) {
210  CachePtr cache_ptr(new Cache<DataType>);
211  cache_ptr->data.reset(new DataType);
212  std::vector<std::string> files;
213  for (auto& names : _filenames) {
214  files.push_back(names[i]);
215  }
216  cache_ptr->status = _thread_pool->push([cache_ptr, files](int id) {
217  cache_ptr->load_success = cache_ptr->data->load(files);
218  });
219  _cached_data.emplace(i, cache_ptr);
220  }
221  }
222  // clean stale data on the other end
223  for (auto citer = _cached_data.rbegin();
224  citer != _cached_data.rend() &&
225  _cached_data.size() > _fixed_cache_size;) {
226  if (citer->second->loaded && citer->first != _idx) {
227  citer->second->data->release();
228  _cached_data.erase((++citer).base());
229  } else {
230  break;
231  }
232  }
233  return load_success;
234 }
235 
236 } // namespace benchmark
237 } // namespace perception
238 } // namespace apollo
bool load_success
Definition: async_sequence_data_loader.h:35
bool loaded
Definition: async_sequence_data_loader.h:34
bool query_last(std::shared_ptr< DataType > &data) override
Definition: async_sequence_data_loader.h:161
Definition: blob.h:72
std::future< void > status
Definition: async_sequence_data_loader.h:36
std::shared_ptr< Cache< apollo::perception::benchmark::FrameStatistics > > CachePtr
Definition: async_sequence_data_loader.h:63
bool query_next(std::shared_ptr< DataType > &data) override
Definition: async_sequence_data_loader.h:85
std::shared_ptr< Dtype > data
Definition: async_sequence_data_loader.h:33
Definition: sequence_data_loader.h:37
Definition: async_sequence_data_loader.h:42
Definition: async_sequence_data_loader.h:32
Definition: ctpl.h:61