package mainimport ("encoding/json""flag""fmt""log""net/http""time""config""framework/logger""global""models/function""models/schema""github.com/go-redis/redis""github.com/gorilla/websocket""github.com/labstack/echo")var clients = make(map[*websocket.Conn]bool)var broadcast = make(chan Message)var upgrader = websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} //不使用默认设置,如果线上环境可能需要使用默认配置var chananel = make(chan schema.Listening) //数据chanvar configFile *string = flag.String("config", "./bin/etc/conf.yaml", "agency config file")//这是数据库的配置文件解析,单写的时候提出来var agentSlice []map[string]*websocket.Conn //socket对应关系存储//发送消息结构体type Message struct {Message interface{} `json:"message"`SiteId string `json:"site_id"`SiteIndexId string `json:"site_index_id"`Count int64 `json:"count"`}//测试用[正式修改之后可以删除]func hu(w http.ResponseWriter, r *http.Request) {siteid := r.FormValue("site_id")siteIndexId := r.FormValue("site_index_id")fmt.Println(siteIndexId, siteid)s := schema.Listening{"zym", "b", 1}chananel <- s}func main() {//数据库初始化cfg, err := config.ParseConfigFile(*configFile)if err != nil {log.Fatalf("parse config file error:%v\n", err.Error())return}//初始化数据库err = global.InitMysql(cfg.Mysqls)if err != nil {//数据库连接错误global.GlobalLogger.Error("InitDb error:%v\n", err.Error())return}http.HandleFunc("/o", hu)http.HandleFunc("/ws", handleConnections)go handleMessages()err = http.ListenAndServe(cfg.Wesocketport, nil)if err != nil {log.Fatal(err.Error())}}func handleConnections(w http.ResponseWriter, r *http.Request) {//如果限制连接就可以使用ip+port限制,根据ip区分客户端,其他的可以根据r.Request提交的数据查找相应的内容siteId := r.FormValue("site_id")siteIndexId := r.FormValue("site_index_id")//这里是用来唯一区分客户端的判断条件if siteId == "" || siteIndexId == "" {http.Error(w, "site_id and site_index_id must not empty", 403)}//注册成为websocketws, err := upgrader.Upgrade(w, r, nil)if err != nil {global.GlobalLogger.Error("error:%s", err.Error())return}defer ws.Close()//存储连接[todo 这里可能还要考虑map并发读写问题]agent := make(map[string]*websocket.Conn)agent[s] = wsagentSlice = append(agentSlice, agent)clients[ws] = true//监听接收一个[models/schema]schema.Listening,for {var msg Messages := <-chananelif s.Types == 1 {//todo 这里解析取出来的数据可能还需要加工//获取最新的没有确认得公司入款newincome := new(function.MemberCompanyIncomeBean)info, count, err := newincome.GetNotConfirm(s.SiteId, s.SiteIndexId)if err != nil {global.GlobalLogger.Error("error:%s", err.Error())return}msg = Message{SiteIndexId: s.SiteIndexId, SiteId: s.SiteId, Message: info, Count: count}} else if s.Types == 2 {//获取最新的线上入款onLineBean := new(function.OnlineEntryRecordBean)info, count, err := onLineBean.GetNotConfirm(s.SiteId, s.SiteIndexId)if err != nil {global.GlobalLogger.Error("error:%s", err.Error())return}msg = Message{SiteIndexId: s.SiteIndexId, SiteId: s.SiteId, Message: info, Count: count}} else {//获取没有确认得最新的出款管理makeMoney := new(function.MakeMoneyBean)info, count, err := makeMoney.GetOperateRecord(s.SiteId, s.SiteIndexId)if err != nil {global.GlobalLogger.Error("error:%s", err.Error())return}msg = Message{SiteIndexId: s.SiteIndexId, SiteId: s.SiteId, Count: count, Message: info}}broadcast <- msg}}//单点推送func handleMessages() {for {msg := <-broadcastvar pushClient []*websocket.ConnnewS := fmt.Sprintf("%s%s", msg.SiteId, msg.SiteIndexId)lenAgent := len(agentSlice)for i := 0; i < lenAgent; i++ {for k, v := range agentSlice[i] {if newS == k {pushClient = append(pushClient, v)}}}for i := 0; i < len(pushClient); i++ {for client := range clients {if pushClient[i] == client {err := client.WriteJSON(msg)if err != nil {global.GlobalLogger.Error("error:%s", err.Error())client.Close()delete(clients, client)}}}}}}