Apollo  v5.5.0
Open source self driving car software
argmax_plugin.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2018 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 
17 #pragma once
18 
20 #include "modules/perception/proto/rt.pb.h"
21 
22 namespace apollo {
23 namespace perception {
24 namespace inference {
25 
26 class ArgMax1Plugin : public nvinfer1::IPlugin {
27  public:
28  ArgMax1Plugin(const ArgMaxParameter &argmax_param, nvinfer1::Dims in_dims) {
29  input_dims_.nbDims = in_dims.nbDims;
30  CHECK_GT(input_dims_.nbDims, 0);
31  for (int i = 0; i < in_dims.nbDims; i++) {
32  input_dims_.d[i] = in_dims.d[i];
33  input_dims_.type[i] = in_dims.type[i];
34  }
35  axis_ = argmax_param.axis();
36  out_max_val_ = argmax_param.out_max_val();
37  top_k_ = argmax_param.top_k();
38  CHECK_GE(top_k_, 1) << "top k must not be less than 1.";
39  output_dims_ = input_dims_;
40  output_dims_.d[0] = 1;
41  if (out_max_val_) {
42  // Produces max_ind and max_val
43  output_dims_.d[0] = 2;
44  }
45  }
46 
55  virtual int initialize() { return 0; }
56  virtual void terminate() {}
57  int getNbOutputs() const override { return 1; }
58  virtual nvinfer1::Dims getOutputDimensions(int index,
59  const nvinfer1::Dims *inputs,
60  int nbInputDims) {
61  input_dims_ = inputs[0];
62  for (int i = 1; i < input_dims_.nbDims; i++) {
63  output_dims_.d[i] = input_dims_.d[i];
64  }
65  return output_dims_;
66  }
67 
68  void configure(const nvinfer1::Dims *inputDims, int nbInputs,
69  const nvinfer1::Dims *outputDims, int nbOutputs,
70  int maxBatchSize) override {
71  input_dims_ = inputDims[0];
72  for (int i = 1; i < input_dims_.nbDims; i++) {
73  output_dims_.d[i] = input_dims_.d[i];
74  }
75  }
76 
77  size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
78 
79  virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
80  void *workspace, cudaStream_t stream);
81 
82  size_t getSerializationSize() override { return 0; }
83 
84  void serialize(void *buffer) override {
85  char *d = reinterpret_cast<char *>(buffer), *a = d;
86  size_t size = getSerializationSize();
87  CHECK_EQ(d, a + size);
88  }
89 
90  virtual ~ArgMax1Plugin() {}
91 
92  private:
93  bool out_max_val_;
94  size_t top_k_;
95  int axis_;
96  nvinfer1::Dims input_dims_;
97  nvinfer1::Dims output_dims_;
98 };
99 
100 } // namespace inference
101 } // namespace perception
102 } // namespace apollo
virtual int initialize()
get the number of outputs from the layer
Definition: argmax_plugin.h:55
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream)
size_t getSerializationSize() override
Definition: argmax_plugin.h:82
Definition: blob.h:72
int getNbOutputs() const override
Definition: argmax_plugin.h:57
virtual void terminate()
Definition: argmax_plugin.h:56
virtual ~ArgMax1Plugin()
Definition: argmax_plugin.h:90
Definition: argmax_plugin.h:26
ArgMax1Plugin(const ArgMaxParameter &argmax_param, nvinfer1::Dims in_dims)
Definition: argmax_plugin.h:28
void configure(const nvinfer1::Dims *inputDims, int nbInputs, const nvinfer1::Dims *outputDims, int nbOutputs, int maxBatchSize) override
Definition: argmax_plugin.h:68
void serialize(void *buffer) override
Definition: argmax_plugin.h:84
size_t getWorkspaceSize(int maxBatchSize) const override
Definition: argmax_plugin.h:77
virtual nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, int nbInputDims)
Definition: argmax_plugin.h:58