Keras Post Training Quantization¶
- model_compression_toolkit.keras_post_training_quantization(in_model, representative_data_gen, n_iter=500, quant_config=DEFAULTCONFIG, fw_info=DEFAULT_KERAS_INFO, network_editor=[], knowledge_distillation_config=None, analyze_similarity=False)¶
Quantize a trained Keras model using post-training quantization. The model is quantized using a symmetric constraint quantization thresholds (power of two). The model is first optimized using several transformations (e.g. BatchNormalization folding to preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are being collected for each layer’s output (and input, depends on the quantization configuration). Thresholds are then being calculated using the collected statistics and the model is quantized (both coefficients and activations by default). If a knowledge distillation configuration is passed, the quantized weights are optimized using knowledge distillation by comparing points between the float and quantized models, and minimizing the observed loss.
- Parameters
in_model (Model) – Keras model to quantize.
representative_data_gen (Callable) – Dataset used for calibration.
n_iter (int) – Number of calibration iterations to run.
quant_config (QuantizationConfig) – QuantizationConfig containing parameters of how the model should be quantized. Default configuration.
fw_info (FrameworkInfo) – Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). Default Keras info
network_editor (List[EditRule]) – List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
knowledge_distillation_config (KnowledgeDistillationConfig) – Configuration for using knowledge distillation (e.g. optimizer).
analyze_similarity (bool) – Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
- Returns
A quantized model.
Examples
Import a Keras model:
>>> from tensorflow.keras.applications.mobilenet import MobileNet >>> model = MobileNet()
Create a random dataset generator:
>>> import numpy as np >>> def repr_datagen(): return [np.random.random((1,224,224,3))]
Import mct and pass the model with the representative dataset generator to get a quantized model:
>>> import model_compression_toolkit as mct >>> quantized_model, quantization_info = mct.keras_post_training_quantization(model, repr_datagen)