我用 tensorflow 实现的“一个神经聊天模型”:一个基于深度学习的聊天机器人

时间:2023-01-12 21:58:18

概述

这个工作尝试重现这个论文的结果 A Neural Conversational Model (aka the Google chatbot).

它使用了循环神经网络(seq2seq 模型)来进行句子预测。它是用 python 和 TensorFlow 开发。

程序的加载主体部分是参考 Torch的 neuralconvo from macournoyer.

现在, DeepQA 支持一下对话语料:

To speedup the training, it's also possible to use pre-trained word embeddings (thanks to Eschnou). More info here.

安装

这个程序需要一下依赖(easy to install using pip: pip3 install -r requirements.txt):

  • python 3.5
  • tensorflow (tested with v1.0)
  • numpy
  • CUDA (for using GPU)
  • nltk (natural language toolkit for tokenized the sentences)
  • tqdm (for the nice progression bars)

你可能需要下载附带的数据让 nltk 正常工作。

python3 -m nltk.downloader punkt

Cornell 数据集已经包括了。其他的数据集查看 readme 文件到他们所在的文件夹。 (在 data/).

网站接口需要一些附加的包:

  • django (tested with 1.10)
  • channels
  • Redis (see here)
  • asgi_redis (at least 1.0)

Docker 安装也是支持的,更多详细的教程参考 here.

运行

聊天机器人

训练这个模型,直接运行 main.py 。一旦训练完成,你可以测试结果用 main.py --test

(结果生成在 'save/model/samples_predictions.txt') 或者用 main.py --test interactive (更有趣).

Here are some flags which could be useful. For more help and options, use python main.py -h:

  • --modelTag <name>: allow to give a name to the current model to differentiate between them when testing/training.
  • --keepAll: use this flag when training if when testing, you want to see the predictions at different steps (it can be interesting to see the program changes its name and age as the training progress). Warning: It can quickly take a lot of storage space if you don't increase the --saveEvery option.
  • --filterVocab 20 or --vocabularySize 30000: Limit the vocabulary size to and optimize the performances and memory usage. Replace the words used less than 20 times by the <unknown> token and set a maximum vocabulary size.
  • --verbose: when testing, will print the sentences as they are computed.
  • --playDataset: show some dialogue samples from the dataset (can be use conjointly with --createDataset if this is the only action you want to perform).

To visualize the computational graph and the cost with TensorBoard, just run tensorboard --logdir save/.

默认的这个网络架构是一个标准的 encoder/decoder 有两个 LSTM layers (隐藏层大小 256) ,然后 vocabulary 的 embedding size 是 32.这个网络用 ADAM 训练。最大的句子长度设置为 10个单词,但是可以增加。

结果

当然,这个网络并不会很擅长聊天:

Q: Hi
A: Hi. Q: What is your name ?
A: Laura. Q: What does that mean ?
A: I dunno. Q: How old are you ?
A: thirty-five. Q: Will Google hire me ?
A: No. Q: Tell me the alphabet
A: Fuck you. Q: That's not nice
A: Yeah.

这里有一些情况它并不能正确回答:

Q: Who is Laura ?
A: My brother. Q: Say 'goodbye'
A: Alright. Q: What is cooking ?
A: A channel. Q: Can you say no ?
A: No. Q: Two plus two
A: Manny...

预训练模型

项目截图:

我用 tensorflow 实现的“一个神经聊天模型”:一个基于深度学习的聊天机器人

实测截图:

我用 tensorflow 实现的“一个神经聊天模型”:一个基于深度学习的聊天机器人

一步一步教程:

1.下载这个项目:

https://github.com/Conchylicultor/DeepQA

2.下载训练好的模型:

https://drive.google.com/file/d/0Bw-phsNSkq23OXRFTkNqN0JGUU0/view

(如果网址不能打开的话,今晚我会上传到百度网盘,分享到:http://www.tensorflownews.com/)

3.解压之后放在 项目 save 目录下

如图所示

我用 tensorflow 实现的“一个神经聊天模型”:一个基于深度学习的聊天机器人

4.复制 save/model-pretrainedv2/dataset-cornell-old-lenght10-filter0-vocabSize0.pkl 这个文件到 data/samples/

如图所示:

我用 tensorflow 实现的“一个神经聊天模型”:一个基于深度学习的聊天机器人

5.在项目目录执行一下命令:

python3 main.py --modelTag pretrainedv2 --test interactive

程序读取了预训练的模型之后,如图:

我用 tensorflow 实现的“一个神经聊天模型”:一个基于深度学习的聊天机器人

聊天机器人资源合集

项目,语聊,论文,教程

https://github.com/fendouai/Awesome-Chatbot

更多教程:

http://www.tensorflownews.com/

DeepQA

https://github.com/Conchylicultor/DeepQA

备注:为了更加容易了解这个项目,说明部分翻译了项目的部分 readme ,主要是介绍使用预处理数据来运行这个项目。

随机推荐

  1. android源码修改,实现长按电源键直接关机

    版本:android 4.4.2 源文件路径:frameworks\base\policy\src\com\android\internal\policy\impl\PhoneWindowManage ...

  2. &lbrack;BZOJ3156&rsqb;防御准备(斜率优化DP)

    题目:http://www.lydsy.com:808/JudgeOnline/problem.php?id=3156 分析: 简单的斜率优化DP

  3. request、response 中文乱码问题与解决方式

    request乱码指的是:浏览器向服务器发送的请求参数中包含中文字符,服务器获取到的请求参数的值是乱码:   response乱码指的是:服务器向浏览器发送的数据包含中文字符,浏览器中显示的是乱码: ...

  4. SQL Server调优系列基础篇 - 联合运算符总

    前言 上两篇文章我们介绍了查看查询计划的方式,以及一些常用的连接运算符的优化技巧,本篇我们总结联合运算符的使用方式和优化技巧. 废话少说,直接进入本篇的主题. 技术准备 基于SQL Server200 ...

  5. js字符串比较

    1,大写字母小于小写字母 a='ang',b='Zh' 那么a>b 2,可以使用字符串的toUpperCase()/toLowerCase()方法不区分字母的大小写. a.toUpperCase ...

  6. SQL学习之组合查询&lpar;UNION&rpar;

    1.大多数的SQL查询只包含从一个或多个表中返回数据的单条SELECT语句,但是,SQL也允许执行多个查询(多条SELECT语句),并将结果作为一个查询结果集返回.这些组合查询通常称为并或复合查询. ...

  7. VC6&period;0 调试&period;dll文件

    对于自己制作的.DLL文件,一直没有比较好的调试方法,其实是知道的太少. 下面就说说VC6.0下面 怎么调试DLL文件: 首先得有一个调用DLL文件的可执行程序,然后调用这个可执行程序. 在工程上 右 ...

  8. 关于tomcat配置MyEclipse项目的配置代码

    例如:<Context path="/shis" docBase="E:\Genuitec\Workspaces\MyEclipse 8.6\zwfw_platfo ...

  9. Linux驱动技术&lpar;四&rpar; &lowbar;异步通知技术

    异步通知的全称是"信号驱动的异步IO",通过"信号"的方式,放期望获取的资源可用时,驱动会主动通知指定的应用程序,和应用层的"信号"相对应, ...

  10. SpringMVC---CookieValue

    配置文件承接一二章 @CookieValue的作用 用来获取Cookie中的值 1.value:参数名称 2.required:是否必须 3.defaultValue:默认值 原网址:https:// ...