import React, { useState } from 'react';
import {
  Typography,
  Card,
  CardContent,
  Grid,
  Button,
  Box,
  Chip,
  LinearProgress,
  TextField,
  Collapse,
  Switch,
  FormControlLabel,
  Select,
  MenuItem,
  InputLabel,
  FormControl,
  IconButton,
  Tooltip,
} from '@mui/material';
import {
  ExpandMore as ExpandMoreIcon,
} from '@mui/icons-material';
    
const defaultTrainingConfig = {
  basic: {
    outputModelName: '',
    batchSizePerDevice: 16,
    epochs: 3,
    hfToken: '',
    dataPrecision: 'fp16', // 'fp16', 'fp32', 'bf16'
  },
  advanced: {
    learningRate: 2e-5,
    pushToHub: false,
    hfToken: '',
    gradientCheckpoint: false,
    optimizer: 'adamw_torch', // 'adamw_torch', 'adamw_hf', 'adam', 'sgd'
    loggingSteps: 500,
    saveStrategy: 'epoch', // 'epoch', 'steps', 'no'
    precisionStrategy: 'no', // 'no', 'fp16', 'bf16'
    maxGradientNorm: 1.0,
    warmupRatio: 0.1,
    lrScheduler: 'linear', // 'linear', 'cosine', 'cosine_with_restarts', 'polynomial'
  }
};

const optimizerOptions = [
  { value: 'adamw_torch', label: 'AdamW (PyTorch)', description: 'Default optimizer with good general performance' },
  { value: 'adamw_hf', label: 'AdamW (Hugging Face)', description: 'Hugging Face implementation of AdamW' },
  { value: 'adam', label: 'Adam', description: 'Traditional Adam optimizer' },
  { value: 'sgd', label: 'SGD', description: 'Stochastic Gradient Descent' },
];

const schedulerOptions = [
  { value: 'linear', label: 'Linear', description: 'Linear learning rate decay' },
  { value: 'cosine', label: 'Cosine', description: 'Cosine learning rate decay' },
  { value: 'cosine_with_restarts', label: 'Cosine with Restarts', description: 'Cosine decay with periodic restarts' },
  { value: 'polynomial', label: 'Polynomial', description: 'Polynomial learning rate decay' },
];

const precisionOptions = [
  { value: 'fp16', label: 'FP16', description: '16-bit floating point precision' },
  { value: 'fp32', label: 'FP32', description: '32-bit floating point precision' },
  { value: 'bf16', label: 'BF16', description: 'Brain floating point format' },
];

const saveStrategyOptions = [
  { value: 'epoch', label: 'Epoch', description: 'Save model at the end of each epoch' },
  { value: 'steps', label: 'Steps', description: 'Save model at specified steps' },
  { value: 'no', label: 'No', description: 'Do not save the model' },
];

const TrainingFineTuning = ({ 
  selectedModel, 
  selectedDataset,
  activeJobs,
  completedJobs,
  onStartTraining,
}) => {
  const [showAdvanced, setShowAdvanced] = useState(false);
  const [config, setConfig] = useState(defaultTrainingConfig);

  const handleBasicConfigChange = (field) => (event) => {
    setConfig(prev => ({
      ...prev,
      basic: {
        ...prev.basic,
        [field]: event.target.value
      }
    }));
  };

  const handleAdvancedConfigChange = (field) => (event) => {
    setConfig(prev => ({
      ...prev,
      advanced: {
        ...prev.advanced,
        [field]: event.target.value
      }
    }));
  };

  const handleSwitchChange = (section, field) => (event) => {
    setConfig(prev => ({
      ...prev,
      [section]: {
        ...prev[section],
        [field]: event.target.checked
      }
    }));
  };

  const isFormValid = config.basic.outputModelName.trim() !== '';

  return (
    <Grid container spacing={3}>
      {/* Training Configuration */}
      <Grid item xs={12}>
        <Card>
          <CardContent>
            <Typography variant="h6" gutterBottom>Training Configuration</Typography>
            
            {/* Basic Configuration */}
            <Grid container spacing={2}>
              <Grid item xs={12}>
                <TextField
                  fullWidth
                  required
                  label="Output Model Name"
                  value={config.basic.outputModelName}
                  onChange={handleBasicConfigChange('outputModelName')}
                  helperText="Name for the fine-tuned model"
                />
              </Grid>
              <Grid item xs={12} md={6}>
                <TextField
                  fullWidth
                  label="Batch Size per Device"
                  type="number"
                  value={config.basic.batchSizePerDevice}
                  onChange={handleBasicConfigChange('batchSizePerDevice')}
                  InputProps={{
                    inputProps: { min: 1 }
                  }}
                />
              </Grid>
              <Grid item xs={12} md={6}>
                <TextField
                  fullWidth
                  label="Number of Epochs"
                  type="number"
                  value={config.basic.epochs}
                  onChange={handleBasicConfigChange('epochs')}
                  InputProps={{
                    inputProps: { min: 1 }
                  }}
                />
              </Grid>
              <Grid item xs={12} md={6}>
                <FormControl fullWidth>
                  <InputLabel>Data Precision</InputLabel>
                  <Select
                    value={config.basic.dataPrecision}
                    onChange={handleBasicConfigChange('dataPrecision')}
                    label="Data Precision"
                  >
                    {precisionOptions.map(option => (
                      <MenuItem key={option.value} value={option.value}>
                        {option.label}
                      </MenuItem>
                    ))}
                  </Select>
                </FormControl>
              </Grid>
            </Grid>

            {/* Advanced Configuration */}
            <Box sx={{ mt: 3 }}>
              <Button
                onClick={() => setShowAdvanced(!showAdvanced)}
                endIcon={<ExpandMoreIcon />}
              >
                Advanced Settings
              </Button>
              <Collapse in={showAdvanced}>
                <Grid container spacing={2} sx={{ mt: 1 }}>
                  <Grid item xs={12}>
                    <FormControlLabel
                      control={
                        <Switch
                          checked={config.advanced.pushToHub}
                          onChange={handleSwitchChange('advanced', 'pushToHub')}
                        />
                      }
                      label="Push to Hugging Face Hub"
                    />
                  </Grid>
                  {config.advanced.pushToHub && (
                    <Grid item xs={12}>
                      <TextField
                        fullWidth
                        label="Hugging Face Token"
                        type="password"
                        value={config.advanced.hfToken}
                        onChange={handleAdvancedConfigChange('hfToken')}
                      />
                    </Grid>
                  )}
                  <Grid item xs={12}>
                    <FormControlLabel
                      control={
                        <Switch
                          checked={config.advanced.gradientCheckpoint}
                          onChange={handleSwitchChange('advanced', 'gradientCheckpoint')}
                        />
                      }
                      label="Gradient Checkpointing"
                    />
                  </Grid>
                  <Grid item xs={12} md={6}>
                    <TextField
                      fullWidth
                      label="Learning Rate"
                      type="number"
                      value={config.advanced.learningRate}
                      onChange={handleAdvancedConfigChange('learningRate')}
                      InputProps={{
                        inputProps: { min: 0, step: "1e-6" }
                      }}
                    />  
                  </Grid>
                  <Grid item xs={12} md={6}>
                    <FormControl fullWidth>
                      <InputLabel>Optimizer</InputLabel>
                      <Select
                        value={config.advanced.optimizer}
                        onChange={handleAdvancedConfigChange('optimizer')}
                        label="Optimizer"
                      >
                        {optimizerOptions.map(option => (
                          <MenuItem key={option.value} value={option.value}>
                            {option.label}
                          </MenuItem>
                        ))}
                      </Select>
                    </FormControl>
                  </Grid>
                  <Grid item xs={12} md={6}>
                    <FormControl fullWidth>
                      <InputLabel>Learning Rate Scheduler</InputLabel>
                      <Select
                        value={config.advanced.lrScheduler}
                        onChange={handleAdvancedConfigChange('lrScheduler')}
                        label="Learning Rate Scheduler"
                      >
                        {schedulerOptions.map(option => (
                          <MenuItem key={option.value} value={option.value}>
                            {option.label}
                          </MenuItem>
                        ))}
                      </Select>
                    </FormControl>
                  </Grid>
                  <Grid item xs={12} md={6}>
                    <TextField
                      fullWidth
                      label="Logging Steps"
                      type="number"
                      value={config.advanced.loggingSteps}
                      onChange={handleAdvancedConfigChange('loggingSteps')}
                      InputProps={{
                        inputProps: { min: 1 }
                      }}
                    />
                  </Grid>
                  <Grid item xs={12} md={6}>
                    <FormControl fullWidth>
                      <InputLabel>Save Strategy</InputLabel>
                      <Select
                        value={config.advanced.saveStrategy}
                        onChange={handleAdvancedConfigChange('saveStrategy')}
                        label="Save Strategy"
                      >
                        {saveStrategyOptions.map(option => (
                          <MenuItem key={option.value} value={option.value}>
                            {option.label}
                          </MenuItem>
                        ))}
                      </Select>
                    </FormControl>
                  </Grid>
                  <Grid item xs={12} md={6}>
                    <TextField
                      fullWidth
                      label="Max Gradient Norm"
                      type="number"
                      value={config.advanced.maxGradientNorm}
                      onChange={handleAdvancedConfigChange('maxGradientNorm')}
                      InputProps={{
                        inputProps: { min: 0, step: 0.1 }
                      }}
                    />
                  </Grid>
                  <Grid item xs={12} md={6}>
                    <TextField
                      fullWidth
                      label="Warmup Ratio"
                      type="number"
                      value={config.advanced.warmupRatio}
                      onChange={handleAdvancedConfigChange('warmupRatio')}
                      InputProps={{
                        inputProps: { min: 0, max: 1, step: 0.01 }
                      }}
                    />
                  </Grid>
                </Grid>
              </Collapse>
            </Box>

            <Button 
              variant="contained" 
              onClick={() => onStartTraining(config)}
              sx={{ mt: 3 }}
              fullWidth
              disabled={!isFormValid}
            >
              {isFormValid ? 'Choose Training Device(s)' : 'Please Enter Output Model Name'}
            </Button>
          </CardContent>
        </Card>
      </Grid>

      {/* Active Jobs */}
      <Grid item xs={12}>
        <Card>
          <CardContent>
            <Typography variant="h6" gutterBottom>Active Training Jobs</Typography>
            {activeJobs.map(job => (
              <Box key={job.id} sx={{ mb: 3 }}>
                <Box sx={{ display: 'flex', justifyContent: 'space-between', mb: 1 }}>
                  <Typography variant="subtitle1">{job.name}</Typography>
                  <Chip label={job.status} color="primary" size="small" />
                </Box>
                <Typography variant="body2" color="text.secondary">
                  Dataset: {job.dataset}
                </Typography>
                <Box sx={{ mt: 1 }}>
                  <Typography variant="body2">
                    Progress: {job.progress}% (ETA: {job.eta})
                  </Typography>
                  <LinearProgress 
                    variant="determinate" 
                    value={job.progress} 
                    sx={{ mt: 1 }} 
                  />
                </Box>
                <Box sx={{ mt: 2 }}>
                  <Typography variant="body2">Current Metrics:</Typography>
                  <Grid container spacing={2}>
                    <Grid item xs={6}>
                      <Typography variant="body2">
                        Loss: {job.metrics.epochs[job.metrics.currentEpoch - 1].loss.toFixed(4)}
                      </Typography>
                    </Grid>
                    <Grid item xs={6}>
                      <Typography variant="body2">
                        Accuracy: {job.metrics.epochs[job.metrics.currentEpoch - 1].accuracy.toFixed(4)}
                      </Typography>
                    </Grid>
                  </Grid>
                </Box>
              </Box>
            ))}
          </CardContent>
        </Card>
      </Grid>

      {/* Completed Jobs */}
      <Grid item xs={12}>
        <Card>
          <CardContent>
            <Typography variant="h6" gutterBottom>Completed Jobs</Typography>
            {completedJobs.map(job => (
              <Box key={job.id} sx={{ mb: 3 }}>
                <Box sx={{ display: 'flex', justifyContent: 'space-between', mb: 1 }}>
                  <Typography variant="subtitle1">{job.name}</Typography>
                  <Typography variant="body2" color="text.secondary">
                    Completed {job.completedAt}
                  </Typography>
                </Box>
                <Typography variant="body2" color="text.secondary">
                  Dataset: {job.dataset}
                </Typography>
                <Box sx={{ mt: 2 }}>
                  <Typography variant="body2">Final Metrics:</Typography>
                  <Grid container spacing={2}>
                    <Grid item xs={4}>
                      <Typography variant="body2">
                        Accuracy: {job.metrics.finalMetrics.accuracy.toFixed(4)}
                      </Typography>
                    </Grid>
                    <Grid item xs={4}>
                      <Typography variant="body2">
                        Loss: {job.metrics.finalMetrics.loss.toFixed(4)}
                      </Typography>
                    </Grid>
                    <Grid item xs={4}>
                      <Typography variant="body2">
                        F1 Score: {job.metrics.finalMetrics.f1Score.toFixed(4)}
                      </Typography>
                    </Grid>
                  </Grid>
                </Box>
              </Box>
            ))}
          </CardContent>
        </Card>
      </Grid>
    </Grid>
  );
};

export default TrainingFineTuning;
