20 #include "modules/perception/proto/rt.pb.h" 23 namespace perception {
28 ArgMax1Plugin(
const ArgMaxParameter &argmax_param, nvinfer1::Dims in_dims) {
29 input_dims_.nbDims = in_dims.nbDims;
30 CHECK_GT(input_dims_.nbDims, 0);
31 for (
int i = 0; i < in_dims.nbDims; i++) {
32 input_dims_.d[i] = in_dims.d[i];
33 input_dims_.type[i] = in_dims.type[i];
35 axis_ = argmax_param.axis();
36 out_max_val_ = argmax_param.out_max_val();
37 top_k_ = argmax_param.top_k();
38 CHECK_GE(top_k_, 1) <<
"top k must not be less than 1.";
39 output_dims_ = input_dims_;
40 output_dims_.d[0] = 1;
43 output_dims_.d[0] = 2;
59 const nvinfer1::Dims *inputs,
61 input_dims_ = inputs[0];
62 for (
int i = 1; i < input_dims_.nbDims; i++) {
63 output_dims_.d[i] = input_dims_.d[i];
68 void configure(
const nvinfer1::Dims *inputDims,
int nbInputs,
69 const nvinfer1::Dims *outputDims,
int nbOutputs,
70 int maxBatchSize)
override {
71 input_dims_ = inputDims[0];
72 for (
int i = 1; i < input_dims_.nbDims; i++) {
73 output_dims_.d[i] = input_dims_.d[i];
79 virtual int enqueue(
int batchSize,
const void *
const *inputs,
void **outputs,
80 void *workspace, cudaStream_t stream);
85 char *d =
reinterpret_cast<char *
>(buffer), *a = d;
87 CHECK_EQ(d, a + size);
96 nvinfer1::Dims input_dims_;
97 nvinfer1::Dims output_dims_;
virtual int initialize()
get the number of outputs from the layer
Definition: argmax_plugin.h:55
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream)
size_t getSerializationSize() override
Definition: argmax_plugin.h:82
int getNbOutputs() const override
Definition: argmax_plugin.h:57
virtual void terminate()
Definition: argmax_plugin.h:56
virtual ~ArgMax1Plugin()
Definition: argmax_plugin.h:90
Definition: argmax_plugin.h:26
ArgMax1Plugin(const ArgMaxParameter &argmax_param, nvinfer1::Dims in_dims)
Definition: argmax_plugin.h:28
void configure(const nvinfer1::Dims *inputDims, int nbInputs, const nvinfer1::Dims *outputDims, int nbOutputs, int maxBatchSize) override
Definition: argmax_plugin.h:68
void serialize(void *buffer) override
Definition: argmax_plugin.h:84
size_t getWorkspaceSize(int maxBatchSize) const override
Definition: argmax_plugin.h:77
virtual nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, int nbInputDims)
Definition: argmax_plugin.h:58