import pandas as pd
import tushare as ts
import baostock as bs
import smtplib
from email.mime.text import MIMEText
import time
import calendar
from datetime import datetime
import titan_data_settings as settings
from urllib import parse
import logging
from sqlalchemy import create_engine
from titan_table_structure import Base
from dateutil.relativedelta import relativedelta
from sqlalchemy import text
import re

'''
author: jscat 2021/02/02
'''

'''
初始化函数
:param token:tushare pro的token
'''

class Data:

    def __init__(self, token):
        # 创建请求数据线程
        self.token = token

    # delete数据并且append
    def delete_append(self, data, table_name, info):
        """删除mysql表所有数据，to_sql追加新数据"""
        engine = create_engine(settings.DB_CONN_URL, encoding='utf-8', echo=False)
        try:
            with engine.begin() as conn:
                conn.execute('DELETE FROM ' + table_name)
                # rollback sql
                data.to_sql(table_name, conn, if_exists='append', index=False)
                self.update_log(table_name, conn, info)
        except Exception as ee:
            logging.error('delete_append failed: '+info, ee)
            self.sendMail('delete_append failed: '+info, ee.args)
            self.error_log(table_name, engine, info, ee.args)
        finally:
            """销毁对象"""
            engine.dispose()

    # keep数据并且append
    def keep_append(self, data, table_name, info):
        """保留mysql表所有数据，to_sql追加新数据"""
        engine = create_engine(settings.DB_CONN_URL, encoding='utf-8', echo=False)
        try:
            with engine.begin() as conn:
                # rollback sql
                data.to_sql(table_name, conn, if_exists='append', index=False)
                self.update_log(table_name, conn, info)
        except Exception as ee:
            logging.error('keep_append failed: '+info, ee)
            self.sendMail('delete_append failed: ' + info, ee.args)
            self.error_log(table_name, engine, info, ee.args)
        finally:
            """销毁对象"""
            engine.dispose()

    # 同步更新日志表
    def update_log(self, table_name, conn, info):
        df = pd.DataFrame({'TARGET_TABLE': table_name, 'UPDATE_INFO': info, 'CREATE_DT': str(datetime.now())}, index=[0])
        df.to_sql("tbl_update_log", conn, if_exists='append', index=False)

    # 同步更新错误日志表
    def error_log(self, table_name, conn, info, log):
        df = pd.DataFrame({'TARGET_TABLE': table_name, 'UPDATE_INFO': info, 'ERROR_INFO': str(log), 'CREATE_DT': str(datetime.now())}, index=[0])
        df.to_sql("tbl_error_log", conn, if_exists='append', index=False)

    '''
    start_date='20210101', end_date='20311231'
    
    FieldType Comment
    OBJECT_ID varchar(100) NULL对象ID
    TRADE_DAYS varchar(8) NOT NULL Trading Day, 20210201
    S_INFO_EXCHMARKET varchar(40) NOT NULL Exchange Name (English), SSE:上海交易所 | SZSE：深圳交易所
    SOURCE_TYPE varchar(10) NULL BS: baostock | WD: wind
    '''
    def get_AShareCalendar(self, start_date, end_date, type):
        # 此方法连接数据库，密码可以输入特殊字符串
        print('bs连接成功')
        lg = bs.login()
        # 显示登陆返回信息
        print('login respond error_code:' + lg.error_code)
        print('login respond  error_msg:' + lg.error_msg)

        #### 获取交易日信息 ####
        rs = bs.query_trade_dates(start_date=start_date, end_date=end_date)
        print('query_trade_dates respond error_code:' + rs.error_code)
        print('query_trade_dates respond error_msg:' + rs.error_msg)

        #### 打印结果集 ####
        data_list = []
        while (rs.error_code == '0') & rs.next():
            # 获取一条记录，将记录合并在一起
            data_list.append(rs.get_row_data())
        df_all = pd.DataFrame(data_list, columns=rs.fields)

        # 删选数据
        df_all = df_all[df_all['is_trading_day'] == '1']
        df_all = df_all.drop('is_trading_day', axis=1)

        # 本地储存前一定要先转化成Wind-Compatible日期格式先
        df_all['calendar_date'] = df_all['calendar_date'].str.replace('-', '')
        # 对获取的数据列名称进行重命名适应Wind format
        df_all = df_all.rename(columns={'calendar_date': 'TRADE_DAYS'})
        # 补全其他字段
        df_all['OBJECT_ID'] = ""
        df_all['SOURCE_TYPE'] = "BS"
        df_all['S_INFO_EXCHMARKET'] = "SSE"

        # data operation
        info = "update record: "+str(start_date)+"_"+str(end_date)
        if type == 0:
            self.delete_append(df_all, 'tbl_AShareCalendar', info)
        else:
            self.keep_append(df_all, 'tbl_AShareCalendar', info)

    '''
    start_date='20210101', end_date='20311231'

    OBJECT_ID varchar(100) NULL对象ID
    S_INFO_CODE varchar(40) NOT NULL交易代码
    S_INFO_NAME varchar(50) NULL 证券简称
    S_INFO_COMPNAME varchar(100) NULL 公司中文名字
    S_INFO_COMPNAMEENG varchar(100) NULL 公司英文名字
    S_INFO_ISINCODE varchar(40) NULL ISIN CODE
    S_INFO_EXCHMARKET varchar(40) NULL 交易所, SSE: 上交所; SZSE:深交所
    S_INFO_LISTBOARD varchar(10) NULL 上市板类型; 434004000:主板; 434003000:中小企业板; 434001000:创业板
    S_INFO_LISTBOARDNAME varchar(10) NULL 上市板, 主板, 创业板, 中小企业板
    '''
    def get_AShareDescription(self, date, type):
        # 登陆系统
        lg = bs.login()
        # 显示登陆返回信息
        print('login respond error_code:' + lg.error_code)
        print('login respond  error_msg:' + lg.error_msg)

        # 获取证券基本资料
        rs = bs.query_stock_basic()
        # rs = bs.query_stock_basic(code_name="浦发银行")  # 支持模糊查询
        print('query_stock_basic respond error_code:' + rs.error_code)
        print('query_stock_basic respond  error_msg:' + rs.error_msg)

        # 打印结果集
        data_list = []
        while (rs.error_code == '0') & rs.next():
            # 获取一条记录，将记录合并在一起
            data_list.append(rs.get_row_data())
        df = pd.DataFrame(data_list, columns=rs.fields)
        # 结果集输出到csv文件

        # 筛选出type=='1'(股票)和status=='1'(可用)的数据
        print(len(df))
        df = df[(df['type'] == '1') & (df['status'] == '1')]
        print(len(df))
        #
        df['S_INFO_CODE'] = df['code'].apply(lambda x: x.split('.')[1])
        df['S_INFO_EXCHMARKET'] = df['code'].apply(lambda x: x.split('.')[0])
        df['S_INFO_EXCHMARKET'] = df['S_INFO_EXCHMARKET'].map(lambda x: re.sub('sh', 'SSE', x))
        df['S_INFO_EXCHMARKET'] = df['S_INFO_EXCHMARKET'].map(lambda x: re.sub('sz', 'SZSE', x))
        df['S_INFO_NAME'] = df['code_name']
        df = df.drop(['code', 'code_name', 'ipoDate', 'outDate', 'type', 'status'], axis=1)
        # data operation
        info = "update record: "+str(date)
        if type == 0:
            self.delete_append(df, 'tbl_AShareDescription', info)
        else:
            self.keep_append(df, 'tbl_AShareDescription', info)

    # Deprecated
    def get_all_stockdata(self, start_date, end_date):
        # 此方法连接数据库，密码可以输入特殊字符串
        engine = create_engine(settings.DB_CONN_URL)
        print('数据库连接成功')
        ts.set_token(self.token)
        pro = ts.pro_api()
        trade_d = pro.trade_cal(exchange='SSE', is_open='1',start_date=start_date,end_date=end_date, fields='cal_date')
        for date in trade_d['cal_date'].values:
            df_basic = pro.stock_basic(exchange='', list_status='L')    #再获取所有股票的基本信息
            df_daily = pro.daily(trade_date=date)    # 先获得所有股票的行情数据，成交额单位是千元，成交量是手
            df_daily_basic = pro.daily_basic(ts_code='', trade_date=date,fields='ts_code, turnover_rate, turnover_rate_f,'
                                                                                ' volume_ratio, pe, pe_ttm, pb, ps, ps_ttm,'
                                                                                ' dv_ratio, dv_ttm, total_share, float_share,'
                                                                                ' free_share, total_mv, circ_mv ')    #获取每日指标，单位是万股，万元
            df_first = pd.merge(left=df_basic, right=df_daily, on='ts_code', how='outer')  # on='ts_code'以ts_code为索引，合并数据，how='outer'，取并集
            df_all = pd.merge(left=df_first, right=df_daily_basic, on='ts_code', how='outer')
            # 数据清洗，删除symbol列数据，跟ts_code数据重复
            df_all = df_all.drop('symbol', axis=1)
            for w in ['name', 'area', 'industry', 'market']:  # 在'name', 'area', 'industry', 'market'列内循环填充NaN值
                df_all[w].fillna('问题股', inplace=True)

            #df_all['amount'] = df_all['amount'] / 100000  # 千转亿
            #df_all['circ_mv'] = df_all['circ_mv'] / 10000  # 万转亿
            #df_all['total_mv'] = df_all['total_mv'] / 10000  # 万转亿

            df_all['ts_code'] = df_all['ts_code'].astype(str)  # 强制转换成str字符串格式
            df_all['listart_date'] = pd.to_datetime(df_all['listart_date'])  # 本地储存前一定要先转化成日期格式先
            df_all['trade_date'] = pd.to_datetime(df_all['trade_date'])

            #对获取的股票数据列名称进行重命名以方便阅读
            df_all = df_all.rename(columns={'ts_code': '股票代码', 'name': '股票名称', 'area': '所在地域', 'industry': '行业'
                                            , 'market': '市场类型', 'listart_date': '上市日期', 'trade_date': '交易日期', 'change': '涨跌额'
                                            , 'pct_chg': '涨跌幅', 'vol': '成交量（手）', 'amount': '成交额（千元）', 'turnover_rate': '换手率（%）'
                                            , 'turnover_rate_f': '流通换手率', 'volume_ratio': '量比', 'pe': '市盈率', 'pe_ttm': '滚动市盈率'
                                            , 'pb': '市净率', 'ps': '市销率', 'ps_ttm': '滚动市销率', 'dv_ratio': '股息率'
                                            , 'dv_ttm': '滚动股息率', 'total_share': '总股本（万股）', 'float_share': '流通股本 （万股）'
                                            , 'free_share': '自由流通股本（万股）', 'total_mv': '总市值 （万元）', 'circ_mv': '流通市值（万元）'})
            #亏损的为空值

            engine.execute('drop table if exists {}_ts;'.format(date))  #删除重复的数据表
            print('%s is downloading....' % (str(date)))
            df_all.to_sql('{}_ts'.format(date),engine,index=False)
            print('{}成功导入数据库'.format(date))

    # 获取当月起止日期
    def get_current_month_start_and_end(self, date):
        """
        年份 date(2017-09-08格式)
        :param date:
        :return:本月第一天日期和本月最后一天日期
        """
        date = str(date)
        if date.count('-') != 2:
            raise ValueError('- is error')
        year, month = str(date).split('-')[0], str(date).split('-')[1]
        end = calendar.monthrange(int(year), int(month))[1]
        start_date = '%s-%s-01' % (year, month)
        end_date = '%s-%s-%s' % (year, month, end)
        return start_date, end_date

    # send email
    def sendMail(self, subject, info):

        content = info
        msg = MIMEText(content)

        msg['Subject'] = subject
        msg['From'] = settings.MAIL_FROM
        msg['To'] = settings.MAIL_TO

        ## smtp ssl port 465
        smtpServer = smtplib.SMTP_SSL(settings.MAIL_HOST, 465)  # SMTP_SSL
        smtpServer.login(settings.MAIL_FROM, settings.MAIL_PASS)
        smtpServer.sendmail(settings.MAIL_FROM, settings.MAIL_TO, msg.as_string())
        smtpServer.quit()
        'send success by port 465'
        print("sendMail pass")


if __name__=='__test__':
    data = Data(settings.TS_TOKEN)
    start = "2020-02-01"
    for i in range(1):
        date = pd.to_datetime(start) + relativedelta(months=+i)  # 当前日期往后推i个月
        date_str = date.strftime("%Y-%m-%d")
        start_date, end_date = data.get_current_month_start_and_end(date)
        print(start_date, end_date)
        data.get_AShareCalendar(start_date=start_date, end_date=end_date, type=1)
        time.sleep(10)
    print("finish")

if __name__=='__test1__':
    data = Data(settings.TS_TOKEN)
    date = datetime.now()
    data.get_AShareDescription(date=date, type=0)
    print("finish")

if __name__=='__main__':
    data = Data(settings.TS_TOKEN)
    data.sendMail("Titan-Data-Sync error", "update tbl_AShareDescription failure")
    print("finish")