from __future__ import print_function
import pickle
import paddle
import paddle.fluid as fluid
import paddle.dataset as dataset
from functools import partial
import numpy as np

try:
    from paddle.fluid.contrib.inferencer import *
except ImportError:
    print(
        "In the fluid 1.0, the inferencer are moving to paddle.fluid.contrib",
        file=sys.stderr)
    from paddle.fluid.inferencer import *

    
CLASS_DIM = 2
EMB_DIM = 128
HID_DIM = 512
STACKED_NUM = 3


class PaddleSentimentAnalysis(object):
    def __init__(self):
        self.ctx = None
        self.model = None
        self.labels = None
        self.signature = None
        self.param_path = None
        self.initialized = False
        self.invalid_reqs = set()

    def stacked_lstm_net(self, data, input_dim, class_dim, emb_dim, hid_dim, stacked_num):
        emb = fluid.layers.embedding(
            input=data, size=[input_dim, emb_dim], is_sparse=True)
        
        fc1 = fluid.layers.fc(input=emb, size=hid_dim)
        lstm1, cell1 = fluid.layers.dynamic_lstm(input=fc1, size=hid_dim)
        
        inputs = [fc1, lstm1]
        
        for i in range(2, stacked_num + 1):
            fc = fluid.layers.fc(input=inputs, size=hid_dim)
            lstm, cell = fluid.layers.dynamic_lstm(
                input=fc, size=hid_dim, is_reverse=(i % 2) == 0)
            inputs = [fc, lstm]

        fc_last = fluid.layers.sequence_pool(input=inputs[0], pool_type='max')
        lstm_last = fluid.layers.sequence_pool(input=inputs[1], pool_type='max')
            
        prediction = fluid.layers.fc(input=[fc_last, lstm_last],
                                     size=class_dim,
                                     act='softmax')
        return prediction

    def inference_program(self, word_dict):
        data = fluid.layers.data(
            name="words", shape=[1], dtype="int64", lod_level=1)

        dict_dim = len(word_dict)
        net = self.stacked_lstm_net(data, dict_dim, CLASS_DIM, EMB_DIM, HID_DIM, STACKED_NUM)
        return net

    def initialize(self, context):
        properties = context.system_properties
        model_dir = properties.get("model_dir")
        gpu_id = properties.get("gpu_id")
        self.ctx = fluid.CUDAPlace(gpu_id) if gpu_id is not None else fluid.CPUPlace()
        word_dict_file = open(model_dir + "/word_dict.pickle", "rb")
        d =  pickle._Unpickler(word_dict_file)   # dataset.imdb.word_dict()
        d.encoding = 'utf-8'
        self.labels = d.load()

        self.model = Inferencer(
            infer_func=partial(self.inference_program, self.labels),
            param_path=model_dir+"/paddle_artifacts",
            place=self.ctx)
        self.initialized = True

    def preprocess(self, data):
        req_list = []
        for idx, req in enumerate(data):
            r = req.get("body").decode('utf-8')

            if r is None:
                r = req.get("data").decode('utf-8')

            if r is None or len(r) == 0:
                self.invalid_reqs.add(idx)
                r = "invalid request"

            req_list.append(r.split())

        UNK = self.labels['<unk>']
        lod = []
        for c in req_list:
            lod.append([self.labels.get(words, UNK) for words in c])
        base_shape = [[len(c) for c in lod]]
        tensor_words = fluid.create_lod_tensor(lod, base_shape, self.ctx)
        return tensor_words

    def inference(self, input):
        return self.model.infer({'words': input})

    def postprocess(self, output, data):
        ret = []
        for i, r in enumerate(output[0]):
            if i in self.invalid_reqs:
                ret.append("Invalid request received")
            else:
                ret.append("Predict probability of " + str(r[0]) + " to be positive and " +
                           str(r[1]) + " to be negative for review\n")
        return ret


_service = PaddleSentimentAnalysis()


def handle(data, context):
    if not _service.initialized:
        _service.initialize(context)
        
    if data is None:
        return None

    pre = _service.preprocess(data)
    inf = _service.inference(pre)
    ret = _service.postprocess(inf, data)
    return ret
