网站图片上怎么做弹幕效果,网站佣金怎么做分录,北京招聘网,一建报名资格条件x01 背景MemGen 提出动态生成式记忆框架#xff0c;由记忆触发器与记忆编织器两个轻量模块协同构成#xff0c;旨在突破现有智能体记忆范式的局限。当前主流的记忆实现路径为#xff1a;参数化记忆通过微调将经验编码进模型参数#xff0c;虽能深度内化知识却易引发灾难性遗…x01 背景MemGen 提出动态生成式记忆框架由记忆触发器与记忆编织器两个轻量模块协同构成旨在突破现有智能体记忆范式的局限。当前主流的记忆实现路径为参数化记忆通过微调将经验编码进模型参数虽能深度内化知识却易引发灾难性遗忘基于检索的记忆将经验外化存储虽规避了遗忘问题但静态的一次性检索机制无法体现记忆与推理动态交互的认知特性。这一现状引出两大核心问题如何实现记忆与推理在每一步思考中的无缝耦合以及如何让记忆从提取式升级为满足当前需求的生成式重构而动态生成式隐式记忆正是应对这些挑战的第三种探索路径。0x02 源码解析MemGen项目旨在创建一个动态且自生成的记忆框架该框架由两个协同工作的轻量级模块组成一个基于强化学习训练的记忆触发器和一个记忆编织器。这一框架的核心思想是解决大型语言模型LLM智能体能力涌现时对“自进化”机制的探索需求其中记忆扮演关键角色。2.1 模型LatentMemoryModel 是 MemGen 框架的核心实现旨在构建动态生成式隐式记忆系统解决传统记忆范式的局限性。通过整合推理器Reasoner、记忆编织器Weaver和记忆触发器Trigger实现记忆与推理过程的无缝耦合让智能体在任务执行中动态生成、使用记忆而非依赖静态检索或参数化存储。2.1.1 核心特色模型的核心特色如下模块化协同设计由推理器核心推理、编织器生成潜在记忆、触发器控制记忆触发三大模块构成模块间通过投影层实现嵌入空间映射结构清晰且解耦。动态记忆增强在推理过程中自动识别分隔符位置作为记忆增强点动态插入编织器生成的潜在记忆突破静态记忆注入的局限贴合人类认知中记忆与推理的动态交互特性。精度与效率优化默认使用 bfloat16 精度推理器采用 Flash Attention 2 提升计算效率冻结推理器参数仅训练编织器和触发器实现参数高效学习。灵活配置与兼容性支持自定义触发器模型、PEFT 微调配置、记忆增强次数等参数自动处理 Tokenizer 缺失 pad token 的问题标准化对话模板提升跨场景兼容性。损失计算精准过滤通过潜在记忆掩码排除记忆嵌入对应的位置仅对原始输入位置计算损失确保训练目标聚焦于核心任务性能避免记忆生成过程干扰主任务学习。2.1.2 网络结构关键说明核心设计亮点三大模块协同逻辑推理器Reasoner核心推理组件权重冻结以保留基础能力仅通过潜在记忆调整解码路径。触发器MemGenTrigger动态判断记忆插入时机输出二分类触发概率决定是否调用编织器。编织器MemGenWeaver生成针对性潜在记忆分提示词 / 推理两阶段设计支持 PEFT 高效微调。核心流程闭环输入 → 推理器生成原始嵌入 → 触发器 增强点选择模块确定插入位置 → 编织器生成潜在记忆 → 投影层适配维度 → 重组增强序列 → 推理器完成最终推理 → 过滤无效位置输出。关键技术细节跨模块投影通过 reasoner_to_weaver 和 weaver_to_reasoner 解决推理器与编织器嵌入维度不匹配问题。动态记忆增强按分隔符拆分序列逐段插入记忆避免长序列冗余贴合人类 “思考 - 记忆” 交互模式。精度与效率全流程采用 bfloat16 精度推理器 / 编织器启用 Flash Attention 2平衡性能与速度。训练与推理适配训练时通过 labels 和 valid_logits 计算损失仅优化编织器、触发器及投影层参数。推理时无需 labels自动完成 “触发判断 - 记忆生成 - 推理增强” 全流程实现动态自进化。具体网络结构如下MemGen-12.1.3 代码LatentMemoryModel 的代码如下registry.register_model(latmem)class LatentMemoryModel(BaseModel): # 定义了一个名为 LatentMemoryModel 的类继承自 BaseModeldef __init__(self,reasoner_model_name: str, # 推理模型名称weaver_model_name: str, # 记忆编织器模型名称prompt_latents_len: int, # 提示长度inference_latents_len: int, # 推理长度weaver_peft_config: Optional[PeftConfig] None, # 记忆编织器配置可选trigger_model_name: str None, # 触发模型名称可选trigger_peft_config: Optional[PeftConfig] None, # 触发器配置可选max_prompt_aug_num: int 1, # 最大提示增强数量max_inference_aug_num: int 5, # 最大推理增强数量):super().__init__() # 调用父类构造函数# 构建推理模型self.model AutoModelForCausalLM.from_pretrained( # 从预训练模型加载推理模型reasoner_model_name, torch_dtypetorch.bfloat16, attn_implementationflash_attention_2)self.tokenizer AutoTokenizer.from_pretrained(reasoner_model_name) # 加载入分词器self.config self.model.config # 获取模型配置# 构建记忆编织器self.weaver MemGenWeaver( # 初始化记忆编织器weaver_model_name, prompt_latents_len, inference_latents_len, weaver_peft_config)# 构建触发器self.trigger NanoTrigger() # 默认触发器始终返回 trueif trigger_model_name is not None:self.trigger MemGenTrigger( # 如果指定了触发模型则加载相应的触发器trigger_model_name, trigger_peft_config)logging.info(fUse Trigger: {trigger_model_name}) # 记录日志# 投影层用于在推理模型和记忆编织器之间映射嵌入# 将推理模型输入嵌入映射到记忆编织器输入嵌入self.reasoner_to_weaver nn.Linear( # 线性层从推理模型隐藏层到记忆编织器隐藏层self.model.config.hidden_size, self.weaver.config.hidden_size, dtypetorch.bfloat16)# 将记忆编织器隐藏状态映射回推理模型输入嵌入self.weaver_to_reasoner nn.Linear( # 线性层从记忆编织器隐藏层到推理模型隐藏层self.weaver.config.hidden_size, self.model.config.hidden_size, dtypetorch.bfloat16)self.delimiters: List[str] [,, ., \n] # 用于检测增强点的分隔符self.max_prompt_aug_num max_prompt_aug_num # 提示后提示中插入潜在数量self.max_inference_aug_num max_inference_aug_num # 指定分隔符后插入潜在数量# 后处理self._postprocess_models() # 后处理模型self.warnings_issued {} # 存储发出的警告self.model_tags None # 存储模型标签log_trainable_params(self) # 记录可训练参数def add_model_tags(self, tags: Union[list[str], str]) - None: # 添加模型标签r向模型添加自定义标签这些标签将被推送到 Hugging Face Hub。不会覆盖模型中现有的标签。参数tags (Union[list[str], str])要添加到模型的标签例子pythonfrom transformers import AutoModelmodel AutoModel.from_pretrained(google-bert/bert-base-cased)model.add_model_tags([custom, custom-bert])# 将模型推送到您的命名空间名称为 my-custom-bert。model.push_to_hub(my-custom-bert)if isinstance(tags, str):tags [tags]if self.model_tags is None:self.model_tags []for tag in tags:if tag not in self.model_tags:self.model_tags.append(tag)def _postprocess_models(self):后处理记忆模型的组件推理模型、记忆编织器、触发器和分词器。步骤1. 冻结推理模型的所有参数不更新梯度。2. 将所有模型转换为 bfloat16 以提高内存和计算效率。3. 确保分词器有一个有效的填充符- 如果缺少填充符使用 EOS 符作为填充符。- 设置 padding_side 为 left 以兼容生成任务。4. 标准化分词器的模板为 CONVERSATION_TEMPLATE。# 默认冻结推理模型的所有参数fix_model_parameters(self.model)# 将所有子模型转换为 bfloat16self.model self.model.bfloat16()self.weaver self.weaver.bfloat16()self.trigger self.trigger.bfloat16()# 确保分词器有一个填充符if self.tokenizer.pad_token is None:self.tokenizer.pad_token self.tokenizer.eos_tokenself.tokenizer.pad_token_id self.tokenizer.eos_token_idself.tokenizer.padding_side leftlogging.info(fTokenizer has no pad token. Using EOS token ({self.tokenizer.eos_token}) as pad token.)# 标准化分词器的模板self.tokenizer.chat_template CONVERSATION_TEMPLATE2.1.4 插入阶段LatentMemoryModel 的两个关键函数 forward 和 generate 区别如下forward 函数训练时候计算损失由训练循环自动调用。generate 函数推理时候生成文本由代码显式调用。forwardforward 函数的主体如下def _forward(self,input_ids: torch.Tensor,attention_mask: torch.Tensor,labels: torch.Tensor,**kwargs) - torch.Tensor:# 预处理输入assert input_ids.shape attention_mask.shape labels.shapetokenizer self.tokenizerreasoner self.modelweaver self.weaverdelimiters self.delimitersmax_augment_num self.max_inference_aug_num # 限制推理增强点的数量以避免过度增强device self.deviceembeds_dtype reasoner.get_input_embeddings().weight.dtypeB, _ input_ids.shapehidden_size reasoner.config.hidden_size# 选择增强索引augmentation_indices self._select_augment_points_after_delimiter(input_ids, labels, delimiters, tokenizer, max_augment_num)# 输入嵌入inputs_embeds reasoner.get_input_embeddings()(input_ids)# 初始化开始索引和空张量以累积处理的段current_start_idx 0current_inputs_embeds torch.empty(B, 0, hidden_size).to(device, dtypeembeds_dtype)current_attention_mask torch.empty(B, 0).to(device, dtypeattention_mask.dtype)current_latents_mask torch.empty(B, 0).to(device, dtypetorch.bool)# 遍历所选增强点for aug_idx in augmentation_indices:# 切片原始嵌入和注意力掩码segment_inputs_embeds inputs_embeds[:, current_start:aug_idx]segment_attention_mask attention_mask[:, current_start:aug_idx]segment_latents_mask torch.zeros(B, segment_inputs_embeds.size(1).to(device, dtypetorch.bool)# 连接当前段到累积嵌入和掩码current_inputs_embeds torch.cat([current_inputs_embeds, segment_inputs_embeds], dim1)current_mask torch.cat([current_mask, segment_attention_mask], dim1)current_position_ids generate_position_ids(current_mask)current_latents torch.cat([current_latents, segment_latents], dim1)# 将推理模型嵌入映射到记忆编织器嵌入weaver_inputs_embeds self.reasoner_to_weaver(current_inputs_embeds)# 确定此点是否为提示增强的结束is_prompt_end_aug (labels[:, aug_idx] ! -100).all() and (labels[:, aug_idx-1] -100).all().item()# 根据类型使用记忆编织器增强提示或推理if is_prompt_end_aug:weaver_hidden_states, attn_mask, pos_ids weaver.augment_prompt(weaver_inputs, current_attention_mask, current_position_ids)else:weaver_hidden_states, attn_mask, pos_ids weaver.augment_inference(weaver_inputs, current_attention_mask, current_position_ids)# 将记忆编织器隐藏状态映射回推理模型嵌入latent_inputs_embeds self.weaver_to_reasoner(weaver_hidden_states)# 更新累积嵌入和掩码与新增强段current_inputs_embeds torch.catgenerate核心作用该 generate 方法是 MemGen 模型的推理核心实现了动态记忆增强与序列生成的无缝融合。通过迭代生成新 token每步自适应判断是否插入编织器生成的潜在记忆让推理器在生成过程中实时利用动态记忆调整解码路径最终输出增强后的序列可选返回记忆增强位置掩码。核心特色双阶段记忆增强先执行提示词阶段记忆增强初始化全局记忆再在迭代生成中动态触发推理阶段增强补充实时记忆适配不同生成阶段的记忆需求。自适应触发机制通过 _should_augment 结合触发器决策仅对需要记忆支持的序列执行增强避免无意义的计算开销。维度对齐优化非增强序列采用左填充_left_pad方式对齐增强序列维度确保批次内所有序列格式统一不影响批量生成效率。高效推理设计禁用梯度计算torch.no_grad()节省内存并加速推理启用推理器缓存use_cacheTrue减少重复计算仅在必要时输出隐藏状态降低计算成本。灵活配置与可解释性支持控制最大生成 token 数、采样策略等参数可选返回 augmentation_pos 掩码标记记忆插入位置提升模型可解释性。鲁棒性保障提前终止机制所有序列生成 EOS 或达最大增强次数时终止避免无效迭代重构生成配置固定关键参数确保生成稳定性。推理生成流程图潜在记忆插入的完整流程初始化阶段对输入提示进行增强插入初始潜在记忆。生成循环逐个生成token。条件检查在每个步骤检查是否满足插入条件。决策判断使用trigger模型决定是否插入潜在记忆。潜在记忆生成通过weaver模型生成潜在记忆表示。嵌入连接将潜在记忆嵌入连接到当前输入序列。继续生成使用增强后的序列继续生成下一个token。具体流程如下图所示MemGen-2代码如下torch.no_grad() # 禁用梯度计算适用于推理阶段提升效率并节省内存def generate(self,input_ids: torch.Tensor, # 输入token ID序列形状[batch_size, prompt_len]attention_mask: torch.Tensor, # 注意力掩码形状与input_ids一致generation_config: GenerationConfig None, # 生成配置如最大新token数、采样策略等return_augmentation_mask: bool False, # 是否返回记忆增强位置掩码**kwargs) - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:执行MemGen模型的推理生成流程动态融合潜在记忆与推理器生成增强后的输出序列。核心逻辑1. 初始化提示词阶段的记忆增强2. 迭代生成新token每步判断是否触发推理阶段记忆增强3. 对需增强的序列插入编织器生成的潜在记忆非增强序列左填充对齐维度4. 生成完成后返回结果可选返回增强位置掩码tokenizer self.tokenizerreasoner self.modelweaver self.weavertrigger self.triggerdelimiters self.delimitersmax_augment_num self.max_inference_aug_num # 单序列最大推理阶段增强次数invalid_token_id -100 # 无效位置标记用于增强位置掩码# 预处理输入转移到模型所在设备input_ids input_ids.to(self.device)attention_mask attention_mask.to(self.device)# 提取生成配置关键参数max_new_tokens generation_config.max_new_tokens # 最大生成新token数do_sample generation_config.do_sample # 是否启用采样生成temperature generation_config.temperature # 采样温度控制随机性pad_token_id tokenizer.pad_token_id # pad token IDeos_token_id tokenizer.eos_token_id # 结束token IDprompt_len input_ids.size(1) # 提示词长度# 重构生成配置固定必要参数确保生成稳定性generation_config GenerationConfig(do_sampledo_sample,temperaturetemperature,pad_token_idpad_token_id,eos_token_ideos_token_id,use_cacheTrue # 启用缓存加速生成)# 将输入token ID转换为嵌入向量inputs_embeds reasoner.get_input_embeddings()(input_ids)B, _, hidden_size inputs_embeds.shape # Bbatch_sizehidden_size推理器隐藏层维度device inputs_embeds.device # 模型所在设备CPU/GPU# 初始化生成过程中的关键张量current_inputs_embeds inputs_embeds # 当前输入嵌入含原始提示词潜在记忆current_attention_mask attention_mask # 当前注意力掩码current_position_ids generate_position_ids(current_attention_mask) # 当前位置IDcurrent_input_ids input_ids # 当前已生成的token ID序列# 提示词阶段记忆增强生成并插入提示词专用潜在记忆weaver_inputs_embeds self.reasoner_to_weaver(current_inputs_embeds) # 映射到编织器嵌入空间weaver_hidden_states, attn_mask, pos_ids weaver.augment_prompt(weaver_inputs_embeds, current_attention_mask, current_position_ids)latent_inputs_embeds self.weaver_to_reasoner(weaver_hidden_states) # 映射回推理器嵌入空间# 拼接提示词与增强记忆current_inputs_embeds torch.cat([current_inputs_embeds, latent_inputs_embeds], dim1)current_attention_mask torch.cat([current_attention_mask, attn_mask], dim1)current_position_ids torch.cat([current_position_ids, pos_ids], dim1)# 生成循环初始化sentence_augment_count torch.zeros(B, dtypetorch.int, devicedevice) # 各序列已增强次数augmentation_pos torch.full((B, max_new_tokens), fill_valueinvalid_token_id, devicedevice) # 增强位置掩码inserted_embeds: List[List[torch.Tensor]] [[] for _ in range(B)] # 记录插入的潜在记忆用于后处理for i in range(max_new_tokens):# 若所有序列均已生成EOS token提前终止if (current_input_ids[:, -1] eos_token_id).all():break# 若所有序列均已达到最大增强次数一次性生成剩余tokenif (sentence_augment_count max_augment_num).all():# 调整剩余生成长度generation_config.max_new_tokens max_new_tokens - i# 推理器生成剩余tokengenerated reasoner.generate(inputs_embedscurrent_inputs_embeds,attention_maskcurrent_attention_mask,generation_configgeneration_config,)current_input_ids torch.cat([current_input_ids, generated], dim1)break# 推理器前向传播获取当前步输出outputs reasoner(inputs_embedscurrent_inputs_embeds,attention_maskcurrent_attention_mask,position_idscurrent_position_ids,output_hidden_statesFalse, # 推理阶段无需输出隐藏状态提升效率)# 生成并追加一个新token更新关键张量current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids self._append_one_step(outputs, current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids, do_sample, temperature)# 若为最后一步生成终止循环if i max_new_tokens - 1:break# 判断当前批次中哪些序列需要进行推理阶段记忆增强augment_decision self._should_augment(current_input_ids, current_attention_mask, sentence_augment_countsentence_augment_count,do_sampledo_sample, temperaturetemperature)augmentation_pos[:, i 1] augment_decision # 记录增强位置1增强0不增强-100无效augment_indices torch.where(augment_decision 1)[0] # 需增强的序列索引# 对需增强的序列执行记忆增强非增强序列左填充对齐维度if len(augment_indices) 0:# 递增需增强序列的增强次数计数sentence_augment_count[augment_indices] 1# 提取需增强序列的嵌入、掩码和位置IDcandidate_inputs_embeds current_inputs_embeds[augment_indices]candidate_attention_mask current_attention_mask[augment_indices]candidate_position_ids current_position_ids[augment_indices]# 编织器生成推理阶段潜在记忆weaver_inputs_embeds self.reasoner_to_weaver(candidate_inputs_embeds)weaver_hidden_states, attn_mask, _ weaver.augment_inference(weaver_inputs_embeds, candidate_attention_mask, candidate_position_ids)latent_inputs_embeds self.weaver_to_reasoner(weaver_hidden_states) # 映射回推理器空间# 拼接原始嵌入与潜在记忆candidate_inputs_embeds torch.cat([candidate_inputs_embeds, latent_inputs_embeds], dim1)candidate_attention_mask torch.cat([candidate_attention_mask, attn_mask], dim1)# 构建合并张量适配所有序列包括增强和非增强new_len candidate_inputs_embeds.size(1) # 增强后序列长度merged_inputs_embeds torch.zeros((B, new_len, hidden_size), devicedevice, dtypecurrent_inputs_embeds.dtype)merged_attention_mask torch.zeros((B, new_len), devicedevice, dtypecurrent_attention_mask.dtype)# 填充增强序列merged_inputs_embeds[augment_indices] candidate_inputs_embedsmerged_attention_mask[augment_indices] candidate_attention_mask# 填充非增强序列左填充对齐长度non_augment_indices torch.where(augment_decision ! 1)[0]if len(non_augment_indices) 0:non_aug_inputs_embeds current_inputs_embeds[non_augment_indices]non_aug_attention_mask current_attention_mask[non_augment_indices]non_aug_inputs_embeds, non_aug_attention_mask, _ self._left_pad(non_aug_inputs_embeds, non_aug_attention_mask, None, weaver.inference_latents_num)merged_inputs_embeds[non_augment_indices] non_aug_inputs_embedsmerged_attention_mask[non_augment_indices] non_aug_attention_mask# 更新当前关键张量current_inputs_embeds merged_inputs_embedscurrent_attention_mask merged_attention_maskcurrent_position_ids generate_position_ids(current_attention_mask) # 重新生成位置ID# 记录插入的潜在记忆用于后处理或可解释性分析for idx, embed in zip(augment_indices, latent_inputs_embeds):inserted_embeds[idx].append(embed.clone().detach().cpu())# 后处理调整增强位置掩码长度与生成结果一致new_generated_len current_input_ids.size(1) - prompt_lenaugmentation_pos augmentation_pos[:, :new_generated_len]# 根据配置返回结果仅生成序列 或 序列增强位置掩码if not return_augmentation_mask:return current_input_idselse:return current_input_ids, augmentation_pos2.2 Trigger2.2.1. 核心作用该模块定义了 MemGen 框架中记忆触发器的核心接口与两种具体实现核心作用是动态决策记忆增强的时机—— 即在推理过程中判断何时插入编织器生成的潜在记忆实现记忆与推理的动态耦合突破传统静态记忆注入的局限。2.2.2. 核心特色抽象接口统一规范Trigger抽象基类定义了触发器的核心接口确保后续扩展新触发器时遵循统一标准提升代码可扩展性。双实现适配不同场景NanoTrigger极简实现始终触发记忆增强无需训练适用于快速测试、基线对比或无需动态控制的简单场景。MemGenTrigger基于预训练 LLM 的智能触发器通过二分类头适配决策任务支持 PEFT 参数高效微调能根据输入序列动态判断是否触发适配复杂真实场景。高效适配与灵活扩展采用 bfloat16 精度和 Flash Attention 2 优化计算效率支持 PEFT 微调在不冻结基础模型的前提下实现参数高效学习替换 LLM 原始输出头为二分类头精准适配 是否插入记忆 的决策需求。模块解耦设计触发器决策独立于编织器模块仅基于输入序列和数据分布做出判断保证了模块间的低耦合和高内聚。2.2.3 网络架构网络架构图如下。说明如下模型支持PEFT参数高效微调如LoRA适配于Transformer Blocks层整体精度采用bfloat16平衡计算效率与数值稳定性注意力计算通过Flash Attention 2优化提升长序列处理速度MemGen-32.2.4 代码class Trigger(torch.nn.Module, ABC):记忆触发器的抽象基类Trigger。定义了触发器的核心接口用于决定在推理过程中何时触发记忆增强插入潜在记忆。所有具体触发器实现都需继承此类并实现forward方法。def __init__(self):super().__init__() # 调用父类Module的初始化方法abstractmethoddef forward(self, **kwargs) - bool:抽象前向传播方法接收输入数据返回是否触发记忆增强的决策。子类必须实现此方法定义具体的触发逻辑。Args:**kwargs: 可变关键字参数包含输入序列、注意力掩码等模型所需数据Returns:bool: 触发决策True表示触发记忆增强False表示不触发...class NanoTrigger(torch.nn.Module):极简触发器NanoTrigger始终触发记忆增强的基础实现。无需复杂逻辑固定返回触发决策适用于基础测试或无需动态控制的场景。def __init__(self):super().__init__()# 注册一个缓冲区张量用于获取模型所在设备无实际计算意义self.register_buffer(_device, torch.tensor(0.0))propertydef device(self):获取模型所在设备CPU/GPUreturn self._device.devicedef forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) - bool:# 该极简触发器始终预测需要插入记忆# 输出logits张量其中插入决策索引1的概率被设为1.0# 适用于批次中的每个token位置batch_size, seq_len input_ids.shape# 初始化logits张量形状为[batch_size, seq_len, 2]2表示不插入0和插入1两类logits torch.zeros(batch_size, seq_len, 2, deviceinput_ids.device)logits[..., 1] 1.0 # 将所有位置的插入决策概率设为1.0return logitsclass MemGenTrigger(torch.nn.Module):MemGen框架的专用触发器模块MemGenTrigger。- 输入接收推理器模型当前解码序列的inputs_embeds或input_ids- 输出生成形状为[batch_size, seq_len, 2]的logits张量表示每个位置不插入0和插入1记忆的概率用于动态决策记忆增强时机。def __init__(self,pretrained_model_name_or_path: str, # 预训练模型名称或路径用于初始化触发器LLMpeft_config: Optional[PeftConfig] None # PEFT配置可选用于参数高效微调):super().__init__()# 构建基础LLM模型作为触发器的核心推理组件self.model AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,torch_dtypetorch.bfloat16, # 使用bfloat16精度提升效率attn_implementationflash_attention_2 # 启用Flash Attention 2优化注意力计算)self.tokenizer AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # 对应的Tokenizer# 对基础模型进行后处理设置可训练、替换输出头self.model self._postprocess(self.model)# 若提供PEFT配置应用参数高效微调if peft_config is not None:self.model get_peft_model(self.model, peft_config)self.config self.model.config # 保存模型配置propertydef device(self):获取模型所在设备CPU/GPUreturn self.model.devicedef _postprocess(self, model: PreTrainedModel):对基础模型进行后处理适配触发器的二分类任务需求。Args:model: 原始预训练LLM模型Returns:处理后的模型可训练、替换为二分类输出头# 设置所有模型参数为可训练for parameter in model.parameters():parameter.requires_grad True# 将原始语言模型的输出头lm_head替换为二分类头hidden_size model.config.hidden_size # 模型隐藏层维度classification_head nn.Linear(hidden_size, 2) # 输出维度为2不插入/插入model.lm_head classification_head# 确保新的二分类头参数可训练for param in model.lm_head.parameters():param.requires_grad Truereturn modeldef forward(self,input_ids: Optional[torch.LongTensor] None, # 生成序列的token ID形状[batch_size, seq_len]attention_mask: Optional[torch.Tensor] None, # 注意力掩码避免关注填充token**kwargs: Unpack[TransformersKwargs], # 传递给底层模型的额外参数) - torch.Tensor:序列生成的触发决策机制。触发器基于已生成的input_ids做出决策受数据分布影响但独立于编织器模块。Args:input_ids (Optional[torch.LongTensor]): 生成序列的token ID张量attention_mask (Optional[torch.Tensor]): 注意力掩码默认None**kwargs: 传递给底层模型的额外关键字参数Returns:torch.Tensor: Logits张量形状为(batch_size, seq_len, num_classes)num_classes2分别对应不插入索引0和插入索引1的概率# 调用基础模型前向传播返回二分类logitsreturn self.model(input_idsinput_ids,attention_maskattention_mask,**kwargs).logits2.3 MemGenWeaver2.3.1 核心作用MemGenWeaver 是 MemGen 框架的核心组件之一负责生成动态潜在记忆并将其与推理器的输入序列融合从而实现记忆与推理过程的无缝交织。它通过可学习的潜在记忆查询向量在提示词阶段和推理阶段分别生成针对性的记忆表示引导推理器调整解码路径提升智能体的动态决策能力。2.3.2 核心特色双阶段记忆生成区分提示词阶段augment_prompt和推理阶段augment_inference使用各自独立的可学习潜在记忆查询向量适配不同阶段的记忆需求增强记忆生成的针对性。灵活的潜在记忆融合通过_augment方法统一实现潜在记忆与输入序列的融合包括嵌入拼接、注意力掩码扩展和位置 ID 计算确保记忆与原始输入在语义空间和时序上的一致性。高效的模型设计基于预训练 LLM 构建支持 PEFT 参数高效微调在保留基础能力的同时降低训练成本采用 bfloat16 精度和 Flash Attention 2 优化提升计算效率和内存利用率。动态记忆编织机制生成的潜在记忆并非静态检索结果而是基于当前输入序列动态生成的隐藏状态能够捕捉实时上下文信息实现 “生成式记忆” 的核心特性。模块化与可扩展性与推理器、触发器解耦通过标准化接口交互潜在记忆的数量可通过参数灵活配置适配不同任务对记忆容量的需求。2.3.3 网络架构网络架构图如下。说明如下核心组件可学习潜在记忆向量分阶段设计P提示词阶段数量I推理阶段数量支持动态生成记忆预训练LLM作为记忆生成核心默认启用bfloat16精度和Flash Attention 2优化序列融合层确保输入与记忆在语义、掩码、时序上的一致性核心流程输入 → 选择对应阶段的潜在记忆 → 融合序列 → LLM生成隐藏状态 → 提取潜在记忆输出支持PEFT参数高效微调如LoRA适配于Transformer Blocks层输出用途生成的潜在记忆将通过投影层映射到推理器的嵌入空间与原始输入融合以引导解码MemGen-4