我一直认为 AI 不能完全替代人编程,但是AI 会是人类有史以来最好的编程助教,下面这个案例就是实证。在此之前,其实我没有写过特别长的程序,顶多不超过 20 行吧,Python 书籍通常只看完第一章节就没有然后了,但是今天我写出了 50+ 行的 Python代码,成功解决了问题。

提出需求:提取A股半导体上市公司股票数据

最近想研究下A股半导体上市公司都有哪些,市值、毛利率、营收等情况是怎样的。简单拉了一下半导体股票列表,去重后有将近170家,人工统计肯定费时费力(刚毕业那会儿真这么干过)。

凡是重复的流程,必定有办法提升效率。

于是打算写一个 Python 脚本,主要功能是提取A股半导体上市公司的一些数据,如收入、净利润、总市值、收入、收入占比、毛利率。

在这之前,好几次我都是直接提需求给 ChatGPT,让它帮我写脚本,这次我想尝试自己写脚本。

选数据源,探索API数据结构:雪球API

首先是要选好 API 数据源,上一次美股上市公司数据提取,用的是雅虎的API-yfinance,这次经过搜索,用的是一个 GitHub 上的雪球 API

选好API之后,需要看看里面有哪些字段,怎么取出自己想要的一些数据,细看数据结构,上面的收入,主营业务等数据,都分布在不同的 json 文件的字段里。我基于之前对于Python程序的列表(list ),字典(ditc)这些数据类型的基本理解,尝试手动提取数据,并以寒武纪公司作为例子。

1
2
3
4
5
6
7
8
9
10
# 返回当日收盘价计算的总市值,或者实时价格计算的市值
ball.quotec('SH688256')['data'][0]['market_capital']
# 2023年收入
ball.income(symbol='SH688256',is_annals=1,count=5)['data']['list'][0]['total_revenue'][0]
# gross_selling_rate就是毛利率
ball.indicator(symbol='SH688256',is_annals=1,count=5)['data']['list'][0]['gross_selling_rate'][0]

# 提取2022年的,参数换成['data']['list'][1]
ball.indicator(symbol='SH688256',is_annals=1,count=5)['data']['list'][1]['gross_selling_rate'][0]

通过上面代码可以看出,提取数据的方法非常简单粗暴,好在之前折腾过,对json数据结构有一定的理解,简单说就是列表和dict交叉嵌套。

写提取函数

今年几次AI编程的经验,我知道模块化思维很重要,就是通过函数封装一个个细分的功能,这样就像金字塔原理:

  • 总论点:对应python程序main函数
  • 分论点: 对应Python中用于处理某一块特定功能的函数模块

两三个函数对应金字塔的两三个要点,在main()函数调用这两三个函数就可以了。

提取成功后,我写了三个函数来封装对应的提取过程,然后尝试写入 csv 文件。

写循环,批量提取数据

单个股票数据提取完成后,接下来要尝试提取170家半导体公司的数据。我参考了之前 ChatGPT 写的提取美股公司数据的代码,其中有 process_csv 函数,直接拿来改了一下。

过程中还是有很多细节要注意,特别要注意的是:

csv 文件,一定要有至少两列表头:stock_code, mkt_cap,除了 stock_code 对应列,其他列数据可以留空。

以下是所有的源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# 探索完毕,开始封装成函数
# 通过pysnowball获取
# API参考:https://github.com/uname-yang/pysnowball
import pysnowball as ball
import csv
ball.set_token("xq_a_token=c6d3a522****;u=71173****")

def main():

input_csv = 'china-semi-stock-list.csv' # 这里 csv 文件最好有至少两列,stock_code,mkt_cap,不能只有 stock_code 一列,不然下方process_csv()无法识别
company_data = process_csv(input_csv)

with open('china_semi_stocks_info.csv', mode='w', newline='', encoding='utf-8') as file:
fieldnames = ['stock_code','mkt_cap','income','gross_profit_ratio','main_busi_name','main_busi_revenue','main_busi_percentage','main_busi_ratio']
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()
for data in company_data:
writer.writerow(data)

def get_stock_data(stock_code):
stock_basic = {}
mkt_cap_figure = ball.quotec(stock_code)['data'][0]['market_capital']
income_figure = ball.income(symbol= stock_code,is_annals=1,count=5)['data']['list'][0]['total_revenue'][0]
gross_profit_figure = ball.indicator(symbol= stock_code,is_annals=1,count=5)['data']['list'][0]['gross_selling_rate'][0]
stock_basic['stock_code'] = stock_code
stock_basic['mkt_cap'] = mkt_cap_figure
stock_basic['income'] = income_figure
stock_basic['gross_profit_ratio'] = gross_profit_figure
return stock_basic

def get_main_busi(stock_code):
stock_main_busi = {}
busi_seg = ball.business(symbol=stock_code,count=5)['data']['list'][1]['class_list'][0]
main_busi_name_content = busi_seg['business_list'][0]['project_announced_name']
main_busi_revenue_figure = busi_seg['business_list'][0]['prime_operating_income']
main_busi_percentage_figure = busi_seg['business_list'][0]['income_ratio']
main_busi_ratio_figure = busi_seg['business_list'][0]['gross_profit_rate']
stock_main_busi['main_busi_name'] = main_busi_name_content
stock_main_busi['main_busi_revenue'] = main_busi_revenue_figure
stock_main_busi['main_busi_percentage'] = main_busi_percentage_figure
stock_main_busi['main_busi_ratio'] = main_busi_ratio_figure
return stock_main_busi

# Read stock codes from CSV and process them in batches of 200
# Fetch company data with a delay between each request
def process_csv(input_csv):
with open(input_csv, newline='', encoding='utf-8') as csvfile:
reader = csv.DictReader(csvfile)
stock_codes = [row['stock_code'] for row in reader]

batch_size = 10
results = []
for i in range(0, len(stock_codes), batch_size):
batch = stock_codes[i:i + batch_size]
for stock_code in batch:
try:
stock_basic = get_stock_data(stock_code)
stock_main_busi = get_main_busi(stock_code)
stock_data = {**stock_basic, **stock_main_busi} # merge two dict types data into one dict
results.append(stock_data)
print(f"Retrieved data for {stock_data['stock_code']}")
time.sleep(1) # Add a 1-second delay between requests
except Exception as e:
print(f"Failed to retrieve data for {stock_data['stock_code']}: {e}")

return results

main()

总结以及提升

这次编程实践说明了一是干中学很重要,比一直学 tutorial 要快。第二,对异常的处理也很重要,能提升程序的鲁棒性。第三,持续不断的想着优化自己的程序,譬如我用 AI 搜索了一下,对于复杂的 json 文件,除了像我用粗暴的方式提取数据,其实还有很多更好的方法来提取字段对应数据。

譬如 jmespath,以后碰到同样的 json 数据,就能举一反三了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 方法二,用 jmespath抓取
# jmespath tutorial:https://jmespath.org/tutorial.html
import json
import jmespath
import pysnowball as ball
ball.set_token("xq_a_token=c6d3a522****;u=71173****")
json_data = ball.income(symbol='SH688256',is_annals=1,count=5)

# 比纯 json 导出的易读性更强,可以通过关键词查找,输入对应一年的数据。
report_2022 = jmespath.search("data.list[?report_name == '2022年报']", json_data)
revenue = report_2022[0]['total_revenue'][0]

# tutorial中支持list和dict嵌套查找
revenue_alternative = jmespath.search("data.list[0].report_name", json_data)
net_profits_alternative = jmespath.search("data.list[0].total_revenue[0]", json_data)

print(revenue_alternative)
print(net_profits_alternative)