2026年7月2日 周四晚上19:30,报名腾讯会议了解“如何构建自进化的动态知识库(Brain)”(限30人)
免费POC, 零成本试错
FDE知识库

FDE知识库

学习大模型的前沿技术与行业落地应用


收藏

【AI大模型应用开发】LATS:比ToT和ReAct更强的大模型思维框架(LangGraph代码实现+拆解)

发布日期:2024-07-14 11:22:32 浏览次数: 3732
作者:同学小张

微信搜一搜,关注“同学小张”

0. 原理回顾

这是LATS论文中的步骤图:

LATS的实现需要:选择、扩展、评估、模拟、反向传播和反思的过程。

LangChain中的代码实现步骤如下:

将步骤简化为:

(1)选择

(2)扩展

(3)评估 + 反思 + 模拟 + 打分

(4)反向传播

1. 代码详解

下面我们一起来看下它的源码实现。

完整代码参考:https://github.com/langchain-ai/langgraph/blob/main/examples/lats/lats.ipynb

1.1 LangGraph的创建

LangGraph的创建是一个非常标准的套路:

(1)创建LangGraph对象

builder = StateGraph(TreeState)

(2) 添加节点

builder.add_node("start", generate_initial_response)
builder.add_node("expand", expand)

(3)设置初始进入节点

builder.set_entry_point("start")

(4)添加边,都是条件边

builder.add_conditional_edges(
    "start",
    # Either expand/rollout or finish
    should_loop,
)
builder.add_conditional_edges(
    "expand",
    # Either continue to rollout or finish
    should_loop,
)

(5)编译图

graph = builder.compile()

这就定义好了一个LangGraph,图的执行路径是这样的:

其中 should_loop 函数的定义:

def should_loop(state: TreeState):
    """Determine whether to continue the tree search."""
    root = state["root"]
    if root.is_solved:
        return END
    if root.height > 5:
        return END
    return "expand"

如果判定已经得到了最终答案(is_solved),或者搜索的最大层数超过了5层,则停止搜索答案,不再继续搜索。否则继续执行 expand 节点。

有了这个LangGraph的框架,我总结的LangGraph创建需要的三要素:节点node、边edge和状态state:边edge上面已经定义了,下面看下节点node和状态state的实现。

1.2 状态state的实现

from typing_extensions import TypedDict
class TreeState(TypedDict):
    # The full tree
    root: Node
    # The original input
    inputstr

自定义的状态state为TreeState,里面有一个Node类型的root字段,和一个str类型的input字段。

1.3 节点node的实现

1.3.1 start节点

start节点是执行 generate_initial_response函数,构造根节点:

def generate_initial_response(state: TreeState) -> dict:
    """Generate the initial candidate response."""
    res = initial_answer_chain.invoke({"input": state["input"]})
    parsed = parser.invoke(res)
    tool_responses = tool_executor.batch(
        [ToolInvocation(tool=r["type"], tool_input=r["args"]) for r in parsed]
    )
    output_messages = [res] + [
        ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])
        for resp, tool_call in zip(tool_responses, parsed)
    ]
    reflection = reflection_chain.invoke(
        {"input": state["input"], "candidate": output_messages}
    )
    root = Node(output_messages, reflection=reflection)
    return {
        **state,
        "root": root,
    }

1.3.1.1 初始化信息 initial_answer_chain

initial_answer_chain 代码如下:实现的功能是将用户提问输入给大模型,大模型给出回复。从Prompt看,基本就是个直来直去的问答。

prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an AI assistant.",
        ),
        ("user""{input}"),
        MessagesPlaceholder(variable_name="messages", optional=True),
    ]
)

initial_answer_chain = prompt_template | llm.bind_tools(tools=tools).with_config(
    run_name="GenerateInitialCandidate"
)

1.3.1.2 解析结果

回复之后,使用Json解析器解析一下:

parser = JsonOutputToolsParser(return_id=True)

这两步的运行结果大体如下:

1.3.1.3 执行工具获取执行工具结果 tool_executor

tool_executor 将上一步解析出来的工具进行并行执行,并获取结果。下面是工具的定义:

search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)
tools = [tavily_tool]
tool_executor = ToolExecutor(tools=tools)

执行完工具后结果类似如下:

1.3.1.4 评估反思 reflection_chain

reflection_chain 用来对工具执行结果进行打分评估。

@as_runnable
def reflection_chain(inputs) -> Reflection:
    tool_choices = reflection_llm_chain.invoke(inputs)
    reflection = tool_choices[0]
    if not isinstance(inputs["candidate"][-1], AIMessage):
        reflection.found_solution = False
    return reflection

其中的 reflection_llm_chain 定义如下:

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Reflect and grade the assistant response to the user question below.",
        ),
        ("user""{input}"),
        MessagesPlaceholder(variable_name="candidate"),
    ]
)

reflection_llm_chain = (
    prompt
    | llm.bind_tools(tools=[Reflection], tool_choice="Reflection").with_config(
        run_name="Reflection"
    )
    | PydanticToolsParser(tools=[Reflection])
)

从Prompt就大体能看出来,是利用大模型进行反思和打分。输入是用户的原始问题和候选的节点candidatetool_choice="Reflection"强制让大模型使用 Reflection工具,最后将大模型返回结果使用 Reflection进行解析。

最终,该chain返回的结果是一个 Reflection实例。

Reflection 执行结果示例:包括一个说明、一个评分和是否是最终答案。参考 Reflection类的实现。

1.3.1.5 Reflection类

class Reflection(BaseModel):
    reflections: str = Field(
        description="The critique and reflections on the sufficiency, superfluency,"
        " and general quality of the response"
    )
    score: int = Field(
        description="Score from 0-10 on the quality of the candidate response.",
        gte=0,
        lte=10,
    )
    found_solution: bool = Field(
        description="Whether the response has fully solved the question or task."
    )

    def as_message(self):
        return HumanMessage(
            content=f"Reasoning: {self.reflections}\nScore: {self.score}"
        )

    @property
    def normalized_score(self) -> float:
        return self.score / 10.0

问题:这个Reflection类没有定义为tools,为什么能直接这样写: tools=[Reflection]

1.3.2 expand节点

这个节点执行的就是整个LATS的流程:选择、扩展、评估、模拟、反向传播和反思。

def expand(state: TreeState, config: RunnableConfig) -> dict:
    """Starting from the "best" node in the tree, generate N candidates for the next step."""
    root = state["root"]
    best_candidate: Node = root.best_child if root.children else root
    messages = best_candidate.get_trajectory()
    # Generate N candidates from the single child candidate
    new_candidates = expansion_chain.invoke(
        {"input": state["input"], "messages": messages}, config
    )
    parsed = parser.batch(new_candidates)
    flattened = [
        (i, tool_call)
        for i, tool_calls in enumerate(parsed)
        for tool_call in tool_calls
    ]
    tool_responses = tool_executor.batch(
        [
            ToolInvocation(tool=tool_call["type"], tool_input=tool_call["args"])
            for _, tool_call in flattened
        ]
    )
    collected_responses = defaultdict(list)
    for (i, tool_call), resp in zip(flattened, tool_responses):
        collected_responses[i].append(
            ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])
        )
    output_messages = []
    for i, candidate in enumerate(new_candidates):
        output_messages.append([candidate] + collected_responses[i])

    # Reflect on each candidate
    # For tasks with external validation, you'd add that here.
    reflections = reflection_chain.batch(
        [{"input": state["input"], "candidate": msges} for msges in output_messages],
        config,
    )
    # Grow tree
    child_nodes = [
        Node(cand, parent=best_candidate, reflection=reflection)
        for cand, reflection in zip(output_messages, reflections)
    ]
    best_candidate.children.extend(child_nodes)
    # We have already extended the tree directly, so we just return the state
    return state

1.3.2.1 选择 best_candidate

选择当前最优的节点。

以下代码从 best_child开始看,首先是获取树的全部节点和子孙节点。取分数最高的节点。

怎么取分数最高的节点?这里调用了 upper_confidence_bound函数。这个函数用来计算UCT分数,UCT 是一种常用于多臂赌博机问题(Multi-Armed Bandit Problem)和蒙特卡洛树搜索(Monte Carlo Tree Search, MCTS)的算法,它有助于在探索(exploration)和利用(exploitation)之间找到一个平衡。

def upper_confidence_bound(self, exploration_weight=1.0):
    """Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
    if self.parent is None:
        raise ValueError("Cannot obtain UCT from root node")
    if self.visits == 0:
        return self.value
    # Encourages exploitation of high-value trajectories
    average_reward = self.value / self.visits
    # Encourages exploration of less-visited trajectories
    exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
    return average_reward + exploration_weight * exploration_term

@property
def best_child(self):
    """Select the child with the highest UCT to search next."""
    if not self.children:
        return None
    all_nodes = self._get_all_children()
    return max(all_nodes, key=lambda child: child.upper_confidence_bound())

1.3.2.2 扩展 expansion_chain

这一步是利用大模型,对于单个输入,生成N个不同的输出。prompt_template与前面初始化节点中的一致,基本是直来直去的问答,只是一次生成多个结果。

# This generates N candidate values for a single input to sample actions from the environment
def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):
    n = config["configurable"].get("N"5)
    bound_kwargs = llm.bind_tools(tools=tools).kwargs
    chat_result = llm.generate(
        [messages.to_messages()],
        n=n,
        callbacks=config["callbacks"],
        run_name="GenerateCandidates",
        **bound_kwargs
    )
    return [gen.message for gen in chat_result.generations[0]]

expansion_chain = prompt_template | generate_candidates

候选节点生成结果示例:

[AIMessage(content='', additional_kwargs={'tool_calls': [{'id''call_5DMq9O6BIden7lLraFH0NuYZ''function': {'arguments''{"query":"lithium pollution research report"}''name''tavily_search_results_json'}, 'type''function'}]}),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id''call_5DMq9O6BIden7lLraFH0NuYZ''function': {'arguments''{"query":"lithium pollution research report"}''name''tavily_search_results_json'}, 'type''function'}]}),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id''call_5DMq9O6BIden7lLraFH0NuYZ''function': {'arguments''{"query":"lithium pollution research report"}''name''tavily_search_results_json'}, 'type''function'}]}),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id''call_5DMq9O6BIden7lLraFH0NuYZ''function': {'arguments''{"query":"lithium pollution research report"}''name''tavily_search_results_json'}, 'type''function'}]}),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id''call_5DMq9O6BIden7lLraFH0NuYZ''function': {'arguments''{"query":"lithium pollution research report"}''name''tavily_search_results_json'}, 'type''function'}]})]

生成完N个候选节点之后,通过 解析结果、并行执行工具得到每个节点执行的结果。

1.3.2.3 评估+反思

对每个候选节点进行评估反思 reflection_chain

1.3.2.4 扩展树和反向传播

将新增的节点添加到树中

# Grow tree
child_nodes = [
    Node(cand, parent=best_candidate, reflection=reflection)
    for cand, reflection in zip(output_messages, reflections)
]
best_candidate.children.extend(child_nodes)

添加之后,树中已经有了这些子节点:

值得注意的是,这里面也包含了反向传播步骤。当新创建一个Node实例时,会调用反向传播:

class Node:
    def __init__(
        self,
        messages: List[BaseMessage],
        reflection: Reflection,
        parent: Optional[Node] = None,
    ):
        ......

        self.backpropagate(reflection.normalized_score)

反向传播的实际作用,就是更新这条路径上各个节点的分数:

def backpropagate(self, reward: float):
    """Update the score of this node and its parents."""
    node = self
    while node:
        node.visits += 1
        node.value = (node.value * (node.visits - 1) + reward) / node.visits
        node = node.parent

1.4 执行

question = "Write a research report on lithium pollution."
for step in graph.stream({"input": question}):
    step_name, step_state = next(iter(step.items()))
    print(step_name)
    print("rolled out: ", step_state["root"].height)
    print("---")
    
    # solution_node = step["__end__"]["root"].get_best_solution() ## 这一句我没运行成功,暂且不管吧
    solution_node = step["start"]["root"].get_best_solution()
    best_trajectory = solution_node.get_trajectory(include_reflections=False)
    print(best_trajectory[-1].content)

执行完之后最后输出结果是:最终节点的content。

best_trajectory = solution_node.get_trajectory(include_reflections=False)
    print(best_trajectory[-1].content)

最终输出结果示例:

2. 总结

本文我们对LangChain中实现LATS的源码进行了详细的学习和拆解,希望能够帮助大家更好地理解LATS,给大家做一个参考。

LATS的六步:选择、扩展、评估、模拟、反向传播和反思。其中评估、模拟和反思可以合并在一起执行,模拟其实就是执行工具获取工具的执行结果,反向传播其实就是更新这条路径上各个节点的得分。

53AI,企业落地大模型首选服务商

产品:场景落地咨询+大模型应用平台+行业解决方案

承诺:免费POC验证,效果达标后再合作。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

添加专属顾问

回到顶部

加载中...

扫码咨询

扫码登录
登录即表示您同意《53AI网站服务协议》
服务协议

欢迎您使用【53AI 官方网站】(以下简称“本网站”或“我们”)。本《会员服务协议》(以下简称“本协议”)是您(以下简称“会员”或“用户”)与【深圳市博思协创网络科技有限公司】之间关于注册、登录及使用本网站会员服务所订立的法律协议。

在您注册或登录前,请务必审慎阅读、充分理解各条款内容,特别是免除或限制责任的条款、知识产权条款、争议解决条款等。此类条款将以加粗形式提示您注意。 当您通过微信公众号授权、手机验证码验证或其他方式成功登录本网站时,即视为您已完全理解并同意接受本协议的全部内容。

一、 定义

本网站:指由【深圳市博思协创网络科技有限公司】运营的,域名为【53ai.com】的网站及相关移动端页面。

会员服务:指本网站向注册会员提供的知识库文章查阅、内容检索及其他相关增值服务。

知识库内容:指本网站发布的包括但不限于文字、图表、数据、研究报告、行业分析等数字化内容资源。

二、 账号注册与登录

登录方式:本网站支持以下登录方式,您可根据实际情况选择:

微信公众号授权登录:您同意将您的微信OpenID信息授权给本网站,用于创建或关联会员账号。

手机验证码登录:您需提供真实有效的手机号码,并通过短信验证码完成身份验证与登录/注册。

账号安全:您的账号仅限您本人使用,禁止赠与、借用、租用、转让或售卖。因您保管不善导致的账号被盗、密码泄露等损失,由您自行承担。

实名认证:根据相关法律法规要求,我们可能要求您在特定功能下完成实名认证。如您拒绝提供,可能无法使用部分或全部服务。

未成年人保护:若您未满18周岁,请在法定监护人的陪同下阅读本协议,并在征得监护人同意后使用本服务。

三、 服务内容与规范

知识库查阅权限:会员登录后,有权按照其会员等级对应的权限范围,在线浏览、检索本网站知识库中的相关文章及内容。

服务变更:我们有权根据业务发展需要,调整、变更或终止部分服务内容,并将以网站公告、公众号消息等方式提前通知。

禁止行为:您在使用服务时不得实施以下行为:

利用技术手段批量爬取、下载、转存知识库内容;

将知识库内容用于商业目的或未经授权地向第三方传播;

干扰本网站正常运行或侵犯其他用户合法权益;

发布违法违规信息或从事违反公序良俗的活动。

四、 知识产权声明

权利归属:本网站知识库中的排版设计、软件代码等内容的知识产权均归【公司全称】或原权利人所有,受《中华人民共和国著作权法》等法律保护。

有限许可:本网站授予会员一项非独占、不可转让、不可转授权的普通许可,仅限于个人学习、研究之目的在线查阅知识库内容。

侵权追责:未经书面许可,任何单位或个人不得以任何形式复制、转载、摘编、镜像、汇编或以其他方式使用上述内容。一经发现,我们保留追究其法律责任的权利。

五、 个人信息保护

我们重视对您个人信息的保护。关于我们如何收集、使用、存储和保护您的个人信息,请单独阅读 《隐私政策》。

您通过微信公众号授权或手机号验证所提供的信息,我们将严格按照《个人信息保护法》的规定处理,仅用于身份识别、服务提供及安全验证等必要用途。

您可以随时通过网站设置或联系客服行使查阅、更正、删除个人信息及撤回授权同意的权利。

六、 免责声明

内容准确性:知识库内容仅供参考,不构成专业建议。我们不对其完整性、准确性、时效性作任何明示或暗示的保证,您应自行判断并承担使用风险。

不可抗力:因自然灾害、政策法规变化、网络故障、第三方平台接口异常(如微信接口维护、运营商短信通道故障)等不可抗力导致的服务中断或延迟,我们不承担违约责任。

第三方链接:本网站可能包含指向第三方网站的链接,该等网站的内容和服务不受我们控制,请您自行甄别风险。

七、 违约责任

如您违反本协议约定,我们有权视情节采取警告、限制功能、暂停服务、注销账号等措施,并保留要求赔偿损失的权利。

如因您的违约行为导致我们遭受行政处罚、第三方索赔或商誉损失,您应承担全部赔偿责任(包括但不限于罚款、赔偿金、律师费、公证费等)。

八、 法律适用与争议解决

本协议的订立、执行和解释均适用中华人民共和国大陆地区法律。

因本协议产生的或与本协议有关的任何争议,双方应友好协商解决;协商不成的,任何一方均可向【公司所在地】有管辖权的人民法院提起诉讼。

九、 其他

本协议构成双方就本服务达成的完整协议,取代此前任何口头或书面约定。

本协议任一条款被认定为无效或不可执行的,不影响其他条款的效力。

我们对本协议享有最终解释权,并在法律允许的范围内保留随时修改的权利。修改后的协议一经公布即生效,继续使用服务即视为同意修订内容。


已查阅