-
Notifications
You must be signed in to change notification settings - Fork 79
/
gemma.py
119 lines (99 loc) · 4.61 KB
/
gemma.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from typing import Dict, Optional
import torch
from tensorrt_llm import Mapping
from tensorrt_llm.models import PretrainedConfig, PretrainedModel
from tensorrt_llm.models.gemma.model import GemmaForCausalLM as TrtGemmaForCausalLM
from tensorrt_llm.models.gemma.weight import load_from_hf_gemma
from tensorrt_llm.plugin import PluginConfig
from transformers import GemmaForCausalLM as TransformersGemmaForCausalLM
from transformers import PretrainedConfig as TransformersPretrainedConfig
from transformers import PreTrainedModel as TransformersPretrainedModel
from optimum.nvidia import TensorRTConfig
from optimum.nvidia.config import dtype_to_str
from optimum.nvidia.hub import HuggingFaceHubModel
from optimum.nvidia.runtime import CausalLM
LOGGER = getLogger(__name__)
class GemmaConfig(TensorRTConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaGemmaConfig`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
Configuration objects inherit from [`TensorRTConfig`] and can be used to control the model outputs. Read the
documentation from [`TensorRTConfig`] for more information.
"""
@staticmethod
def from_config(
config: TransformersPretrainedConfig, mapping: Optional[Mapping] = None
) -> "TensorRTConfig":
mapping = mapping or Mapping()
# Retrieve the quantization from the transformers config (if provided)
_, qconfig = TensorRTConfig.get_quantization_config(config)
trt_config = GemmaConfig(
architecture=config.architectures[0],
dtype=dtype_to_str(config.torch_dtype),
logits_dtype="float32",
vocab_size=config.vocab_size,
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
head_size=config.head_dim,
hidden_act=config.hidden_act,
intermediate_size=config.intermediate_size,
norm_epsilon=config.rms_norm_eps,
position_embedding_type="rope_gpt_neox",
rotary_base=getattr(config, "rope_theta", 10000.0),
rotary_scaling=getattr(config, "rope_scaling", None),
world_size=mapping.world_size,
tp_size=mapping.tp_size,
pp_size=mapping.pp_size,
use_prompt_tuning=False,
use_parallel_embedding=mapping.tp_size > 1,
embedding_sharding_dim=0,
share_embedding_table=False,
max_lora_rank=64,
quantization=qconfig,
)
trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8)
return trt_config
def get_plugins_config(self) -> PluginConfig:
config = super().get_plugins_config()
config.moe_plugin = "disable"
config.bert_attention_plugin = "disable"
config.gpt_attention_plugin = self.dtype
config.gemm_plugin = self.dtype
return config
@staticmethod
def supports_strong_typing() -> bool:
return False
class GemmaForCausalLM(CausalLM, HuggingFaceHubModel):
MODEL_CONFIG = GemmaConfig
HF_LIBRARY_TARGET_MODEL_CLASS = TransformersGemmaForCausalLM
TRT_LLM_TARGET_MODEL_CLASS = TrtGemmaForCausalLM
@staticmethod
def convert_weights(
target: PretrainedModel,
source: TransformersPretrainedModel,
config: PretrainedConfig,
) -> Dict[str, torch.Tensor]:
if config.quant_mode.has_any_quant():
raise NotImplementedError("Quantization is not supported yet.")
return load_from_hf_gemma(target, source, config.mapping, config.dtype)