Driverless AI 自然语言处理演示 – 航空公司情感分析数据集¶
在本 notebook 中,我们将看到如何通过 Driverless AI python 客户端使用 Twitter 美国航空公司情绪分析数据集来构建文本分类模型。
导入所需的 python 模块以开始并启动 Driverless AI 客户端。如果尚未安装,请从 Driverless AI GUI 中下载并安装 python 客户端。
点击此处,获取 Python 客户端文档资料 。
[1]:
import pandas as pd
from sklearn import model_selection
import driverlessai
首先使用 Client
建立与 Driverless AI 的连接。请输入您的凭据和 url 地址。
[2]:
address = 'http://ip_where_driverless_is_running:12345'
username = 'username'
password = 'password'
dai = driverlessai.Client(address = address, username = username, password = password)
# # make sure to use the same user name and password when signing in through the GUI
使用 datasets.create
命令将航空公司文件上传至 Driverless AI 。然后,将数据拆分为训练集和测试集。
[3]:
airlines = dai.datasets.create(data='https://h2o-public-test-data.s3.amazonaws.com/dai_release_testing/datasets/airline_sentiment_tweets.csv',
data_source='s3')
ds_split = airlines.split_to_train_test(train_size=0.7,
train_name='train',
test_name='test')
Complete 100.00% - [4/4] Computing column statistics
Complete
现在让我们看看数据集的一些基本信息。
[9]:
print('Train Dataset: ', train.shape)
print('Test Dataset: ', test.shape)
ids = [c for c in train.columns]
print(ids)
Train Dataset: (11712, 15)
Test Dataset: (2928, 15)
我们的实验仅需使用两列。包含推文文本的 text
列和包含这些推文中所包含情绪的 airline_sentiment
列(目标列)。我们可以删除此实验中的剩余列。
我们将启用 tensorflow 模型和转换,以利用基于 CNN 的文本特征。
[12]:
exp_preview = dai.experiments.preview(train_dataset=train,
target_column='airline_sentiment',
task='classification',
drop_columns=["_unit_id", "_golden", "_unit_state", "_trusted_judgments", "_last_judgment_at",
"airline_sentiment:confidence", "negativereason", "negativereason:confidence", "airline",
"airline_sentiment_gold", "name", "negativereason_gold", "retweet_count",
"tweet_coord", "tweet_created", "tweet_id", "tweet_location", "user_timezone"],
config_overrides="""
enable_tensorflow='on'
enable_tensorflow_charcnn='on'
enable_tensorflow_textcnn='on'
enable_tensorflow_textbigru='on'
""")
ACCURACY [7/10]:
- Training data size: *11,712 rows, 4 cols*
- Feature evolution: *[Constant, DecisionTree, LightGBM, TensorFlow, XGBoostGBM]*, *3-fold CV**, 2 reps*
- Final pipeline: *Ensemble (6 models), 3-fold CV*
TIME [2/10]:
- Feature evolution: *8 individuals*, up to *42 iterations*
- Early stopping: After *5* iterations of no improvement
INTERPRETABILITY [8/10]:
- Feature pre-pruning strategy: Permutation Importance FS
- Monotonicity constraints: enabled
- Feature engineering search space: [CVCatNumEncode, CVTargetEncode, CatOriginal, Cat, Frequent, Interactions, NumCatTE, Original, TextBiGRU, TextCNN, TextCharCNN, Text]
[Constant, DecisionTree, LightGBM, TensorFlow, XGBoostGBM] models to train:
- Model and feature tuning: *192*
- Feature evolution: *288*
- Final pipeline: *6*
Estimated runtime: *minutes*
Auto-click Finish/Abort if not done in: *1 day*/*7 days*
请注意此实验启用了 Text
和 TextCNN
特征。
现在我们可以开始实验。
[13]:
model = dai.experiments.create(train_dataset=train,
target_column='airline_sentiment',
task='classification',
name="nlp_airline_sentiment_beta",
scorer='F1',
drop_columns=["tweet_id", "airline_sentiment_confidence", "negativereason", "negativereason_confidence", "airline", "airline_sentiment_gold", "name", "negativereason_gold", "retweet_count", "tweet_coord", "tweet_created", "tweet_location", "user_timezone", "airline_sentiment.negative", "airline_sentiment.neutral", "airline_sentiment.positive"],
accuracy=6,
time=2,
interpretability=5)
Experiment launched at: http://localhost:12345/#experiment?key=b971fe8a-e317-11ea-9088-0242ac110002
Complete 100.00% - Status: Complete
[14]:
print('Modeling completed for model ' + model.key)
Modeling completed for model b971fe8a-e317-11ea-9088-0242ac110002
[15]:
logs = model.log.download(dst_dir = '.', overwrite = True)
#logs = dai.datasets.download(model.log_file_path, '.')
print('Logs available at', logs)
Downloaded './h2oai_experiment_logs_b971fe8a-e317-11ea-9088-0242ac110002.zip'
Logs available at ./h2oai_experiment_logs_b971fe8a-e317-11ea-9088-0242ac110002.zip
我们可以将预测结果下载至当前文件夹中。
[16]:
test_preds = model.predict(dataset = test, include_columns = ids).download(dst_dir = '.', overwrite = True)
print('Test set predictions available at', test_preds)
Complete
Downloaded './b971fe8a-e317-11ea-9088-0242ac110002_preds_9f438fac.csv'
Test set predictions available at ./b971fe8a-e317-11ea-9088-0242ac110002_preds_9f438fac.csv
[ ]: